From 7d5b820642580e334c3e5118c2c3992581e5aeb6 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 13 Jun 2022 04:26:01 +0200 Subject: [PATCH 001/529] [quic] first TLS changes --- mitmproxy/addons/tlsconfig.py | 93 +++++++++++++++++++++---- mitmproxy/certs.py | 6 +- mitmproxy/proxy/layers/http/__init__.py | 9 ++- mitmproxy/proxy/layers/http/_http3.py | 31 +++++++++ mitmproxy/proxy/layers/quic.py | 83 ++++++++++++++++++++++ setup.py | 1 + 6 files changed, 207 insertions(+), 16 deletions(-) create mode 100644 mitmproxy/proxy/layers/http/_http3.py create mode 100644 mitmproxy/proxy/layers/quic.py diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index 0cb7492e28..61245cd8b5 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -1,15 +1,18 @@ import ipaddress import os from pathlib import Path +import ssl from typing import Any, Optional, TypedDict +from aioquic.quic.configuration import QuicConfiguration +from aioquic.tls import CipherSuite from OpenSSL import SSL from mitmproxy import certs, ctx, exceptions, connection, tls from mitmproxy.net import tls as net_tls from mitmproxy.options import CONF_BASENAME from mitmproxy.proxy import context from mitmproxy.proxy.layers import modes -from mitmproxy.proxy.layers import tls as proxy_tls +from mitmproxy.proxy.layers import tls as proxy_tls, quic # We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default. # https://ssl-config.mozilla.org/#config=old @@ -196,6 +199,18 @@ def tls_start_client(self, tls_start: tls.TlsData) -> None: ) tls_start.ssl_conn.set_accept_state() + def _get_client_cert(self, server: connection.Server) -> Optional[str]: + if ctx.options.client_certs: + client_certs = os.path.expanduser(ctx.options.client_certs) + if os.path.isfile(client_certs): + return client_certs + else: + server_name: str = server.sni or server.address[0] + p = os.path.join(client_certs, f"{server_name}.pem") + if os.path.isfile(p): + return p + return None + def tls_start_server(self, tls_start: tls.TlsData) -> None: """Establish TLS between proxy and server.""" if tls_start.ssl_conn is not None: @@ -240,17 +255,6 @@ def tls_start_server(self, tls_start: tls.TlsData) -> None: # don't assign to client.cipher_list, doesn't need to be stored. cipher_list = server.cipher_list or DEFAULT_CIPHERS - client_cert: Optional[str] = None - if ctx.options.client_certs: - client_certs = os.path.expanduser(ctx.options.client_certs) - if os.path.isfile(client_certs): - client_cert = client_certs - else: - server_name: str = server.sni or server.address[0] - p = os.path.join(client_certs, f"{server_name}.pem") - if os.path.isfile(p): - client_cert = p - ssl_ctx = net_tls.create_proxy_server_context( min_version=net_tls.Version[ctx.options.tls_version_client_min], max_version=net_tls.Version[ctx.options.tls_version_client_max], @@ -258,7 +262,7 @@ def tls_start_server(self, tls_start: tls.TlsData) -> None: verify=verify, ca_path=ctx.options.ssl_verify_upstream_trusted_confdir, ca_pemfile=ctx.options.ssl_verify_upstream_trusted_ca, - client_cert=client_cert, + client_cert=self._get_client_cert(server), ) tls_start.ssl_conn = SSL.Connection(ssl_ctx) @@ -293,6 +297,69 @@ def tls_start_server(self, tls_start: tls.TlsData) -> None: tls_start.ssl_conn.set_connect_state() + def quic_tls_start_client(self, tls_start: quic.QuicTlsData) -> None: + """Establish QUIC between client and proxy.""" + if tls_start.settings is not None: + return # a user addon has already provided the settings. + tls_start.settings = quic.QuicTlsSettings() + + assert isinstance(tls_start.conn, connection.Client) + + client: connection.Client = tls_start.conn + server: connection.Server = tls_start.context.server + + entry = self.get_cert(tls_start.context) + tls_start.settings.certificate = entry.cert + tls_start.settings.certificate_private_key = entry.privatekey + tls_start.settings.certificate_chain = entry.chain_certs + + if not client.cipher_list and ctx.options.ciphers_client: + client.cipher_list = ctx.options.ciphers_client.split(":") + if client.cipher_list: + tls_start.settings.cipher_suites = [ + CipherSuite(cipher) for cipher in client.cipher_list + ] + if ctx.options.add_upstream_certs_to_client_chain: + tls_start.settings.certificate_chain.extend(server.certificate_list) + + def quic_tls_start_server(self, tls_start: quic.QuicTlsData) -> None: + """Establish QUIC between proxy and server.""" + if tls_start.settings is not None: + return # a user addon has already provided the settings. + tls_start.settings = quic.QuicTlsSettings() + + assert isinstance(tls_start.conn, connection.Server) + + client: connection.Client = tls_start.context.client + server: connection.Server = tls_start.conn + assert server.address + + if ctx.options.ssl_insecure: + tls_start.settings.verify_mode = ssl.CERT_NONE + + if server.sni is None: + server.sni = client.sni or server.address[0] + + if not server.alpn_offers: + server.alpn_offers = client.alpn_offers + + if not server.cipher_list and ctx.options.ciphers_server: + server.cipher_list = ctx.options.ciphers_server.split(":") + tls_start.settings.cipher_suites = [ + CipherSuite(cipher) for cipher in server.cipher_list + ] + + client_cert = self._get_client_cert(server) + if client_cert: + config = QuicConfiguration() + config.load_cert_chain(client_cert) + tls_start.settings.certificate = config.certificate + tls_start.settings.certificate_private_key = config.private_key + tls_start.settings.certificate_chain = config.certificate_chain + + tls_start.settings.ca_path = ctx.options.ssl_verify_upstream_trusted_confdir + tls_start.settings.ca_file = ctx.options.ssl_verify_upstream_trusted_ca + def running(self): # FIXME: We have a weird bug where the contract for configure is not followed and it is never called with # confdir or command_history as updated. diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index dc3787a8b7..ab26b4b8f5 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import NewType, Optional, Union +from aioquic.tls import load_pem_x509_certificates from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec @@ -283,6 +284,7 @@ class CertStoreEntry: cert: Cert privatekey: rsa.RSAPrivateKey chain_file: Optional[Path] + chain_certs: Optional[list[Cert]] TCustomCertId = str # manually provided certs (e.g. mitmproxy's --certs) @@ -311,6 +313,7 @@ def __init__( self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file + self.default_chain_certs = load_pem_x509_certificates(self.default_chain_file.read_bytes()) if self.default_chain_file else None self.dhparams = dhparams self.certs = {} self.expire_queue = [] @@ -453,7 +456,7 @@ def add_cert_file( except ValueError: key = self.default_privatekey - self.add_cert(CertStoreEntry(cert, key, path), spec) + self.add_cert(CertStoreEntry(cert, key, path, [cert]), spec) def add_cert(self, entry: CertStoreEntry, *names: str) -> None: """ @@ -516,6 +519,7 @@ def get_cert( ), privatekey=self.default_privatekey, chain_file=self.default_chain_file, + chain_certs=self.default_chain_certs, ) self.certs[(commonname, tuple(sans))] = entry self.expire(entry) diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index 2af8aefb76..f5327722f0 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -41,6 +41,7 @@ ) from ._http1 import Http1Client, Http1Connection, Http1Server from ._http2 import Http2Client, Http2Server +from ._http3 import Http3Client, Http3Server from ...context import Context @@ -821,7 +822,9 @@ def __init__(self, context: Context, mode: HTTPMode): self.command_sources = {} http_conn: HttpConnection - if self.context.client.alpn == b"h2": + if self.context.client.alpn == b"h3": + http_conn = Http3Server(context.fork()) + elif self.context.client.alpn == b"h2": http_conn = Http2Server(context.fork()) else: http_conn = Http1Server(context.fork()) @@ -1060,7 +1063,9 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: else: err = yield commands.OpenConnection(self.context.server) if not err: - if self.context.server.alpn == b"h2": + if self.context.server.alpn == b"h3": + self.child_layer = Http3Client(self.context) + elif self.context.server.alpn == b"h2": self.child_layer = Http2Client(self.context) else: self.child_layer = Http1Client(self.context) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py new file mode 100644 index 0000000000..1f57add6eb --- /dev/null +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -0,0 +1,31 @@ +from aioquic.h3.connection import H3Connection +from mitmproxy.connection import Connection +from ._base import HttpConnection +from ..quic import QuicLayer +from ...context import Context + + +class Http3Connection(HttpConnection): + h3_conn: H3Connection + + def __init__(self, context: Context, conn: Connection): + super().__init__(context, conn) + quic = context.layers[0] + assert isinstance(quic, QuicLayer) + self.h3_conn = H3Connection(quic.conn) + + +class Http3Server(Http3Connection): + def __init__(self, context: Context): + super().__init__(context, context.client) + + +class Http3Client(Http3Connection): + def __init__(self, context: Context): + super().__init__(context, context.server) + + +__all__ = [ + "Http3Client", + "Http3Server", +] diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py new file mode 100644 index 0000000000..b938c44916 --- /dev/null +++ b/mitmproxy/proxy/layers/quic.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass +from ssl import VerifyMode +from typing import List, Optional, Union + +from aioquic.tls import CipherSuite +from cryptography import x509 +from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa +from mitmproxy.proxy import layer +from mitmproxy.proxy.commands import StartHook +from mitmproxy.tls import TlsData + + +@dataclass +class QuicTlsSettings: + """ + Settings necessary to establish QUIC's TLS context. + """ + + certificate: Optional[x509.Certificate] = None + """The certificate to use for the connection.""" + certificate_chain: List[x509.Certificate] = [] + """An optional list of additional certificates to send to the peer.""" + certificate_private_key: Optional[ + Union[dsa.DSAPrivateKey, ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey] + ] = None + """The certificate's private key.""" + cipher_suites: Optional[List[CipherSuite]] = None + """An optional list of allowed/advertised protocols.""" + ca_path: Optional[str] = None + """An optional path to a directory that contains the necessary information to verify the peer certificate.""" + ca_file: Optional[str] = None + """An optional path to a PEM file that will be used to verify the peer certificate.""" + verify_mode: Optional[VerifyMode] = None + """An optional flag that specifies how/if the peer's certificate should be validated.""" + + +@dataclass +class QuicTlsData(TlsData): + """ + Event data for `quic_tls_start_client` and `quic_tls_start_server` event hooks. + """ + + settings: Optional[QuicTlsSettings] = None + """ + The associated `QuicTlsSettings` object. + This will be set by an addon in the `quic_tls_start_*` event hooks. + """ + + +@dataclass +class QuicTlsStartClientHook(StartHook): + """ + TLS negotation between mitmproxy and a client over QUIC is about to start. + + An addon is expected to initialize at least data.certificate and data.certificate_private_key. + (by default, this is done by `mitmproxy.addons.tlsconfig`) + """ + + data: QuicTlsData + + +@dataclass +class QuicTlsStartServerHook(StartHook): + """ + TLS negotation between mitmproxy and a server over QUIC is about to start. + + An addon is expected to initialize at least data.certificate and data.certificate_private_key. + (by default, this is done by `mitmproxy.addons.tlsconfig`) + """ + + data: QuicTlsData + + +class QuicLayer(layer.Layer): + pass + + +class QuicServerLayer(layer.Layer): + pass + + +class QuicClientLayer(layer.Layer): + pass diff --git a/setup.py b/setup.py index 233228273f..575659a277 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,7 @@ # https://packaging.python.org/en/latest/discussions/install-requires-vs-requirements/#install-requires # It is not considered best practice to use install_requires to pin dependencies to specific versions. install_requires=[ + "aioquic>=0.9.20", "asgiref>=3.2.10,<3.6", "blinker>=1.4, <1.5", "Brotli>=1.0,<1.1", From b4f0e28dc09e767d366457d4b2b4926c406cdb7a Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 13 Jun 2022 04:46:02 +0200 Subject: [PATCH 002/529] [quic] user proper cert type --- mitmproxy/addons/tlsconfig.py | 12 +++++++++--- mitmproxy/certs.py | 13 +++++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index 61245cd8b5..631ba6ad87 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -6,6 +6,8 @@ from aioquic.quic.configuration import QuicConfiguration from aioquic.tls import CipherSuite +from cryptography import x509 +from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa from OpenSSL import SSL from mitmproxy import certs, ctx, exceptions, connection, tls from mitmproxy.net import tls as net_tls @@ -205,6 +207,7 @@ def _get_client_cert(self, server: connection.Server) -> Optional[str]: if os.path.isfile(client_certs): return client_certs else: + assert server.address server_name: str = server.sni or server.address[0] p = os.path.join(client_certs, f"{server_name}.pem") if os.path.isfile(p): @@ -309,9 +312,9 @@ def quic_tls_start_client(self, tls_start: quic.QuicTlsData) -> None: server: connection.Server = tls_start.context.server entry = self.get_cert(tls_start.context) - tls_start.settings.certificate = entry.cert + tls_start.settings.certificate = entry.cert._cert tls_start.settings.certificate_private_key = entry.privatekey - tls_start.settings.certificate_chain = entry.chain_certs + tls_start.settings.certificate_chain = [cert._cert for cert in entry.chain_certs] if not client.cipher_list and ctx.options.ciphers_client: client.cipher_list = ctx.options.ciphers_client.split(":") @@ -353,8 +356,11 @@ def quic_tls_start_server(self, tls_start: quic.QuicTlsData) -> None: if client_cert: config = QuicConfiguration() config.load_cert_chain(client_cert) + assert isinstance(config.certificate, x509.Certificate) tls_start.settings.certificate = config.certificate - tls_start.settings.certificate_private_key = config.private_key + if config.private_key: + assert isinstance(config.private_key, (dsa.DSAPrivateKey, ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey)) + tls_start.settings.certificate_private_key = config.private_key tls_start.settings.certificate_chain = config.certificate_chain tls_start.settings.ca_path = ctx.options.ssl_verify_upstream_trusted_confdir diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index ab26b4b8f5..0a84f88848 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -284,7 +284,7 @@ class CertStoreEntry: cert: Cert privatekey: rsa.RSAPrivateKey chain_file: Optional[Path] - chain_certs: Optional[list[Cert]] + chain_certs: list[Cert] TCustomCertId = str # manually provided certs (e.g. mitmproxy's --certs) @@ -313,7 +313,16 @@ def __init__( self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file - self.default_chain_certs = load_pem_x509_certificates(self.default_chain_file.read_bytes()) if self.default_chain_file else None + self.default_chain_certs = ( + [ + Cert(cert) + for cert in load_pem_x509_certificates( + self.default_chain_file.read_bytes() + ) + ] + if self.default_chain_file + else [] + ) self.dhparams = dhparams self.certs = {} self.expire_queue = [] From 568d03c600a9df429e410a6f0ab5d702304de151 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 13 Jun 2022 18:14:05 +0200 Subject: [PATCH 003/529] [quic] changes to proxyserver --- mitmproxy/addons/proxyserver.py | 142 ++++++++++++++++++++++++----- mitmproxy/addons/tlsconfig.py | 2 +- mitmproxy/proxy/layers/__init__.py | 3 + mitmproxy/proxy/layers/quic.py | 8 +- 4 files changed, 129 insertions(+), 26 deletions(-) diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index 70ae28855d..2d20a23d7e 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -3,8 +3,15 @@ import ipaddress import re import struct -from typing import Optional +from typing import Callable, Optional +from aioquic.buffer import Buffer as QuicBuffer +from aioquic.quic.packet import ( + PACKET_TYPE_INITIAL, + QuicProtocolVersion, + encode_quic_version_negotiation, + pull_quic_header, +) from mitmproxy import ( command, ctx, @@ -21,8 +28,8 @@ from mitmproxy.connection import Address from mitmproxy.flow import Flow from mitmproxy.net import udp -from mitmproxy.proxy import commands, events, layers, server_hooks -from mitmproxy.proxy import server +from mitmproxy.proxy import commands, events, layer, layers, server, server_hooks +from mitmproxy.proxy.context import Context from mitmproxy.proxy.layers.tcp import TcpMessageInjected from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected from mitmproxy.utils import asyncio_utils, human @@ -62,6 +69,7 @@ class Proxyserver: tcp_server: Optional[base_events.Server] dns_server: Optional[udp.UdpServer] + quic_server: Optional[udp.UdpServer] connect_addr: Optional[Address] listen_port: int dns_reverse_addr: Optional[tuple[str, int]] @@ -74,6 +82,7 @@ def __init__(self): self._lock = asyncio.Lock() self.tcp_server = None self.dns_server = None + self.quic_server = None self.connect_addr = None self.dns_reverse_addr = None self.is_running = False @@ -99,6 +108,14 @@ def _server_desc(self): self.options.dns_listen_port, transparent=self.options.dns_mode == "transparent", ) + yield "QUIC", self.quic_server, lambda x: setattr( + self, "quic_server", x + ), ctx.options.quic_server, lambda: udp.start_server( + self.handle_quic_datagram, + self.options.listen_host or "127.0.0.1", + self.options.listen_port, + transparent=self.options.mode == "transparent", + ) @property def running_servers(self): @@ -196,6 +213,15 @@ def load(self, loader): transparent: transparent mode """, ) + loader.add_option( + "quic_server", bool, False, """Start a QUIC server. Disabled by default.""" + ) + loader.add_option( + "quic_connection_id_length", + int, + 8, + """The length in bytes of local QUIC connection IDs.""", + ) async def running(self): self.master = ctx.master @@ -261,6 +287,7 @@ def configure(self, updated): "dns_mode", "dns_listen_host", "dns_listen_port", + "quic_server", ] ): asyncio.create_task(self.refresh_server()) @@ -326,34 +353,26 @@ async def handle_tcp_connection( ) await self.handle_connection(connection_id) - def handle_dns_datagram( + def handle_udp_connection( self, transport: asyncio.DatagramTransport, data: bytes, remote_addr: Address, - local_addr: Address, + connection_id: tuple, + layer_cb: Callable[[Context], layer.Layer], + server_addr: Optional[Address] = None, + timeout: Optional[int] = None, ) -> None: - try: - dns_id = struct.unpack_from("!H", data, 0) - except struct.error: - ctx.log.info( - f"Invalid DNS datagram received from {human.format_address(remote_addr)}." - ) - return - connection_id = ("udp", dns_id, remote_addr, local_addr) if connection_id not in self._connections: reader = udp.DatagramReader() writer = udp.DatagramWriter(transport, remote_addr, reader) handler = ProxyConnectionHandler( - self.master, reader, writer, self.options, 20 - ) - handler.layer = layers.DNSLayer(handler.layer.context) - handler.layer.context.server.address = ( - local_addr - if self.options.dns_mode == "transparent" - else self.dns_reverse_addr + self.master, reader, writer, self.options, timeout ) - handler.layer.context.server.transport_protocol = "udp" + handler.layer = layer_cb(handler.layer.context) + if server_addr is not None: + handler.layer.context.server.address = server_addr + handler.layer.context.server.transport_protocol = "udp" self._connections[connection_id] = handler asyncio.create_task(self.handle_connection(connection_id)) else: @@ -363,6 +382,87 @@ def handle_dns_datagram( reader = client_reader reader.feed_data(data, remote_addr) + def handle_dns_datagram( + self, + transport: asyncio.DatagramTransport, + data: bytes, + remote_addr: Address, + local_addr: Address, + ) -> None: + try: + dns_id = struct.unpack_from("!H", data, 0) + except struct.error: + ctx.log.info( + f"Invalid DNS datagram received from {human.format_address(remote_addr)}." + ) + return + self.handle_udp_connection( + transport=transport, + date=data, + remote_addr=remote_addr, + server_addr=( + local_addr + if self.options.dns_mode == "transparent" + else self.dns_reverse_addr + ), + connection_id=("udp", dns_id, remote_addr, local_addr), + layer_cb=layers.DNSLayer, + timeout=20, + ) + + def handle_quic_datagram( + self, + transport: asyncio.DatagramTransport, + data: bytes, + remote_addr: Address, + local_addr: Address, + ) -> None: + # largely taken from aioquic's own asyncio server code + buffer = QuicBuffer(data=data) + try: + header = pull_quic_header( + buffer, host_cid_length=self.options.quic_connection_id_length + ) + except ValueError: + ctx.log.info( + f"Invalid QUIC datagram received from {human.format_address(remote_addr)}." + ) + return + + # negotiate version, support all versions known to aioquic + supported_versions = ( + version.value + for version in QuicProtocolVersion + if version is not QuicProtocolVersion.NEGOTIATION + ) + if header.version is not None and header.version not in supported_versions: + transport.sendto( + encode_quic_version_negotiation( + source_cid=header.destination_cid, + destination_cid=header.source_cid, + supported_versions=supported_versions, + ), + remote_addr, + ) + return + + # create or resume the connection + connection_id = ("quic", header.destination_cid) + if connection_id not in self._connections: + if len(data) < 1200 or header.packet_type != PACKET_TYPE_INITIAL: + ctx.log.info( + f"QUIC packet received from {human.format_address(remote_addr)} with an unknown connection id." + ) + return + self.handle_udp_connection( + transport=transport, + date=data, + remote_addr=remote_addr, + server_addr=local_addr if self.options.mode == "transparent" else None, + connection_id=connection_id, + layer_cb=layers.ServerQuicLayer, + ) + def inject_event(self, event: events.MessageInjected): connection_id = ( "tcp", diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index 631ba6ad87..34ca048c16 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -323,7 +323,7 @@ def quic_tls_start_client(self, tls_start: quic.QuicTlsData) -> None: CipherSuite(cipher) for cipher in client.cipher_list ] if ctx.options.add_upstream_certs_to_client_chain: - tls_start.settings.certificate_chain.extend(server.certificate_list) + tls_start.settings.certificate_chain.extend(cert._cert for cert in server.certificate_list) def quic_tls_start_server(self, tls_start: quic.QuicTlsData) -> None: """Establish QUIC between proxy and server.""" diff --git a/mitmproxy/proxy/layers/__init__.py b/mitmproxy/proxy/layers/__init__.py index 55553b258c..7746ed9657 100644 --- a/mitmproxy/proxy/layers/__init__.py +++ b/mitmproxy/proxy/layers/__init__.py @@ -1,6 +1,7 @@ from . import modes from .dns import DNSLayer from .http import HttpLayer +from .quic import ClientQuicLayer, ServerQuicLayer from .tcp import TCPLayer from .tls import ClientTLSLayer, ServerTLSLayer from .websocket import WebsocketLayer @@ -9,6 +10,8 @@ "modes", "DNSLayer", "HttpLayer", + "ClientQuicLayer", + "ServerQuicLayer", "TCPLayer", "ClientTLSLayer", "ServerTLSLayer", diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index b938c44916..829ef56929 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -52,7 +52,7 @@ class QuicTlsStartClientHook(StartHook): """ TLS negotation between mitmproxy and a client over QUIC is about to start. - An addon is expected to initialize at least data.certificate and data.certificate_private_key. + An addon is expected to initialize data.settings. (by default, this is done by `mitmproxy.addons.tlsconfig`) """ @@ -64,7 +64,7 @@ class QuicTlsStartServerHook(StartHook): """ TLS negotation between mitmproxy and a server over QUIC is about to start. - An addon is expected to initialize at least data.certificate and data.certificate_private_key. + An addon is expected to initialize data.settings. (by default, this is done by `mitmproxy.addons.tlsconfig`) """ @@ -75,9 +75,9 @@ class QuicLayer(layer.Layer): pass -class QuicServerLayer(layer.Layer): +class ServerQuicLayer(QuicLayer): pass -class QuicClientLayer(layer.Layer): +class ClientQuicLayer(QuicLayer): pass From 5926c45aa51316e0d9031eeeaaebcaa22d1055be Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Tue, 14 Jun 2022 16:16:46 +0200 Subject: [PATCH 004/529] [quic] replicate DestinationKnown in proxyserver --- mitmproxy/addons/proxyserver.py | 35 +++++++++---- mitmproxy/proxy/layers/http/__init__.py | 4 +- mitmproxy/proxy/layers/quic.py | 67 +++++++++++++++++++++++-- 3 files changed, 90 insertions(+), 16 deletions(-) diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index 2d20a23d7e..2e016eb209 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -27,7 +27,7 @@ ) from mitmproxy.connection import Address from mitmproxy.flow import Flow -from mitmproxy.net import udp +from mitmproxy.net import server_spec, udp from mitmproxy.proxy import commands, events, layer, layers, server, server_hooks from mitmproxy.proxy.context import Context from mitmproxy.proxy.layers.tcp import TcpMessageInjected @@ -359,8 +359,9 @@ def handle_udp_connection( data: bytes, remote_addr: Address, connection_id: tuple, - layer_cb: Callable[[Context], layer.Layer], + layer_factory: Callable[[Context], layer.Layer], server_addr: Optional[Address] = None, + server_sni: Optional[str] = None, timeout: Optional[int] = None, ) -> None: if connection_id not in self._connections: @@ -369,10 +370,10 @@ def handle_udp_connection( handler = ProxyConnectionHandler( self.master, reader, writer, self.options, timeout ) - handler.layer = layer_cb(handler.layer.context) - if server_addr is not None: - handler.layer.context.server.address = server_addr - handler.layer.context.server.transport_protocol = "udp" + handler.layer = layer_factory(handler.layer.context) + handler.layer.context.server.transport_protocol = "udp" + handler.layer.context.server.address = server_addr + handler.layer.context.server.sni = server_sni self._connections[connection_id] = handler asyncio.create_task(self.handle_connection(connection_id)) else: @@ -406,7 +407,7 @@ def handle_dns_datagram( else self.dns_reverse_addr ), connection_id=("udp", dns_id, remote_addr, local_addr), - layer_cb=layers.DNSLayer, + layer_factory=layers.DNSLayer, timeout=20, ) @@ -446,7 +447,7 @@ def handle_quic_datagram( ) return - # create or resume the connection + # check if a new connection is possible connection_id = ("quic", header.destination_cid) if connection_id not in self._connections: if len(data) < 1200 or header.packet_type != PACKET_TYPE_INITIAL: @@ -454,13 +455,27 @@ def handle_quic_datagram( f"QUIC packet received from {human.format_address(remote_addr)} with an unknown connection id." ) return + + # determine the server settings (similar to modes.DestinationKnown) + server_addr: Optional[Address] = None + server_sni: Optional[str] = None + if self.options.mode == "transparent": + server_addr = local_addr + elif self.options.mode.startswith("reverse:"): + spec = server_spec.parse_with_mode(self.options.mode)[1] + server_addr = spec.address + if not self.options.keep_host_header: + server_sni = spec.address[0] + + # create or resume the connection self.handle_udp_connection( transport=transport, date=data, remote_addr=remote_addr, - server_addr=local_addr if self.options.mode == "transparent" else None, + server_addr=server_addr, + server_sni=server_sni, connection_id=connection_id, - layer_cb=layers.ServerQuicLayer, + layer_factory=layers.ClientQuicLayer, ) def inject_event(self, event: events.MessageInjected): diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index f5327722f0..36e0969a9d 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -879,7 +879,9 @@ def _handle_event(self, event: events.Event): elif isinstance(event, events.DataReceived): # The peer has sent data. This can happen with HTTP/2 servers that already send a settings frame. child_layer: HttpConnection - if self.context.server.alpn == b"h2": + if self.context.server.alpn == b"h3": + child_layer = Http3Client(self.context.fork()) + elif self.context.server.alpn == b"h2": child_layer = Http2Client(self.context.fork()) else: child_layer = Http1Client(self.context.fork()) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 829ef56929..14437074f8 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -1,12 +1,17 @@ +from abc import abstractmethod from dataclasses import dataclass +import io from ssl import VerifyMode from typing import List, Optional, Union -from aioquic.tls import CipherSuite +from aioquic.buffer import Buffer as QuicBuffer +from aioquic.quic.connection import QuicConnection +from aioquic.tls import CipherSuite, HandshakeType from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa from mitmproxy.proxy import layer from mitmproxy.proxy.commands import StartHook +from mitmproxy.proxy.context import Context from mitmproxy.tls import TlsData @@ -19,13 +24,13 @@ class QuicTlsSettings: certificate: Optional[x509.Certificate] = None """The certificate to use for the connection.""" certificate_chain: List[x509.Certificate] = [] - """An optional list of additional certificates to send to the peer.""" + """A list of additional certificates to send to the peer.""" certificate_private_key: Optional[ Union[dsa.DSAPrivateKey, ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey] ] = None """The certificate's private key.""" cipher_suites: Optional[List[CipherSuite]] = None - """An optional list of allowed/advertised protocols.""" + """An optional list of allowed/advertised cipher suites.""" ca_path: Optional[str] = None """An optional path to a directory that contains the necessary information to verify the peer certificate.""" ca_file: Optional[str] = None @@ -72,12 +77,64 @@ class QuicTlsStartServerHook(StartHook): class QuicLayer(layer.Layer): - pass + conn: QuicConnection + + +self._protocols[connection.host_cid] = protocol + + def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol): + self._protocols[cid] = protocol + + def _connection_id_retired( + self, cid: bytes, protocol: QuicConnectionProtocol + ) -> None: + assert self._protocols[cid] == protocol + del self._protocols[cid] + + def _connection_terminated(self, protocol: QuicConnectionProtocol): + for cid, proto in list(self._protocols.items()): + if proto == protocol: + del self._protocols[cid] class ServerQuicLayer(QuicLayer): + """ + This layer establishes QUIC for a single server connection. + """ pass +@dataclass +class ClientHelloException(Exception): + data: bytes + + class ClientQuicLayer(QuicLayer): - pass + _intercept_client_hello: bool + + """ + This layer establishes QUIC on a single client connection. + """ + + def __init__(self, context: Context) -> None: + super().__init__(context) + + # patch aioquic to intercept the client hello + orig_initialize = self.conn._initialize + def initialize_replacement(peer_cid: bytes) -> None: + try: + return orig_initialize(peer_cid) + finally: + orig_server_handle_hello = self.conn.tls._server_handle_hello + def server_handle_hello_replacement( + input_buf: QuicBuffer, + initial_buf: QuicBuffer, + handshake_buf: QuicBuffer, + onertt_buf: QuicBuffer, + ) -> None: + if self._intercept_client_hello and input_buf.pull_uint8() == HandshakeType.CLIENT_HELLO: + raise ClientHelloException(input_buf.data[:input_buf.tell()]) + else: + orig_server_handle_hello(input_buf, initial_buf, handshake_buf, onertt_buf) + self.conn.tls._server_handle_hello = server_handle_hello_replacement + self.conn._initialize = initialize_replacement From ff42d291147eb444c785842b71aac07ee7bfa46e Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Thu, 16 Jun 2022 04:03:28 +0200 Subject: [PATCH 005/529] [quic] connection_id handling --- mitmproxy/addons/proxyserver.py | 40 ++++-- mitmproxy/addons/tlsconfig.py | 6 +- mitmproxy/proxy/layers/quic.py | 211 ++++++++++++++++++++++++-------- 3 files changed, 194 insertions(+), 63 deletions(-) diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index 2e016eb209..b856ac7daf 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -29,7 +29,6 @@ from mitmproxy.flow import Flow from mitmproxy.net import server_spec, udp from mitmproxy.proxy import commands, events, layer, layers, server, server_hooks -from mitmproxy.proxy.context import Context from mitmproxy.proxy.layers.tcp import TcpMessageInjected from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected from mitmproxy.utils import asyncio_utils, human @@ -359,23 +358,26 @@ def handle_udp_connection( data: bytes, remote_addr: Address, connection_id: tuple, - layer_factory: Callable[[Context], layer.Layer], + layer_factory: Callable[[ProxyConnectionHandler], layer.Layer], server_addr: Optional[Address] = None, server_sni: Optional[str] = None, + done_callback: Optional[Callable[[ProxyConnectionHandler]]] = None, timeout: Optional[int] = None, - ) -> None: + ) -> Optional[asyncio.Task[None]]: if connection_id not in self._connections: reader = udp.DatagramReader() writer = udp.DatagramWriter(transport, remote_addr, reader) handler = ProxyConnectionHandler( self.master, reader, writer, self.options, timeout ) - handler.layer = layer_factory(handler.layer.context) + handler.layer = layer_factory(handler) handler.layer.context.server.transport_protocol = "udp" handler.layer.context.server.address = server_addr handler.layer.context.server.sni = server_sni self._connections[connection_id] = handler - asyncio.create_task(self.handle_connection(connection_id)) + task = asyncio.create_task(self.handle_connection(connection_id)) + if done_callback is not None: + task.add_done_callback(lambda _: done_callback(handler)) else: handler = self._connections[connection_id] client_reader = handler.transports[handler.client].reader @@ -407,7 +409,7 @@ def handle_dns_datagram( else self.dns_reverse_addr ), connection_id=("udp", dns_id, remote_addr, local_addr), - layer_factory=layers.DNSLayer, + layer_factory=lambda handler: layers.DNSLayer(handler.layer.context), timeout=20, ) @@ -467,6 +469,25 @@ def handle_quic_datagram( if not self.options.keep_host_header: server_sni = spec.address[0] + # define the callback functions + connection_ids = set([connection_id]) + + def cleanup_connection_ids(handler: ProxyConnectionHandler) -> None: + for connection_id in connection_ids: + if connection_id in self._connections: + del self._connections[connection_id] + + def issue_connection_id(handler: ProxyConnectionHandler, cid: bytes) -> None: + connection_id = ("quic", cid) + assert connection_id not in self._connections + self._connections[connection_id] = handler + connection_ids.add(connection_id) + + def retire_connection_id(handler: ProxyConnectionHandler, cid: bytes) -> None: + connection_id = ("quic", cid) + connection_ids.remove(connection_id) + del self._connections[connection_id] + # create or resume the connection self.handle_udp_connection( transport=transport, @@ -475,7 +496,12 @@ def handle_quic_datagram( server_addr=server_addr, server_sni=server_sni, connection_id=connection_id, - layer_factory=layers.ClientQuicLayer, + done_callback=cleanup_connection_ids, + layer_factory=lambda handler: layers.ClientQuicLayer( + context=handler.layer.context, + issue_cid=lambda cid: issue_connection_id(handler, cid), + retire_cid=lambda cid: retire_connection_id(handler, cid), + ), ) def inject_event(self, event: events.MessageInjected): diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index 34ca048c16..7a7c7fbffc 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -201,7 +201,7 @@ def tls_start_client(self, tls_start: tls.TlsData) -> None: ) tls_start.ssl_conn.set_accept_state() - def _get_client_cert(self, server: connection.Server) -> Optional[str]: + def get_client_cert(self, server: connection.Server) -> Optional[str]: if ctx.options.client_certs: client_certs = os.path.expanduser(ctx.options.client_certs) if os.path.isfile(client_certs): @@ -265,7 +265,7 @@ def tls_start_server(self, tls_start: tls.TlsData) -> None: verify=verify, ca_path=ctx.options.ssl_verify_upstream_trusted_confdir, ca_pemfile=ctx.options.ssl_verify_upstream_trusted_ca, - client_cert=self._get_client_cert(server), + client_cert=self.get_client_cert(server), ) tls_start.ssl_conn = SSL.Connection(ssl_ctx) @@ -352,7 +352,7 @@ def quic_tls_start_server(self, tls_start: quic.QuicTlsData) -> None: CipherSuite(cipher) for cipher in server.cipher_list ] - client_cert = self._get_client_cert(server) + client_cert = self.get_client_cert(server) if client_cert: config = QuicConfiguration() config.load_cert_chain(client_cert) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 14437074f8..9a068a032c 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -1,17 +1,22 @@ -from abc import abstractmethod from dataclasses import dataclass -import io from ssl import VerifyMode -from typing import List, Optional, Union +from typing import Callable, List, Optional, TextIO, Union from aioquic.buffer import Buffer as QuicBuffer +from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.connection import QuicConnection -from aioquic.tls import CipherSuite, HandshakeType +from aioquic.tls import ( + CipherSuite, + Context as QuicTlsContext, + HandshakeType, + ServerHello, + pull_server_hello, +) from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa -from mitmproxy.proxy import layer -from mitmproxy.proxy.commands import StartHook -from mitmproxy.proxy.context import Context +from mitmproxy import connection +from mitmproxy.net import tls +from mitmproxy.proxy import commands, context, layer from mitmproxy.tls import TlsData @@ -53,9 +58,9 @@ class QuicTlsData(TlsData): @dataclass -class QuicTlsStartClientHook(StartHook): +class QuicTlsStartClientHook(connection.StartHook): """ - TLS negotation between mitmproxy and a client over QUIC is about to start. + TLS negotiation between mitmproxy and a client over QUIC is about to start. An addon is expected to initialize data.settings. (by default, this is done by `mitmproxy.addons.tlsconfig`) @@ -65,9 +70,9 @@ class QuicTlsStartClientHook(StartHook): @dataclass -class QuicTlsStartServerHook(StartHook): +class QuicTlsStartServerHook(connection.StartHook): """ - TLS negotation between mitmproxy and a server over QUIC is about to start. + TLS negotiation between mitmproxy and a server over QUIC is about to start. An addon is expected to initialize data.settings. (by default, this is done by `mitmproxy.addons.tlsconfig`) @@ -76,65 +81,165 @@ class QuicTlsStartServerHook(StartHook): data: QuicTlsData -class QuicLayer(layer.Layer): - conn: QuicConnection +class QuicSecretsLogger(TextIO): + conn: connection.Connection + logger: tls.MasterSecretLogger + + def __init__( + self, conn: connection.Connection, logger: tls.MasterSecretLogger + ) -> None: + super().__init__() + self.conn = conn + self.logger = logger + + def write(self, s: str) -> int: + self.logger(self.conn, s.encode()) + + def flush(self) -> None: + # done by the logger during write + pass + + +@dataclass +class QuicClientHelloException(Exception): + data: bytes + + +def hook_quic_tls(quic: QuicConnection, cb: Callable[[QuicTlsContext]]) -> None: + assert quic.tls is None + + # patch aioquic to intercept the client/server hello + orig_initialize = quic._initialize + + def initialize_replacement(peer_cid: bytes) -> None: + try: + return orig_initialize(peer_cid) + finally: + cb(quic.tls) + + quic._initialize = initialize_replacement + +def throw_on_client_hello(tls: QuicTlsContext) -> None: + def server_handle_hello_replacement( + input_buf: QuicBuffer, + initial_buf: QuicBuffer, + handshake_buf: QuicBuffer, + onertt_buf: QuicBuffer, + ) -> None: + assert input_buf.pull_uint8() == HandshakeType.CLIENT_HELLO + length = 0 + for b in input_buf.pull_bytes(3): + length = (length << 8) | b + offset = input_buf.tell() + raise QuicClientHelloException( + data=input_buf.data_slice(offset, offset + length) + ) + + tls._server_handle_hello = server_handle_hello_replacement -self._protocols[connection.host_cid] = protocol - def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol): - self._protocols[cid] = protocol +def callback_on_server_hello(tls: QuicTlsContext, cb: Callable[[ServerHello]]) -> None: + orig_client_handle_hello = tls._client_handle_hello - def _connection_id_retired( - self, cid: bytes, protocol: QuicConnectionProtocol + def _client_handle_hello_replacement( + input_buf: QuicBuffer, + output_buf: QuicBuffer, ) -> None: - assert self._protocols[cid] == protocol - del self._protocols[cid] + offset = input_buf.tell() + cb(pull_server_hello(input_buf)) + input_buf.seek(offset) + orig_client_handle_hello(input_buf, output_buf) + + tls._client_handle_hello = _client_handle_hello_replacement + - def _connection_terminated(self, protocol: QuicConnectionProtocol): - for cid, proto in list(self._protocols.items()): - if proto == protocol: - del self._protocols[cid] +class QuicLayer(layer.Layer): + buffer: List[bytes] + quic: Optional[QuicConnection] + conn: connection.Connection + issue_cid: Callable[[bytes]] + retire_cid: Callable[[bytes]] + + def __init__( + self, + context: context.Context, + conn: connection.Connection, + issue_cid: Callable[[bytes]], + retire_cid: Callable[[bytes]], + ) -> None: + super().__init__(context) + self.buffer = [] + self.quic = None + self.conn = conn + + def build_configuration(self, settings: QuicTlsSettings) -> QuicConfiguration: + return QuicConfiguration( + alpn_protocols=self.conn.alpn_offers, + connection_id_length=self.context.options.quic_connection_id_length, + is_client=self.conn == self.context.server, + secrets_log_file=QuicSecretsLogger(self.conn, tls.log_master_secret) + if tls.log_master_secret is not None + else None, + server_name=self.conn.sni, + cafile=settings.ca_file, + capath=settings.ca_path, + certificate=settings.certificate, + certificate_chain=settings.certificate_chain, + cipher_suites=settings.cipher_suites, + private_key=settings.certificate_private_key, + verify_mode=settings.verify_mode, + ) + + def initialize_connection( + self, original_destination_connection_id: Union[bytes, None] + ) -> layer.CommandGenerator[None]: + assert not self.quic + + # (almost) identical to _TLSLayer.start_tls + tls_data = QuicTlsData(self.conn, self.context) + if self.conn == self.context.client: + yield QuicTlsStartClientHook(tls_data) + else: + yield QuicTlsStartServerHook(tls_data) + if not tls_data.settings: + yield commands.Log( + "No TLS settings were provided, failing connection.", "error" + ) + yield commands.CloseConnection(self.conn) + return + assert tls_data.settings + + self.quic = QuicConnection( + configuration=self.build_configuration(tls_data.settings), + original_destination_connection_id=original_destination_connection_id, + ) + self.issue_cid(self.quic.host_cid) class ServerQuicLayer(QuicLayer): """ This layer establishes QUIC for a single server connection. """ - pass - -@dataclass -class ClientHelloException(Exception): - data: bytes + def __init__( + self, + context: context.Context, + issue_cid: Callable[[bytes]], + retire_cid: Callable[[bytes]], + ) -> None: + super().__init__(context, context.server, issue_cid, retire_cid) class ClientQuicLayer(QuicLayer): - _intercept_client_hello: bool - """ This layer establishes QUIC on a single client connection. """ - def __init__(self, context: Context) -> None: - super().__init__(context) - - # patch aioquic to intercept the client hello - orig_initialize = self.conn._initialize - def initialize_replacement(peer_cid: bytes) -> None: - try: - return orig_initialize(peer_cid) - finally: - orig_server_handle_hello = self.conn.tls._server_handle_hello - def server_handle_hello_replacement( - input_buf: QuicBuffer, - initial_buf: QuicBuffer, - handshake_buf: QuicBuffer, - onertt_buf: QuicBuffer, - ) -> None: - if self._intercept_client_hello and input_buf.pull_uint8() == HandshakeType.CLIENT_HELLO: - raise ClientHelloException(input_buf.data[:input_buf.tell()]) - else: - orig_server_handle_hello(input_buf, initial_buf, handshake_buf, onertt_buf) - self.conn.tls._server_handle_hello = server_handle_hello_replacement - self.conn._initialize = initialize_replacement + def __init__( + self, + context: context.Context, + issue_cid: Callable[[bytes]], + retire_cid: Callable[[bytes]], + ) -> None: + super().__init__(context, context.client, issue_cid, retire_cid) From c271f246e247d2cf847141a80fdd31a595e99993 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Thu, 16 Jun 2022 05:54:09 +0200 Subject: [PATCH 006/529] [quic] parse client hello --- mitmproxy/proxy/layers/quic.py | 89 ++++++++++++++++++++++++++++++++-- 1 file changed, 85 insertions(+), 4 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 9a068a032c..41f69e7aa4 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -1,10 +1,11 @@ +import asyncio from dataclasses import dataclass from ssl import VerifyMode from typing import Callable, List, Optional, TextIO, Union from aioquic.buffer import Buffer as QuicBuffer from aioquic.quic.configuration import QuicConfiguration -from aioquic.quic.connection import QuicConnection +from aioquic.quic.connection import QuicConnection, QuicConnectionError from aioquic.tls import ( CipherSuite, Context as QuicTlsContext, @@ -12,12 +13,14 @@ ServerHello, pull_server_hello, ) +from aioquic.quic.packet import PACKET_TYPE_INITIAL, pull_quic_header from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa from mitmproxy import connection from mitmproxy.net import tls -from mitmproxy.proxy import commands, context, layer -from mitmproxy.tls import TlsData +from mitmproxy.proxy import commands, context, events, layer, layers +from mitmproxy.proxy.utils import expect +from mitmproxy.tls import ClientHello, ClientHelloData, TlsData @dataclass @@ -120,7 +123,7 @@ def initialize_replacement(peer_cid: bytes) -> None: quic._initialize = initialize_replacement -def throw_on_client_hello(tls: QuicTlsContext) -> None: +def raise_on_client_hello(tls: QuicTlsContext) -> None: def server_handle_hello_replacement( input_buf: QuicBuffer, initial_buf: QuicBuffer, @@ -154,7 +157,31 @@ def _client_handle_hello_replacement( tls._client_handle_hello = _client_handle_hello_replacement +def read_client_hello(data: bytes, connection_id_length: int) -> ClientHello: + buffer = QuicBuffer(data=data) + header = pull_quic_header( + buffer, host_cid_length=connection_id_length + ) + assert header.packet_type == PACKET_TYPE_INITIAL + temp_quic = QuicConnection( + configuration=QuicConfiguration(connection_id_length=connection_id_length), + original_destination_connection_id=header.destination_cid, + ) + hook_quic_tls(temp_quic, raise_on_client_hello) + try: + temp_quic.receive_datagram(data, ("0.0.0.0", 0), now=0) + except QuicClientHelloException as hello: + try: + return ClientHello(hello.data) + except EOFError as e: + raise ValueError("Invalid ClientHello data.") from e + except QuicConnectionError as e: + raise ValueError(e.reason_phrase) from e + raise ValueError("No ClientHello returned.") + + class QuicLayer(layer.Layer): + loop: asyncio.AbstractEventLoop buffer: List[bytes] quic: Optional[QuicConnection] conn: connection.Connection @@ -169,9 +196,12 @@ def __init__( retire_cid: Callable[[bytes]], ) -> None: super().__init__(context) + self.loop = asyncio.get_event_loop() self.buffer = [] self.quic = None self.conn = conn + self.issue_cid = issue_cid + self.retire_cid = retire_cid def build_configuration(self, settings: QuicTlsSettings) -> QuicConfiguration: return QuicConfiguration( @@ -243,3 +273,54 @@ def __init__( retire_cid: Callable[[bytes]], ) -> None: super().__init__(context, context.client, issue_cid, retire_cid) + + @expect(events.Start) + def handle_start(self, _: events.Event) -> layer.CommandGenerator[None]: + self._handle_event = self.handle_client_hello + yield from () + + @expect(events.DataReceived, events.ConnectionClosed) + def handle_client_hello(self, event: events.Event) -> layer.CommandGenerator[None]: + if isinstance(event, events.DataReceived): + assert event.connection == self.conn + + # extract the client hello + try: + client_hello = read_client_hello(event.data, connection_id_length=self.context.options.quic_connection_id_length) + except ValueError as e: + yield commands.Log( + f"Cannot parse ClientHello: {str(e)} ({event.data.hex()})", "warn" + ) + yield commands.CloseConnection(self.conn) + else: + self.conn.sni = client_hello.sni + self.conn.alpn_offers = client_hello.alpn_protocols + + # check with addons what we shall do + hook_data = ClientHelloData(self.context, client_hello) + yield layers.tls.TlsClienthelloHook(hook_data) + if hook_data.ignore_connection: + # simply relay everything (including the client hello) + relay_layer = layers.TCPLayer(self.context, ignore=True) + self._handle_event = relay_layer.handle_event + yield from relay_layer.handle_event(events.Start()) + yield from relay_layer.handle_event(event) + + elif hook_data.establish_server_tls_first: + pass + + else: + pass + + elif isinstance(event, events.ConnectionClosed): + assert event.connection == self.conn + self._handle_event = self.handle_done + + else: + raise AssertionError(f"Unexpected event: {event}") + + @expect(events.DataReceived, events.ConnectionClosed) + def handle_done(self, _) -> layer.CommandGenerator[None]: + yield from () + + _handle_event = handle_start From 0e49eebe62dbe3e527ba95934e05ade39fde62d8 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Thu, 16 Jun 2022 14:09:38 +0200 Subject: [PATCH 007/529] [quic] support roaming --- mitmproxy/addons/proxyserver.py | 27 +++++++++++++++---------- mitmproxy/net/udp.py | 16 +++++++-------- mitmproxy/proxy/commands.py | 8 +++++--- mitmproxy/proxy/events.py | 3 ++- mitmproxy/proxy/layers/quic.py | 36 ++++++++++++++++----------------- mitmproxy/proxy/server.py | 16 +++++++++++---- 6 files changed, 61 insertions(+), 45 deletions(-) diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index b856ac7daf..d163c293d4 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -3,7 +3,7 @@ import ipaddress import re import struct -from typing import Callable, Optional +from typing import Any, Callable, Optional from aioquic.buffer import Buffer as QuicBuffer from aioquic.quic.packet import ( @@ -361,9 +361,9 @@ def handle_udp_connection( layer_factory: Callable[[ProxyConnectionHandler], layer.Layer], server_addr: Optional[Address] = None, server_sni: Optional[str] = None, - done_callback: Optional[Callable[[ProxyConnectionHandler]]] = None, + done_callback: Optional[Callable[[ProxyConnectionHandler], Any]] = None, timeout: Optional[int] = None, - ) -> Optional[asyncio.Task[None]]: + ) -> None: if connection_id not in self._connections: reader = udp.DatagramReader() writer = udp.DatagramWriter(transport, remote_addr, reader) @@ -375,9 +375,11 @@ def handle_udp_connection( handler.layer.context.server.address = server_addr handler.layer.context.server.sni = server_sni self._connections[connection_id] = handler - task = asyncio.create_task(self.handle_connection(connection_id)) - if done_callback is not None: - task.add_done_callback(lambda _: done_callback(handler)) + asyncio.create_task( + self.handle_connection(connection_id) + ).add_done_callback( + lambda _: None if done_callback is None else done_callback(handler) + ) else: handler = self._connections[connection_id] client_reader = handler.transports[handler.client].reader @@ -401,7 +403,7 @@ def handle_dns_datagram( return self.handle_udp_connection( transport=transport, - date=data, + data=data, remote_addr=remote_addr, server_addr=( local_addr @@ -420,6 +422,9 @@ def handle_quic_datagram( remote_addr: Address, local_addr: Address, ) -> None: + def build_connection_id(cid: bytes) -> tuple: + return ("quic", cid, local_addr) + # largely taken from aioquic's own asyncio server code buffer = QuicBuffer(data=data) try: @@ -450,7 +455,7 @@ def handle_quic_datagram( return # check if a new connection is possible - connection_id = ("quic", header.destination_cid) + connection_id = build_connection_id(header.destination_cid) if connection_id not in self._connections: if len(data) < 1200 or header.packet_type != PACKET_TYPE_INITIAL: ctx.log.info( @@ -478,20 +483,20 @@ def cleanup_connection_ids(handler: ProxyConnectionHandler) -> None: del self._connections[connection_id] def issue_connection_id(handler: ProxyConnectionHandler, cid: bytes) -> None: - connection_id = ("quic", cid) + connection_id = build_connection_id(cid) assert connection_id not in self._connections self._connections[connection_id] = handler connection_ids.add(connection_id) def retire_connection_id(handler: ProxyConnectionHandler, cid: bytes) -> None: - connection_id = ("quic", cid) + connection_id = build_connection_id(cid) connection_ids.remove(connection_id) del self._connections[connection_id] # create or resume the connection self.handle_udp_connection( transport=transport, - date=data, + data=data, remote_addr=remote_addr, server_addr=server_addr, server_sni=server_sni, diff --git a/mitmproxy/net/udp.py b/mitmproxy/net/udp.py index c70647800e..61cf7cf336 100644 --- a/mitmproxy/net/udp.py +++ b/mitmproxy/net/udp.py @@ -4,7 +4,7 @@ import ipaddress import socket import struct -from typing import Any, Callable, Optional, Union, cast +from typing import Any, Callable, Optional, Tuple, Union, cast from mitmproxy import ctx from mitmproxy.connection import Address from mitmproxy.utils import human @@ -228,7 +228,7 @@ def close(self) -> None: class DatagramReader: - _packets: asyncio.Queue + _packets: asyncio.Queue[Tuple[bytes, Address]] _eof: bool def __init__(self) -> None: @@ -243,7 +243,7 @@ def feed_data(self, data: bytes, remote_addr: Address) -> None: ) else: try: - self._packets.put_nowait(data) + self._packets.put_nowait((data, remote_addr)) except asyncio.QueueFull: ctx.log.debug( f"Dropped UDP packet from {human.format_address(remote_addr)}." @@ -252,17 +252,17 @@ def feed_data(self, data: bytes, remote_addr: Address) -> None: def feed_eof(self) -> None: self._eof = True try: - self._packets.put_nowait(b"") + self._packets.put_nowait((b"", None)) # type: ignore except asyncio.QueueFull: pass - async def read(self, n: int) -> bytes: + async def read(self, n: int) -> Tuple[bytes, Address]: assert n >= MAX_DATAGRAM_SIZE if self._eof: try: return self._packets.get_nowait() except asyncio.QueueEmpty: - return b"" + return (b"", None) # type: ignore else: return await self._packets.get() @@ -295,8 +295,8 @@ def __init__( def _protocol(self) -> DrainableDatagramProtocol: return cast(DrainableDatagramProtocol, self._transport.get_protocol()) - def write(self, data: bytes) -> None: - self._transport.sendto(data, self._remote_addr) + def write(self, data: bytes, remote_addr: Optional[Address] = None) -> None: + self._transport.sendto(data, self._remote_addr if remote_addr is None else remote_addr) def write_eof(self) -> None: raise NotImplementedError("UDP does not support half-closing.") diff --git a/mitmproxy/proxy/commands.py b/mitmproxy/proxy/commands.py index 388abf9fe8..3845a7774e 100644 --- a/mitmproxy/proxy/commands.py +++ b/mitmproxy/proxy/commands.py @@ -6,10 +6,10 @@ The counterpart to commands are events. """ -from typing import Literal, Union, TYPE_CHECKING +from typing import Literal, Optional, Union, TYPE_CHECKING import mitmproxy.hooks -from mitmproxy.connection import Connection, Server +from mitmproxy.connection import Address, Connection, Server if TYPE_CHECKING: import mitmproxy.proxy.layer @@ -67,10 +67,12 @@ class SendData(ConnectionCommand): """ data: bytes + remote_addr: Optional[Address] - def __init__(self, connection: Connection, data: bytes): + def __init__(self, connection: Connection, data: bytes, remote_addr: Optional[Address] = None): super().__init__(connection) self.data = data + self.remote_addr = remote_addr def __repr__(self): target = str(self.connection).split("(", 1)[0].lower() diff --git a/mitmproxy/proxy/events.py b/mitmproxy/proxy/events.py index b767483f0e..fb1e925f23 100644 --- a/mitmproxy/proxy/events.py +++ b/mitmproxy/proxy/events.py @@ -10,7 +10,7 @@ from mitmproxy import flow from mitmproxy.proxy import commands -from mitmproxy.connection import Connection +from mitmproxy.connection import Address, Connection class Event: @@ -45,6 +45,7 @@ class DataReceived(ConnectionEvent): """ data: bytes + remote_addr: Optional[Address] = None def __repr__(self): target = type(self.connection).__name__.lower() diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 41f69e7aa4..484013de54 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -61,7 +61,7 @@ class QuicTlsData(TlsData): @dataclass -class QuicTlsStartClientHook(connection.StartHook): +class QuicTlsStartClientHook(commands.StartHook): """ TLS negotiation between mitmproxy and a client over QUIC is about to start. @@ -73,7 +73,7 @@ class QuicTlsStartClientHook(connection.StartHook): @dataclass -class QuicTlsStartServerHook(connection.StartHook): +class QuicTlsStartServerHook(commands.StartHook): """ TLS negotiation between mitmproxy and a server over QUIC is about to start. @@ -84,19 +84,21 @@ class QuicTlsStartServerHook(connection.StartHook): data: QuicTlsData -class QuicSecretsLogger(TextIO): - conn: connection.Connection +class QuicSecretsLogger: logger: tls.MasterSecretLogger def __init__( - self, conn: connection.Connection, logger: tls.MasterSecretLogger + self, logger: tls.MasterSecretLogger ) -> None: super().__init__() - self.conn = conn self.logger = logger def write(self, s: str) -> int: - self.logger(self.conn, s.encode()) + if s.endswith("\n"): + s = s[:-1] + data = s.encode() + self.logger(None, data) # type: ignore + return len(data) + 1 def flush(self) -> None: # done by the logger during write @@ -108,7 +110,7 @@ class QuicClientHelloException(Exception): data: bytes -def hook_quic_tls(quic: QuicConnection, cb: Callable[[QuicTlsContext]]) -> None: +def hook_quic_tls(quic: QuicConnection, cb: Callable[[QuicTlsContext], None]) -> None: assert quic.tls is None # patch aioquic to intercept the client/server hello @@ -142,7 +144,7 @@ def server_handle_hello_replacement( tls._server_handle_hello = server_handle_hello_replacement -def callback_on_server_hello(tls: QuicTlsContext, cb: Callable[[ServerHello]]) -> None: +def callback_on_server_hello(tls: QuicTlsContext, cb: Callable[[ServerHello], None]) -> None: orig_client_handle_hello = tls._client_handle_hello def _client_handle_hello_replacement( @@ -185,15 +187,13 @@ class QuicLayer(layer.Layer): buffer: List[bytes] quic: Optional[QuicConnection] conn: connection.Connection - issue_cid: Callable[[bytes]] - retire_cid: Callable[[bytes]] def __init__( self, context: context.Context, conn: connection.Connection, - issue_cid: Callable[[bytes]], - retire_cid: Callable[[bytes]], + issue_cid: Callable[[bytes], None], + retire_cid: Callable[[bytes], None], ) -> None: super().__init__(context) self.loop = asyncio.get_event_loop() @@ -208,7 +208,7 @@ def build_configuration(self, settings: QuicTlsSettings) -> QuicConfiguration: alpn_protocols=self.conn.alpn_offers, connection_id_length=self.context.options.quic_connection_id_length, is_client=self.conn == self.context.server, - secrets_log_file=QuicSecretsLogger(self.conn, tls.log_master_secret) + secrets_log_file=QuicSecretsLogger(tls.log_master_secret) if tls.log_master_secret is not None else None, server_name=self.conn.sni, @@ -255,8 +255,8 @@ class ServerQuicLayer(QuicLayer): def __init__( self, context: context.Context, - issue_cid: Callable[[bytes]], - retire_cid: Callable[[bytes]], + issue_cid: Callable[[bytes], None], + retire_cid: Callable[[bytes], None], ) -> None: super().__init__(context, context.server, issue_cid, retire_cid) @@ -269,8 +269,8 @@ class ClientQuicLayer(QuicLayer): def __init__( self, context: context.Context, - issue_cid: Callable[[bytes]], - retire_cid: Callable[[bytes]], + issue_cid: Callable[[bytes], None], + retire_cid: Callable[[bytes], None], ) -> None: super().__init__(context, context.client, issue_cid, retire_cid) diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 20fd4233bf..8f213a55bb 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -249,9 +249,13 @@ async def handle_connection(self, connection: Connection) -> None: cancelled = None reader = self.transports[connection].reader assert reader + has_remote_addr = isinstance(reader, udp.DatagramReader) while True: try: - data = await reader.read(65535) + if has_remote_addr: + data, remote_addr = await reader.read(65535) + else: + data, remote_addr = await reader.read(65535), None if not data: raise OSError("Connection closed by peer.") except OSError: @@ -260,7 +264,7 @@ async def handle_connection(self, connection: Connection) -> None: cancelled = e break - self.server_event(events.DataReceived(connection, data)) + self.server_event(events.DataReceived(connection, data, remote_addr)) try: await self.drain_writers() @@ -353,8 +357,12 @@ def server_event(self, event: events.Event) -> None: pass # The connection has already been closed. elif isinstance(command, commands.SendData): writer = self.transports[command.connection].writer - assert writer - writer.write(command.data) + if command.remote_addr is not None: + assert isinstance(writer, udp.DatagramWriter) + writer.write(command.data, command.remote_addr) + else: + assert writer + writer.write(command.data) elif isinstance(command, commands.CloseConnection): self.close_connection(command.connection, command.half_close) elif isinstance(command, commands.GetSocket): From f4cb656b431a5667dbc09f1fff4971e3fa67fafa Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Sat, 18 Jun 2022 03:18:45 +0200 Subject: [PATCH 008/529] [quic] more work on TLS --- mitmproxy/proxy/layers/quic.py | 230 +++++++++++++++++++++------------ 1 file changed, 147 insertions(+), 83 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 484013de54..af30a272fa 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -1,26 +1,21 @@ import asyncio from dataclasses import dataclass from ssl import VerifyMode -from typing import Callable, List, Optional, TextIO, Union +from typing import Callable, List, Optional, Tuple, Union from aioquic.buffer import Buffer as QuicBuffer +from aioquic.quic import events as quic_events from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.connection import QuicConnection, QuicConnectionError -from aioquic.tls import ( - CipherSuite, - Context as QuicTlsContext, - HandshakeType, - ServerHello, - pull_server_hello, -) +from aioquic.tls import CipherSuite, HandshakeType from aioquic.quic.packet import PACKET_TYPE_INITIAL, pull_quic_header from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa from mitmproxy import connection from mitmproxy.net import tls from mitmproxy.proxy import commands, context, events, layer, layers -from mitmproxy.proxy.utils import expect from mitmproxy.tls import ClientHello, ClientHelloData, TlsData +from mitmproxy.utils import human @dataclass @@ -84,17 +79,25 @@ class QuicTlsStartServerHook(commands.StartHook): data: QuicTlsData +@dataclass +class QuicStreamDataReceived(quic_events.StreamDataReceived, events.ConnectionEvent): + pass + + +@dataclass +class QuicStreamReset(quic_events.StreamReset, events.ConnectionEvent): + pass + + class QuicSecretsLogger: logger: tls.MasterSecretLogger - def __init__( - self, logger: tls.MasterSecretLogger - ) -> None: + def __init__(self, logger: tls.MasterSecretLogger) -> None: super().__init__() self.logger = logger def write(self, s: str) -> int: - if s.endswith("\n"): + if s[-1:] == "\n": s = s[:-1] data = s.encode() self.logger(None, data) # type: ignore @@ -106,26 +109,24 @@ def flush(self) -> None: @dataclass -class QuicClientHelloException(Exception): +class QuicClientHello(Exception): data: bytes -def hook_quic_tls(quic: QuicConnection, cb: Callable[[QuicTlsContext], None]) -> None: - assert quic.tls is None - - # patch aioquic to intercept the client/server hello - orig_initialize = quic._initialize - - def initialize_replacement(peer_cid: bytes) -> None: - try: - return orig_initialize(peer_cid) - finally: - cb(quic.tls) - - quic._initialize = initialize_replacement +def read_client_hello(data: bytes) -> ClientHello: + # ensure the first packet is indeed the initial one + buffer = QuicBuffer(data=data) + header = pull_quic_header(buffer) + if header.packet_type != PACKET_TYPE_INITIAL: + raise ValueError("Packet is not initial one.") + # patch aioquic to intercept the client hello + quic = QuicConnection( + configuration=QuicConfiguration(), + original_destination_connection_id=header.destination_cid, + ) + _initialize = quic._initialize -def raise_on_client_hello(tls: QuicTlsContext) -> None: def server_handle_hello_replacement( input_buf: QuicBuffer, initial_buf: QuicBuffer, @@ -137,42 +138,18 @@ def server_handle_hello_replacement( for b in input_buf.pull_bytes(3): length = (length << 8) | b offset = input_buf.tell() - raise QuicClientHelloException( - data=input_buf.data_slice(offset, offset + length) - ) - - tls._server_handle_hello = server_handle_hello_replacement - - -def callback_on_server_hello(tls: QuicTlsContext, cb: Callable[[ServerHello], None]) -> None: - orig_client_handle_hello = tls._client_handle_hello - - def _client_handle_hello_replacement( - input_buf: QuicBuffer, - output_buf: QuicBuffer, - ) -> None: - offset = input_buf.tell() - cb(pull_server_hello(input_buf)) - input_buf.seek(offset) - orig_client_handle_hello(input_buf, output_buf) - - tls._client_handle_hello = _client_handle_hello_replacement + raise QuicClientHello(data=input_buf.data_slice(offset, offset + length)) + def initialize_replacement(peer_cid: bytes) -> None: + try: + return _initialize(peer_cid) + finally: + quic.tls._server_handle_hello = server_handle_hello_replacement -def read_client_hello(data: bytes, connection_id_length: int) -> ClientHello: - buffer = QuicBuffer(data=data) - header = pull_quic_header( - buffer, host_cid_length=connection_id_length - ) - assert header.packet_type == PACKET_TYPE_INITIAL - temp_quic = QuicConnection( - configuration=QuicConfiguration(connection_id_length=connection_id_length), - original_destination_connection_id=header.destination_cid, - ) - hook_quic_tls(temp_quic, raise_on_client_hello) + quic._initialize = initialize_replacement try: - temp_quic.receive_datagram(data, ("0.0.0.0", 0), now=0) - except QuicClientHelloException as hello: + quic.receive_datagram(data, ("0.0.0.0", 0), now=0) + except QuicClientHello as hello: try: return ClientHello(hello.data) except EOFError as e: @@ -184,9 +161,9 @@ def read_client_hello(data: bytes, connection_id_length: int) -> ClientHello: class QuicLayer(layer.Layer): loop: asyncio.AbstractEventLoop - buffer: List[bytes] quic: Optional[QuicConnection] conn: connection.Connection + original_destination_connection_id: Optional[bytes] def __init__( self, @@ -207,8 +184,8 @@ def build_configuration(self, settings: QuicTlsSettings) -> QuicConfiguration: return QuicConfiguration( alpn_protocols=self.conn.alpn_offers, connection_id_length=self.context.options.quic_connection_id_length, - is_client=self.conn == self.context.server, - secrets_log_file=QuicSecretsLogger(tls.log_master_secret) + is_client=self.conn is self.context.server, + secrets_log_file=QuicSecretsLogger(tls.log_master_secret) # type: ignore if tls.log_master_secret is not None else None, server_name=self.conn.sni, @@ -221,14 +198,13 @@ def build_configuration(self, settings: QuicTlsSettings) -> QuicConfiguration: verify_mode=settings.verify_mode, ) - def initialize_connection( - self, original_destination_connection_id: Union[bytes, None] - ) -> layer.CommandGenerator[None]: + def initialize_connection(self) -> layer.CommandGenerator[None]: assert not self.quic + self._handle_event = self.handle_connected # (almost) identical to _TLSLayer.start_tls tls_data = QuicTlsData(self.conn, self.context) - if self.conn == self.context.client: + if self.conn is self.context.client: yield QuicTlsStartClientHook(tls_data) else: yield QuicTlsStartServerHook(tls_data) @@ -237,15 +213,22 @@ def initialize_connection( "No TLS settings were provided, failing connection.", "error" ) yield commands.CloseConnection(self.conn) + self._handle_event = self.handle_done return assert tls_data.settings self.quic = QuicConnection( configuration=self.build_configuration(tls_data.settings), - original_destination_connection_id=original_destination_connection_id, + original_destination_connection_id=self.original_destination_connection_id, ) self.issue_cid(self.quic.host_cid) + def process_events(self) -> layer.CommandGenerator[None]: + assert self.quic + + def handle_done(self, _) -> layer.CommandGenerator[None]: + yield from () + class ServerQuicLayer(QuicLayer): """ @@ -260,12 +243,34 @@ def __init__( ) -> None: super().__init__(context, context.server, issue_cid, retire_cid) + def handle_start(self, event: events.Event) -> layer.CommandGenerator[None]: + assert isinstance(event, events.Start) + + # ensure there is an UDP connection + if not self.conn.connected: + err = yield commands.OpenConnection(self.conn) + if err is not None: + yield commands.Log( + f"Failed to establish connection to {human.format_address(self.conn)}: {err}" + ) + self._handle_event = self.handle_done + return + + # try to connect + yield from self.initialize_connection() + if self.quic is not None: + self.quic.connect(addr=self.conn.peername, now=self.loop.time()) + yield from self.process_events() + class ClientQuicLayer(QuicLayer): """ This layer establishes QUIC on a single client connection. """ + server_layer: Optional[ServerQuicLayer] + buffered_packets: Optional[List[Tuple[bytes, connection.Address, float]]] + def __init__( self, context: context.Context, @@ -273,24 +278,40 @@ def __init__( retire_cid: Callable[[bytes], None], ) -> None: super().__init__(context, context.client, issue_cid, retire_cid) + self.server_layer = None + self.buffered_packets = None + + def start_client_connection(self) -> layer.CommandGenerator[None]: + assert self.buffered_packets is not None + + yield from self.initialize_connection() + if self.quic is not None: + for data, addr, now in self.buffered_packets: + self.quic.receive_datagram( + data=data, + addr=addr, + now=now, + ) + yield from self.process_events() - @expect(events.Start) - def handle_start(self, _: events.Event) -> layer.CommandGenerator[None]: + def handle_start(self, event: events.Event) -> layer.CommandGenerator[None]: + assert isinstance(event, events.Start) self._handle_event = self.handle_client_hello yield from () - @expect(events.DataReceived, events.ConnectionClosed) def handle_client_hello(self, event: events.Event) -> layer.CommandGenerator[None]: - if isinstance(event, events.DataReceived): - assert event.connection == self.conn + assert isinstance(event, events.ConnectionEvent) + assert event.connection is self.conn + if isinstance(event, events.DataReceived): # extract the client hello try: - client_hello = read_client_hello(event.data, connection_id_length=self.context.options.quic_connection_id_length) + client_hello = read_client_hello(event.data) except ValueError as e: yield commands.Log( f"Cannot parse ClientHello: {str(e)} ({event.data.hex()})", "warn" ) + self._handle_event = self.handle_done yield commands.CloseConnection(self.conn) else: self.conn.sni = client_hello.sni @@ -299,6 +320,7 @@ def handle_client_hello(self, event: events.Event) -> layer.CommandGenerator[Non # check with addons what we shall do hook_data = ClientHelloData(self.context, client_hello) yield layers.tls.TlsClienthelloHook(hook_data) + if hook_data.ignore_connection: # simply relay everything (including the client hello) relay_layer = layers.TCPLayer(self.context, ignore=True) @@ -306,21 +328,63 @@ def handle_client_hello(self, event: events.Event) -> layer.CommandGenerator[Non yield from relay_layer.handle_event(events.Start()) yield from relay_layer.handle_event(event) - elif hook_data.establish_server_tls_first: - pass - else: - pass + # buffer the client hello + self.buffered_packets = [ + (event.data, event.remote_addr, self.loop.time()) + ] + + # contact the upstream server first if so desired + if hook_data.establish_server_tls_first: + self.server_layer = ServerQuicLayer( + context=self.context, + issue_cid=self.issue_cid, + retire_cid=self.retire_cid, + ) + self._handle_event = self.handle_wait_for_server + yield from self.handle_wait_for_server(events.Start()) + else: + yield from self.start_client_connection() elif isinstance(event, events.ConnectionClosed): - assert event.connection == self.conn + # this is odd since this layer should only be created if there is a packet self._handle_event = self.handle_done else: raise AssertionError(f"Unexpected event: {event}") - @expect(events.DataReceived, events.ConnectionClosed) - def handle_done(self, _) -> layer.CommandGenerator[None]: - yield from () + def handle_wait_for_server( + self, event: events.Event + ) -> layer.CommandGenerator[None]: + assert self.buffered_packets is not None + assert self.server_layer is not None + + # filter DataReceived and ConnectionClosed relating to the client connection + if isinstance(event, events.ConnectionEvent): + if event.connection is self.context.client: + if isinstance(event, events.DataReceived): + # still waiting for the server, buffer the data + self.buffered_packets.append( + (event.data, event.remote_addr, self.loop.time()) + ) + + elif isinstance(event, events.ConnectionClosed): + # close the upstream connection as well and be done + yield commands.CloseConnection(self.context.server) + self._handle_event = self.handle_done + + else: + raise AssertionError(f"Unexpected event: {event}") + + # forward the event and check it's results + yield from self.server_layer.handle_event(event) + if not self.context.server.connected: + yield commands.Log( + f"Unable to establish QUIC connection with server ({self.context.server.error or 'Connection closed.'}). " + f"Trying to establish QUIC with client anyway." + ) + yield from self.start_client_connection() + elif self.context.server.tls_established: + yield from self.start_client_connection() _handle_event = handle_start From 366e696538061d13540ced440998193e4c1e4c31 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Sat, 18 Jun 2022 12:33:45 +0200 Subject: [PATCH 009/529] [quic] add child layering --- mitmproxy/proxy/layers/quic.py | 187 ++++++++++++++++++++++----------- 1 file changed, 127 insertions(+), 60 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index af30a272fa..3e671356ad 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -1,7 +1,9 @@ +from abc import abstractmethod import asyncio from dataclasses import dataclass from ssl import VerifyMode -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Literal, Optional, Tuple, Union +from urllib.parse import non_hierarchical from aioquic.buffer import Buffer as QuicBuffer from aioquic.quic import events as quic_events @@ -80,13 +82,19 @@ class QuicTlsStartServerHook(commands.StartHook): @dataclass -class QuicStreamDataReceived(quic_events.StreamDataReceived, events.ConnectionEvent): - pass +class QuicConnectionEvent(events.ConnectionEvent): + event: quic_events.QuicEvent @dataclass -class QuicStreamReset(quic_events.StreamReset, events.ConnectionEvent): - pass +class QuicGetConnection(commands.ConnectionCommand): # -> QuicConnection + blocking = True + + +@dataclass(repr=False) +class OpenGetConnectionCompleted(events.CommandCompleted): + command: QuicGetConnection + connection: QuicConnection class QuicSecretsLogger: @@ -113,7 +121,7 @@ class QuicClientHello(Exception): data: bytes -def read_client_hello(data: bytes) -> ClientHello: +def pull_client_hello_and_connection_id(data: bytes) -> Tuple[ClientHello, bytes]: # ensure the first packet is indeed the initial one buffer = QuicBuffer(data=data) header = pull_quic_header(buffer) @@ -151,7 +159,7 @@ def initialize_replacement(peer_cid: bytes) -> None: quic.receive_datagram(data, ("0.0.0.0", 0), now=0) except QuicClientHello as hello: try: - return ClientHello(hello.data) + return (ClientHello(hello.data), header.destination_cid) except EOFError as e: raise ValueError("Invalid ClientHello data.") from e except QuicConnectionError as e: @@ -160,25 +168,29 @@ def initialize_replacement(peer_cid: bytes) -> None: class QuicLayer(layer.Layer): - loop: asyncio.AbstractEventLoop - quic: Optional[QuicConnection] + child_layer: Optional[layer.Layer] conn: connection.Connection + loop: asyncio.AbstractEventLoop original_destination_connection_id: Optional[bytes] + quic: Optional[QuicConnection] + waiting_get_connection_commands: List[QuicGetConnection] def __init__( self, context: context.Context, conn: connection.Connection, - issue_cid: Callable[[bytes], None], - retire_cid: Callable[[bytes], None], + issue_cid: Optional[Callable[[bytes], None]] = None, + retire_cid: Optional[Callable[[bytes], None]] = None, ) -> None: super().__init__(context) + self.child_layer = None + self.conn = conn self.loop = asyncio.get_event_loop() - self.buffer = [] + self.original_destination_connection_id = None self.quic = None - self.conn = conn - self.issue_cid = issue_cid - self.retire_cid = retire_cid + self.waiting_get_connection_commands = [] + self._issue_cid = issue_cid + self._retire_cid = retire_cid def build_configuration(self, settings: QuicTlsSettings) -> QuicConfiguration: return QuicConfiguration( @@ -198,9 +210,42 @@ def build_configuration(self, settings: QuicTlsSettings) -> QuicConfiguration: verify_mode=settings.verify_mode, ) + def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: + assert self.child_layer is not None + + # answer the child layers request for the connection + for command in self.child_layer.handle_event(event): + if ( + isinstance(command, QuicGetConnection) + and command.connection is self.conn + ): + if self.quic is None: + self.waiting_get_connection_commands.append(command) + else: + yield from self.child_layer.handle_event( + OpenGetConnectionCompleted( + command=command, + connection=self.quic, + ) + ) + else: + yield command + + def fail_connection( + self, + reason: str, + level: Literal["error", "warn", "info", "alert", "debug"] = "warn", + ) -> layer.CommandGenerator[None]: + yield commands.Log( + message=f"Failing connection {self.conn}: {reason}", level=level + ) + if self.conn.connected: + yield commands.CloseConnection(self.conn) + self._handle_event = self.state_done + def initialize_connection(self) -> layer.CommandGenerator[None]: - assert not self.quic - self._handle_event = self.handle_connected + assert self.quic is None + self._handle_event = self.state_ready # (almost) identical to _TLSLayer.start_tls tls_data = QuicTlsData(self.conn, self.context) @@ -209,51 +254,71 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: else: yield QuicTlsStartServerHook(tls_data) if not tls_data.settings: - yield commands.Log( - "No TLS settings were provided, failing connection.", "error" + yield from self.fail_connection( + "No TLS settings were provided, failing connection.", level="error" ) - yield commands.CloseConnection(self.conn) - self._handle_event = self.handle_done return - assert tls_data.settings + assert tls_data.settings is not None + # create the connection and let the waiters know about it self.quic = QuicConnection( configuration=self.build_configuration(tls_data.settings), original_destination_connection_id=self.original_destination_connection_id, ) - self.issue_cid(self.quic.host_cid) + if self._issue_cid: + self._issue_cid(self.quic.host_cid) + while self.waiting_get_connection_commands: + assert self.quic is not None + assert self.child_layer is not None + yield from self.child_layer.handle_event( + OpenGetConnectionCompleted( + command=self.waiting_get_connection_commands.pop(), + connection=self.quic, + ) + ) def process_events(self) -> layer.CommandGenerator[None]: - assert self.quic + assert self.quic is not None + yield from () + + @abstractmethod + def start(self) -> layer.CommandGenerator[None]: + yield from () # pragma: no cover + + def state_start(self, event: events.Event) -> layer.CommandGenerator[None]: + assert isinstance(event, events.Start) + + # start this layer and the child layer + yield from self.start() + if self.child_layer is not None: + yield from self.child_layer.handle_event(event) + + def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: + assert self.quic is not None + yield from () - def handle_done(self, _) -> layer.CommandGenerator[None]: + def state_done(self, event: events.Event) -> layer.CommandGenerator[None]: yield from () + _handle_event = state_start + class ServerQuicLayer(QuicLayer): """ This layer establishes QUIC for a single server connection. """ - def __init__( - self, - context: context.Context, - issue_cid: Callable[[bytes], None], - retire_cid: Callable[[bytes], None], - ) -> None: - super().__init__(context, context.server, issue_cid, retire_cid) - - def handle_start(self, event: events.Event) -> layer.CommandGenerator[None]: - assert isinstance(event, events.Start) + def __init__(self, context: context.Context) -> None: + super().__init__(context, context.server) + def start(self) -> layer.CommandGenerator[None]: # ensure there is an UDP connection if not self.conn.connected: err = yield commands.OpenConnection(self.conn) if err is not None: - yield commands.Log( + self.fail_connection( f"Failed to establish connection to {human.format_address(self.conn)}: {err}" ) - self._handle_event = self.handle_done return # try to connect @@ -294,25 +359,29 @@ def start_client_connection(self) -> layer.CommandGenerator[None]: ) yield from self.process_events() - def handle_start(self, event: events.Event) -> layer.CommandGenerator[None]: - assert isinstance(event, events.Start) - self._handle_event = self.handle_client_hello + def start(self) -> layer.CommandGenerator[None]: + self._handle_event = self.state_wait_for_client_hello yield from () - def handle_client_hello(self, event: events.Event) -> layer.CommandGenerator[None]: + def state_wait_for_client_hello( + self, event: events.Event + ) -> layer.CommandGenerator[None]: assert isinstance(event, events.ConnectionEvent) assert event.connection is self.conn if isinstance(event, events.DataReceived): + assert event.remote_addr is not None + # extract the client hello try: - client_hello = read_client_hello(event.data) + ( + client_hello, + self.original_destination_connection_id, + ) = pull_client_hello_and_connection_id(event.data) except ValueError as e: - yield commands.Log( - f"Cannot parse ClientHello: {str(e)} ({event.data.hex()})", "warn" + yield from self.fail_connection( + f"Cannot parse ClientHello: {str(e)} ({event.data.hex()})" ) - self._handle_event = self.handle_done - yield commands.CloseConnection(self.conn) else: self.conn.sni = client_hello.sni self.conn.alpn_offers = client_hello.alpn_protocols @@ -336,24 +405,20 @@ def handle_client_hello(self, event: events.Event) -> layer.CommandGenerator[Non # contact the upstream server first if so desired if hook_data.establish_server_tls_first: - self.server_layer = ServerQuicLayer( - context=self.context, - issue_cid=self.issue_cid, - retire_cid=self.retire_cid, - ) - self._handle_event = self.handle_wait_for_server - yield from self.handle_wait_for_server(events.Start()) + self.server_layer = ServerQuicLayer(self.context) + self._handle_event = self.state_wait_for_upstream_server + yield from self.state_wait_for_upstream_server(events.Start()) else: yield from self.start_client_connection() elif isinstance(event, events.ConnectionClosed): # this is odd since this layer should only be created if there is a packet - self._handle_event = self.handle_done + self._handle_event = self.state_done else: raise AssertionError(f"Unexpected event: {event}") - def handle_wait_for_server( + def state_wait_for_upstream_server( self, event: events.Event ) -> layer.CommandGenerator[None]: assert self.buffered_packets is not None @@ -361,8 +426,10 @@ def handle_wait_for_server( # filter DataReceived and ConnectionClosed relating to the client connection if isinstance(event, events.ConnectionEvent): - if event.connection is self.context.client: + if event.connection is self.conn: if isinstance(event, events.DataReceived): + assert event.remote_addr is not None + # still waiting for the server, buffer the data self.buffered_packets.append( (event.data, event.remote_addr, self.loop.time()) @@ -370,8 +437,10 @@ def handle_wait_for_server( elif isinstance(event, events.ConnectionClosed): # close the upstream connection as well and be done - yield commands.CloseConnection(self.context.server) - self._handle_event = self.handle_done + self._handle_event = self.state_done + yield from self.server_layer.fail_connection( + "Client closed the connection." + ) else: raise AssertionError(f"Unexpected event: {event}") @@ -386,5 +455,3 @@ def handle_wait_for_server( yield from self.start_client_connection() elif self.context.server.tls_established: yield from self.start_client_connection() - - _handle_event = handle_start From 08aa838e9681746c64beb0f3f3f085863501ac49 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Sat, 18 Jun 2022 23:28:00 +0200 Subject: [PATCH 010/529] [quic] handle aioquic events --- mitmproxy/proxy/layers/quic.py | 130 +++++++++++++++++++++++++++------ 1 file changed, 107 insertions(+), 23 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 3e671356ad..41f4dfaa7d 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from ssl import VerifyMode from typing import Callable, List, Literal, Optional, Tuple, Union -from urllib.parse import non_hierarchical from aioquic.buffer import Buffer as QuicBuffer from aioquic.quic import events as quic_events @@ -13,7 +12,7 @@ from aioquic.quic.packet import PACKET_TYPE_INITIAL, pull_quic_header from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa -from mitmproxy import connection +from mitmproxy import certs, connection from mitmproxy.net import tls from mitmproxy.proxy import commands, context, events, layer, layers from mitmproxy.tls import ClientHello, ClientHelloData, TlsData @@ -92,7 +91,7 @@ class QuicGetConnection(commands.ConnectionCommand): # -> QuicConnection @dataclass(repr=False) -class OpenGetConnectionCompleted(events.CommandCompleted): +class QuicGetConnectionCompleted(events.CommandCompleted): command: QuicGetConnection connection: QuicConnection @@ -173,7 +172,7 @@ class QuicLayer(layer.Layer): loop: asyncio.AbstractEventLoop original_destination_connection_id: Optional[bytes] quic: Optional[QuicConnection] - waiting_get_connection_commands: List[QuicGetConnection] + tls: Optional[QuicTlsSettings] def __init__( self, @@ -188,11 +187,17 @@ def __init__( self.loop = asyncio.get_event_loop() self.original_destination_connection_id = None self.quic = None - self.waiting_get_connection_commands = [] + self.tls = None + self._get_connection_commands: List[QuicGetConnection] = [] self._issue_cid = issue_cid + self._request_wakeup_command_and_timer: Optional[ + Tuple[commands.RequestWakeup, float] + ] = None self._retire_cid = retire_cid - def build_configuration(self, settings: QuicTlsSettings) -> QuicConfiguration: + def build_configuration(self) -> QuicConfiguration: + assert self.tls is not None + return QuicConfiguration( alpn_protocols=self.conn.alpn_offers, connection_id_length=self.context.options.quic_connection_id_length, @@ -201,13 +206,13 @@ def build_configuration(self, settings: QuicTlsSettings) -> QuicConfiguration: if tls.log_master_secret is not None else None, server_name=self.conn.sni, - cafile=settings.ca_file, - capath=settings.ca_path, - certificate=settings.certificate, - certificate_chain=settings.certificate_chain, - cipher_suites=settings.cipher_suites, - private_key=settings.certificate_private_key, - verify_mode=settings.verify_mode, + cafile=self.tls.ca_file, + capath=self.tls.ca_path, + certificate=self.tls.certificate, + certificate_chain=self.tls.certificate_chain, + cipher_suites=self.tls.cipher_suites, + private_key=self.tls.certificate_private_key, + verify_mode=self.tls.verify_mode, ) def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: @@ -220,10 +225,10 @@ def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: and command.connection is self.conn ): if self.quic is None: - self.waiting_get_connection_commands.append(command) + self._get_connection_commands.append(command) else: yield from self.child_layer.handle_event( - OpenGetConnectionCompleted( + QuicGetConnectionCompleted( command=command, connection=self.quic, ) @@ -245,7 +250,6 @@ def fail_connection( def initialize_connection(self) -> layer.CommandGenerator[None]: assert self.quic is None - self._handle_event = self.state_ready # (almost) identical to _TLSLayer.start_tls tls_data = QuicTlsData(self.conn, self.context) @@ -253,33 +257,113 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: yield QuicTlsStartClientHook(tls_data) else: yield QuicTlsStartServerHook(tls_data) - if not tls_data.settings: + if tls_data.settings is None: yield from self.fail_connection( "No TLS settings were provided, failing connection.", level="error" ) return assert tls_data.settings is not None + self.tls = tls_data.settings - # create the connection and let the waiters know about it + # create the aioquic connection self.quic = QuicConnection( - configuration=self.build_configuration(tls_data.settings), + configuration=self.build_configuration(), original_destination_connection_id=self.original_destination_connection_id, ) if self._issue_cid: self._issue_cid(self.quic.host_cid) - while self.waiting_get_connection_commands: + self._handle_event = self.state_ready + + # let the waiters know about the available connection + while self._get_connection_commands: assert self.quic is not None assert self.child_layer is not None yield from self.child_layer.handle_event( - OpenGetConnectionCompleted( - command=self.waiting_get_connection_commands.pop(), + QuicGetConnectionCompleted( + command=self._get_connection_commands.pop(), connection=self.quic, ) ) def process_events(self) -> layer.CommandGenerator[None]: assert self.quic is not None - yield from () + assert self.tls is not None + + event = self.quic.next_event() + while event is not None: + if isinstance(event, quic_events.ConnectionIdIssued): + if self._issue_cid is not None: + self._issue_cid(event.connection_id) + + elif isinstance(event, quic_events.ConnectionIdRetired): + if self._retire_cid is not None: + self._retire_cid(event.connection_id) + + elif isinstance(event, quic_events.ConnectionTerminated): + # report as TLS failure if the termination happened before the handshake + if not self.conn.tls_established: + self.conn.error = event.reason_phrase + tls_data = QuicTlsData( + conn=self.conn, context=self.context, settings=self.tls + ) + if self.conn is self.context.client: + yield layers.tls.TlsFailedClientHook(tls_data) + else: + yield layers.tls.TlsFailedServerHook(tls_data) + + # always close the connection + yield from self.fail_connection(event.reason_phrase) + + elif isinstance(event, quic_events.HandshakeCompleted): + # concatenate all peer certificates + all_certs = [] + if self.quic.tls._peer_certificate is not None: + all_certs.append(self.quic.tls._peer_certificate) + if self.quic.tls._peer_certificate_chain is not None: + all_certs.extend(self.quic.tls._peer_certificate_chain) + + # set the connection's TLS properties + self.conn.timestamp_tls_setup = self.loop.time() + self.conn.certificate_list = [ + certs.Cert.from_pyopenssl(x) for x in all_certs + ] + self.conn.alpn = event.alpn_protocol.encode() + self.conn.cipher = self.quic.tls.key_schedule.cipher_suite.name + self.conn.tls_version = "QUIC" + + # report the success to addons + tls_data = QuicTlsData( + conn=self.conn, context=self.context, settings=self.tls + ) + if self.conn is self.context.client: + yield layers.tls.TlsEstablishedClientHook(tls_data) + else: + yield layers.tls.TlsEstablishedServerHook(tls_data) + + # forward the event as a QuicConnectionEvent to the child layer + yield from self.event_to_child( + QuicConnectionEvent(connection=self.conn, event=event) + ) + + # handle the next event + event = self.quic.next_event() + + # send all queued datagrams + for data, addr in self.quic.datagrams_to_send(now=self.loop.time()): + yield commands.SendData(connection=self.conn, data=data, remote_addr=addr) + + # ensure the wakeup is set and still correct + timer = self.quic.get_timer() + if timer is None: + self._request_wakeup_command_and_timer = None + else: + if self._request_wakeup_command_and_timer is not None: + _, existing_timer = self._request_wakeup_command_and_timer + if existing_timer == timer: + return + command = commands.RequestWakeup(timer - self.loop.time()) + self._request_wakeup_command_and_timer = (command, timer) + yield command @abstractmethod def start(self) -> layer.CommandGenerator[None]: From 34776c1298f571abc7581054bbac642a9699269b Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Sun, 19 Jun 2022 01:13:56 +0200 Subject: [PATCH 011/529] [quic] use next layer and event filtering --- mitmproxy/proxy/layers/quic.py | 70 +++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 13 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 41f4dfaa7d..619c6233d6 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -167,7 +167,7 @@ def initialize_replacement(peer_cid: bytes) -> None: class QuicLayer(layer.Layer): - child_layer: Optional[layer.Layer] + child_layer: layer.Layer conn: connection.Connection loop: asyncio.AbstractEventLoop original_destination_connection_id: Optional[bytes] @@ -182,7 +182,7 @@ def __init__( retire_cid: Optional[Callable[[bytes], None]] = None, ) -> None: super().__init__(context) - self.child_layer = None + self.child_layer = layer.NextLayer(context) self.conn = conn self.loop = asyncio.get_event_loop() self.original_destination_connection_id = None @@ -216,10 +216,10 @@ def build_configuration(self) -> QuicConfiguration: ) def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: - assert self.child_layer is not None - - # answer the child layers request for the connection + # filter commands coming from the child layer for command in self.child_layer.handle_event(event): + + # answer or queue requests for the aioquic connection instanc if ( isinstance(command, QuicGetConnection) and command.connection is self.conn @@ -233,6 +233,19 @@ def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: connection=self.quic, ) ) + + # properly close QUIC connections + elif ( + isinstance(command, commands.CloseConnection) + and command.connection is self.conn + ): + if self.conn.connected and self.quic is not None: + self.quic.close() + yield from self.process_events() + self._handle_event = self.state_done + yield command + + # return other commands else: yield command @@ -251,7 +264,7 @@ def fail_connection( def initialize_connection(self) -> layer.CommandGenerator[None]: assert self.quic is None - # (almost) identical to _TLSLayer.start_tls + # query addons to provide the necessary TLS settings tls_data = QuicTlsData(self.conn, self.context) if self.conn is self.context.client: yield QuicTlsStartClientHook(tls_data) @@ -262,7 +275,6 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: "No TLS settings were provided, failing connection.", level="error" ) return - assert tls_data.settings is not None self.tls = tls_data.settings # create the aioquic connection @@ -270,14 +282,13 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: configuration=self.build_configuration(), original_destination_connection_id=self.original_destination_connection_id, ) - if self._issue_cid: + if self._issue_cid is not None: self._issue_cid(self.quic.host_cid) self._handle_event = self.state_ready # let the waiters know about the available connection while self._get_connection_commands: assert self.quic is not None - assert self.child_layer is not None yield from self.child_layer.handle_event( QuicGetConnectionCompleted( command=self._get_connection_commands.pop(), @@ -289,6 +300,7 @@ def process_events(self) -> layer.CommandGenerator[None]: assert self.quic is not None assert self.tls is not None + # handle all buffered aioquic connection events event = self.quic.next_event() while event is not None: if isinstance(event, quic_events.ConnectionIdIssued): @@ -340,6 +352,10 @@ def process_events(self) -> layer.CommandGenerator[None]: else: yield layers.tls.TlsEstablishedServerHook(tls_data) + # perform next layer decisions now + if isinstance(self.child_layer, layer.NextLayer): + yield from self.child_layer._ask() + # forward the event as a QuicConnectionEvent to the child layer yield from self.event_to_child( QuicConnectionEvent(connection=self.conn, event=event) @@ -374,15 +390,43 @@ def state_start(self, event: events.Event) -> layer.CommandGenerator[None]: # start this layer and the child layer yield from self.start() - if self.child_layer is not None: - yield from self.child_layer.handle_event(event) + yield from self.child_layer.handle_event(event) def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: assert self.quic is not None - yield from () + + if isinstance(event, events.DataReceived): + # forward incoming data only to aioquic + if event.connection is self.conn: + self.quic.receive_datagram( + data=event.data, addr=event.remote_addr, now=self.loop.time() + ) + yield from self.process_events() + return + + elif isinstance(event, events.ConnectionClosed): + if event.connection is self.conn: + # connection closed unexpectedly + yield from self.fail_connection( + "Client closed UDP connection.", level="info" + ) + + elif isinstance(event, events.Wakeup): + # make sure we intercept wakeup events for aioquic + if self._request_wakeup_command_and_timer is not None: + command, timer = self._request_wakeup_command_and_timer + if event.command is command: + self._request_wakeup_command_and_timer = None + self.quic.handle_timer(now=max(timer, self.loop.time())) + yield from self.process_events() + return + + # forward other events to the child layer + yield from self.event_to_child(event) def state_done(self, event: events.Event) -> layer.CommandGenerator[None]: - yield from () + # when done, just forward the event + yield from self.child_layer.handle_event(event) _handle_event = state_start From 3aaa2f9b9b9c519291df455e1922c614e285613c Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Sun, 19 Jun 2022 03:42:58 +0200 Subject: [PATCH 012/529] [quic] introduce entry layer --- mitmproxy/addons/proxyserver.py | 2 +- mitmproxy/proxy/layers/__init__.py | 5 +- mitmproxy/proxy/layers/quic.py | 278 +++++++++++++++-------------- 3 files changed, 151 insertions(+), 134 deletions(-) diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index d163c293d4..e52390be57 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -502,7 +502,7 @@ def retire_connection_id(handler: ProxyConnectionHandler, cid: bytes) -> None: server_sni=server_sni, connection_id=connection_id, done_callback=cleanup_connection_ids, - layer_factory=lambda handler: layers.ClientQuicLayer( + layer_factory=lambda handler: layers.QuicLayer( context=handler.layer.context, issue_cid=lambda cid: issue_connection_id(handler, cid), retire_cid=lambda cid: retire_connection_id(handler, cid), diff --git a/mitmproxy/proxy/layers/__init__.py b/mitmproxy/proxy/layers/__init__.py index 7746ed9657..ae31304775 100644 --- a/mitmproxy/proxy/layers/__init__.py +++ b/mitmproxy/proxy/layers/__init__.py @@ -1,7 +1,7 @@ from . import modes from .dns import DNSLayer from .http import HttpLayer -from .quic import ClientQuicLayer, ServerQuicLayer +from .quic import QuicLayer from .tcp import TCPLayer from .tls import ClientTLSLayer, ServerTLSLayer from .websocket import WebsocketLayer @@ -10,8 +10,7 @@ "modes", "DNSLayer", "HttpLayer", - "ClientQuicLayer", - "ServerQuicLayer", + "QuicLayer", "TCPLayer", "ClientTLSLayer", "ServerTLSLayer", diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 619c6233d6..dd0f55cd1e 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -166,34 +166,28 @@ def initialize_replacement(peer_cid: bytes) -> None: raise ValueError("No ClientHello returned.") -class QuicLayer(layer.Layer): +class _QuicLayer(layer.Layer): child_layer: layer.Layer conn: connection.Connection - loop: asyncio.AbstractEventLoop - original_destination_connection_id: Optional[bytes] - quic: Optional[QuicConnection] - tls: Optional[QuicTlsSettings] + issue_connection_id_callback: Optional[Callable[[bytes], None]] = None + original_destination_connection_id: Optional[bytes] = None + quic: Optional[QuicConnection] = None + retire_connection_id_callback: Optional[Callable[[bytes], None]] = None + tls: Optional[QuicTlsSettings] = None def __init__( self, context: context.Context, conn: connection.Connection, - issue_cid: Optional[Callable[[bytes], None]] = None, - retire_cid: Optional[Callable[[bytes], None]] = None, ) -> None: super().__init__(context) self.child_layer = layer.NextLayer(context) self.conn = conn - self.loop = asyncio.get_event_loop() - self.original_destination_connection_id = None - self.quic = None - self.tls = None + self._loop = asyncio.get_event_loop() self._get_connection_commands: List[QuicGetConnection] = [] - self._issue_cid = issue_cid self._request_wakeup_command_and_timer: Optional[ Tuple[commands.RequestWakeup, float] ] = None - self._retire_cid = retire_cid def build_configuration(self) -> QuicConfiguration: assert self.tls is not None @@ -282,8 +276,8 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: configuration=self.build_configuration(), original_destination_connection_id=self.original_destination_connection_id, ) - if self._issue_cid is not None: - self._issue_cid(self.quic.host_cid) + if self.issue_connection_id_callback is not None: + self.issue_connection_id_callback(self.quic.host_cid) self._handle_event = self.state_ready # let the waiters know about the available connection @@ -304,12 +298,12 @@ def process_events(self) -> layer.CommandGenerator[None]: event = self.quic.next_event() while event is not None: if isinstance(event, quic_events.ConnectionIdIssued): - if self._issue_cid is not None: - self._issue_cid(event.connection_id) + if self.issue_connection_id_callback is not None: + self.issue_connection_id_callback(event.connection_id) elif isinstance(event, quic_events.ConnectionIdRetired): - if self._retire_cid is not None: - self._retire_cid(event.connection_id) + if self.retire_connection_id_callback is not None: + self.retire_connection_id_callback(event.connection_id) elif isinstance(event, quic_events.ConnectionTerminated): # report as TLS failure if the termination happened before the handshake @@ -335,7 +329,7 @@ def process_events(self) -> layer.CommandGenerator[None]: all_certs.extend(self.quic.tls._peer_certificate_chain) # set the connection's TLS properties - self.conn.timestamp_tls_setup = self.loop.time() + self.conn.timestamp_tls_setup = self._loop.time() self.conn.certificate_list = [ certs.Cert.from_pyopenssl(x) for x in all_certs ] @@ -365,7 +359,7 @@ def process_events(self) -> layer.CommandGenerator[None]: event = self.quic.next_event() # send all queued datagrams - for data, addr in self.quic.datagrams_to_send(now=self.loop.time()): + for data, addr in self.quic.datagrams_to_send(now=self._loop.time()): yield commands.SendData(connection=self.conn, data=data, remote_addr=addr) # ensure the wakeup is set and still correct @@ -377,7 +371,7 @@ def process_events(self) -> layer.CommandGenerator[None]: _, existing_timer = self._request_wakeup_command_and_timer if existing_timer == timer: return - command = commands.RequestWakeup(timer - self.loop.time()) + command = commands.RequestWakeup(timer - self._loop.time()) self._request_wakeup_command_and_timer = (command, timer) yield command @@ -395,31 +389,34 @@ def state_start(self, event: events.Event) -> layer.CommandGenerator[None]: def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: assert self.quic is not None - if isinstance(event, events.DataReceived): - # forward incoming data only to aioquic - if event.connection is self.conn: - self.quic.receive_datagram( - data=event.data, addr=event.remote_addr, now=self.loop.time() - ) - yield from self.process_events() - return + # forward incoming data only to aioquic + if isinstance(event, events.DataReceived) and event.connection is self.conn: + assert event.remote_addr is not None + self.quic.receive_datagram( + data=event.data, addr=event.remote_addr, now=self._loop.time() + ) + yield from self.process_events() + return - elif isinstance(event, events.ConnectionClosed): - if event.connection is self.conn: - # connection closed unexpectedly - yield from self.fail_connection( - "Client closed UDP connection.", level="info" - ) + # check if the connection was closed by peer + elif ( + isinstance(event, events.ConnectionClosed) and event.connection is self.conn + ): + yield from self.fail_connection( + "Client closed UDP connection.", level="info" + ) - elif isinstance(event, events.Wakeup): - # make sure we intercept wakeup events for aioquic - if self._request_wakeup_command_and_timer is not None: - command, timer = self._request_wakeup_command_and_timer - if event.command is command: - self._request_wakeup_command_and_timer = None - self.quic.handle_timer(now=max(timer, self.loop.time())) - yield from self.process_events() - return + # intercept wakeup events for aioquic + elif ( + isinstance(event, events.Wakeup) + and self._request_wakeup_command_and_timer is not None + ): + command, timer = self._request_wakeup_command_and_timer + if event.command is command: + self._request_wakeup_command_and_timer = None + self.quic.handle_timer(now=max(timer, self._loop.time())) + yield from self.process_events() + return # forward other events to the child layer yield from self.event_to_child(event) @@ -431,7 +428,7 @@ def state_done(self, event: events.Event) -> layer.CommandGenerator[None]: _handle_event = state_start -class ServerQuicLayer(QuicLayer): +class ServerQuicLayer(_QuicLayer): """ This layer establishes QUIC for a single server connection. """ @@ -452,29 +449,26 @@ def start(self) -> layer.CommandGenerator[None]: # try to connect yield from self.initialize_connection() if self.quic is not None: - self.quic.connect(addr=self.conn.peername, now=self.loop.time()) + self.quic.connect(addr=self.conn.peername, now=self._loop.time()) yield from self.process_events() -class ClientQuicLayer(QuicLayer): +class ClientQuicLayer(_QuicLayer): """ This layer establishes QUIC on a single client connection. """ - server_layer: Optional[ServerQuicLayer] buffered_packets: Optional[List[Tuple[bytes, connection.Address, float]]] def __init__( self, context: context.Context, - issue_cid: Callable[[bytes], None], - retire_cid: Callable[[bytes], None], + wait_for_upstream: bool, ) -> None: - super().__init__(context, context.client, issue_cid, retire_cid) - self.server_layer = None - self.buffered_packets = None + super().__init__(context, context.client) + self.buffered_packets = [] if wait_for_upstream else None - def start_client_connection(self) -> layer.CommandGenerator[None]: + def initialize_connection_and_flush_buffer(self) -> layer.CommandGenerator[None]: assert self.buffered_packets is not None yield from self.initialize_connection() @@ -488,98 +482,122 @@ def start_client_connection(self) -> layer.CommandGenerator[None]: yield from self.process_events() def start(self) -> layer.CommandGenerator[None]: - self._handle_event = self.state_wait_for_client_hello - yield from () + if self.buffered_packets is None: + yield from self.initialize_connection() + else: + self._handle_event = self.state_wait_for_upstream - def state_wait_for_client_hello( + def state_wait_for_upstream( self, event: events.Event ) -> layer.CommandGenerator[None]: - assert isinstance(event, events.ConnectionEvent) - assert event.connection is self.conn + assert self.buffered_packets is not None - if isinstance(event, events.DataReceived): + # buffer incoming packets until the upstream handshake completed + if isinstance(event, events.DataReceived) and event.connection is self.conn: assert event.remote_addr is not None + self.buffered_packets.append( + (event.data, event.remote_addr, self._loop.time()) + ) + return - # extract the client hello - try: - ( - client_hello, - self.original_destination_connection_id, - ) = pull_client_hello_and_connection_id(event.data) - except ValueError as e: + # watch for closed connections on both legs + elif isinstance(event, events.ConnectionClosed): + if event.connection is self.conn: yield from self.fail_connection( - f"Cannot parse ClientHello: {str(e)} ({event.data.hex()})" + "Client closed UDP connection before upstream server handshake completed.", + level="info", ) - else: - self.conn.sni = client_hello.sni - self.conn.alpn_offers = client_hello.alpn_protocols + elif event.connection is self.context.server: + yield commands.Log( + f"Unable to establish QUIC connection with server ({self.context.server.error or 'Connection closed.'}). " + f"Trying to establish QUIC with client anyway." + ) + yield from self.initialize_connection_and_flush_buffer() - # check with addons what we shall do - hook_data = ClientHelloData(self.context, client_hello) - yield layers.tls.TlsClienthelloHook(hook_data) + # continue if upstream completed the handshake + elif ( + isinstance(event, QuicConnectionEvent) + and event.connection is self.context.server + and isinstance(event.event, quic_events.HandshakeCompleted) + ): + yield from self.initialize_connection_and_flush_buffer() - if hook_data.ignore_connection: - # simply relay everything (including the client hello) - relay_layer = layers.TCPLayer(self.context, ignore=True) - self._handle_event = relay_layer.handle_event - yield from relay_layer.handle_event(events.Start()) - yield from relay_layer.handle_event(event) + # forward other events to the child layer + yield from self.event_to_child(event) - else: - # buffer the client hello - self.buffered_packets = [ - (event.data, event.remote_addr, self.loop.time()) - ] - - # contact the upstream server first if so desired - if hook_data.establish_server_tls_first: - self.server_layer = ServerQuicLayer(self.context) - self._handle_event = self.state_wait_for_upstream_server - yield from self.state_wait_for_upstream_server(events.Start()) - else: - yield from self.start_client_connection() - elif isinstance(event, events.ConnectionClosed): - # this is odd since this layer should only be created if there is a packet - self._handle_event = self.state_done +class QuicLayer(layer.Layer): + """ + Entry layer for QUIC proxy server. + """ - else: - raise AssertionError(f"Unexpected event: {event}") + def __init__( + self, + context: context.Context, + issue_cid: Callable[[bytes], None], + retire_cid: Callable[[bytes], None], + ) -> None: + super().__init__(context) + self._issue_cid = issue_cid + self._retire_cid = retire_cid - def state_wait_for_upstream_server( - self, event: events.Event - ) -> layer.CommandGenerator[None]: - assert self.buffered_packets is not None - assert self.server_layer is not None + def build_client_layer( + self, connection_id: bytes, wait_for_upstream: bool + ) -> ClientQuicLayer: + layer = ClientQuicLayer( + context=self.context, wait_for_upstream=wait_for_upstream + ) + layer.original_destination_connection_id = connection_id + layer.issue_connection_id_callback = self._issue_cid + layer.retire_connection_id_callback = self._retire_cid + return layer + + def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: + # only handle the first packet from the client + if ( + not isinstance(event, events.DataReceived) + or event.connection is not self.context.client + ): + return - # filter DataReceived and ConnectionClosed relating to the client connection - if isinstance(event, events.ConnectionEvent): - if event.connection is self.conn: - if isinstance(event, events.DataReceived): - assert event.remote_addr is not None + # extract the client hello + try: + client_hello, connection_id = pull_client_hello_and_connection_id( + event.data + ) + except ValueError as e: + yield commands.Log( + f"Cannot parse ClientHello: {str(e)} ({event.data.hex()})" + ) + yield commands.CloseConnection(self.context.client) + return - # still waiting for the server, buffer the data - self.buffered_packets.append( - (event.data, event.remote_addr, self.loop.time()) - ) + # copy the information + self.context.client.sni = client_hello.sni + self.context.client.alpn_offers = client_hello.alpn_protocols - elif isinstance(event, events.ConnectionClosed): - # close the upstream connection as well and be done - self._handle_event = self.state_done - yield from self.server_layer.fail_connection( - "Client closed the connection." - ) + # check with addons what we shall do + next_layer: layer.Layer + hook_data = ClientHelloData(self.context, client_hello) + yield layers.tls.TlsClienthelloHook(hook_data) - else: - raise AssertionError(f"Unexpected event: {event}") + # simply relay everything + if hook_data.ignore_connection: + next_layer = layers.TCPLayer(self.context, ignore=True) - # forward the event and check it's results - yield from self.server_layer.handle_event(event) - if not self.context.server.connected: - yield commands.Log( - f"Unable to establish QUIC connection with server ({self.context.server.error or 'Connection closed.'}). " - f"Trying to establish QUIC with client anyway." + # contact the upstream server first + elif hook_data.establish_server_tls_first: + next_layer = ServerQuicLayer(self.context) + next_layer.child_layer = self.build_client_layer( + connection_id, wait_for_upstream=True ) - yield from self.start_client_connection() - elif self.context.server.tls_established: - yield from self.start_client_connection() + + # perform the client handshake immediately + else: + next_layer = self.build_client_layer(connection_id, wait_for_upstream=False) + + # replace this layer and start the next one + self.handle_event = next_layer.handle_event + self._handle_event = next_layer._handle_event + yield from next_layer.handle_event(events.Start()) + yield from next_layer.handle_event(event) From ac21eac71ebfed537933a7842c84557ec4a84b9e Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 20 Jun 2022 04:25:42 +0200 Subject: [PATCH 013/529] [quic] expose transmit improve connection shutdown --- mitmproxy/proxy/layers/quic.py | 99 ++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 40 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index dd0f55cd1e..1fc14282aa 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -7,7 +7,7 @@ from aioquic.buffer import Buffer as QuicBuffer from aioquic.quic import events as quic_events from aioquic.quic.configuration import QuicConfiguration -from aioquic.quic.connection import QuicConnection, QuicConnectionError +from aioquic.quic.connection import QuicConnection, QuicConnectionError, QuicErrorCode from aioquic.tls import CipherSuite, HandshakeType from aioquic.quic.packet import PACKET_TYPE_INITIAL, pull_quic_header from cryptography import x509 @@ -16,7 +16,6 @@ from mitmproxy.net import tls from mitmproxy.proxy import commands, context, events, layer, layers from mitmproxy.tls import ClientHello, ClientHelloData, TlsData -from mitmproxy.utils import human @dataclass @@ -90,6 +89,11 @@ class QuicGetConnection(commands.ConnectionCommand): # -> QuicConnection blocking = True +@dataclass +class QuicTransmit: + connection: QuicConnection + + @dataclass(repr=False) class QuicGetConnectionCompleted(events.CommandCompleted): command: QuicGetConnection @@ -228,33 +232,26 @@ def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: ) ) + # transmit buffered data and re-arm timer + elif isinstance(command, QuicTransmit) and command.connection is self.quic: + yield from self.transmit() + # properly close QUIC connections elif ( isinstance(command, commands.CloseConnection) and command.connection is self.conn ): + reason = "CloseConnection command received." if self.conn.connected and self.quic is not None: - self.quic.close() + self.quic.close(reason_phrase=reason) yield from self.process_events() - self._handle_event = self.state_done - yield command + else: + yield from self.shutdown_connection(reason, level="info") # return other commands else: yield command - def fail_connection( - self, - reason: str, - level: Literal["error", "warn", "info", "alert", "debug"] = "warn", - ) -> layer.CommandGenerator[None]: - yield commands.Log( - message=f"Failing connection {self.conn}: {reason}", level=level - ) - if self.conn.connected: - yield commands.CloseConnection(self.conn) - self._handle_event = self.state_done - def initialize_connection(self) -> layer.CommandGenerator[None]: assert self.quic is None @@ -265,7 +262,7 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: else: yield QuicTlsStartServerHook(tls_data) if tls_data.settings is None: - yield from self.fail_connection( + yield from self.shutdown_connection( "No TLS settings were provided, failing connection.", level="error" ) return @@ -318,7 +315,12 @@ def process_events(self) -> layer.CommandGenerator[None]: yield layers.tls.TlsFailedServerHook(tls_data) # always close the connection - yield from self.fail_connection(event.reason_phrase) + yield from self.shutdown_connection( + event.reason_phrase, + level=( + "info" if event.error_code is QuicErrorCode.NO_ERROR else "warn" + ), + ) elif isinstance(event, quic_events.HandshakeCompleted): # concatenate all peer certificates @@ -358,22 +360,20 @@ def process_events(self) -> layer.CommandGenerator[None]: # handle the next event event = self.quic.next_event() - # send all queued datagrams - for data, addr in self.quic.datagrams_to_send(now=self._loop.time()): - yield commands.SendData(connection=self.conn, data=data, remote_addr=addr) + # transmit buffered data and re-arm timer + yield from self.transmit() - # ensure the wakeup is set and still correct - timer = self.quic.get_timer() - if timer is None: - self._request_wakeup_command_and_timer = None - else: - if self._request_wakeup_command_and_timer is not None: - _, existing_timer = self._request_wakeup_command_and_timer - if existing_timer == timer: - return - command = commands.RequestWakeup(timer - self._loop.time()) - self._request_wakeup_command_and_timer = (command, timer) - yield command + def shutdown_connection( + self, + reason: str, + level: Literal["error", "warn", "info", "alert", "debug"], + ) -> layer.CommandGenerator[None]: + yield commands.Log( + message=f"Connection {self.conn} closed: {reason}", level=level + ) + if self.conn.connected: + yield commands.CloseConnection(self.conn) + self._handle_event = self.state_done @abstractmethod def start(self) -> layer.CommandGenerator[None]: @@ -402,8 +402,8 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: elif ( isinstance(event, events.ConnectionClosed) and event.connection is self.conn ): - yield from self.fail_connection( - "Client closed UDP connection.", level="info" + yield from self.shutdown_connection( + "Peer UDP connection timed out.", level="info" ) # intercept wakeup events for aioquic @@ -425,6 +425,24 @@ def state_done(self, event: events.Event) -> layer.CommandGenerator[None]: # when done, just forward the event yield from self.child_layer.handle_event(event) + def transmit(self) -> layer.CommandGenerator[None]: + # send all queued datagrams + for data, addr in self.quic.datagrams_to_send(now=self._loop.time()): + yield commands.SendData(connection=self.conn, data=data, remote_addr=addr) + + # ensure the wakeup is set and still correct + timer = self.quic.get_timer() + if timer is None: + self._request_wakeup_command_and_timer = None + else: + if self._request_wakeup_command_and_timer is not None: + _, existing_timer = self._request_wakeup_command_and_timer + if existing_timer == timer: + return + command = commands.RequestWakeup(timer - self._loop.time()) + self._request_wakeup_command_and_timer = (command, timer) + yield command + _handle_event = state_start @@ -441,8 +459,9 @@ def start(self) -> layer.CommandGenerator[None]: if not self.conn.connected: err = yield commands.OpenConnection(self.conn) if err is not None: - self.fail_connection( - f"Failed to establish connection to {human.format_address(self.conn)}: {err}" + self.shutdown_connection( + f"Failed to connect: {err}", + level="warn", ) return @@ -503,8 +522,8 @@ def state_wait_for_upstream( # watch for closed connections on both legs elif isinstance(event, events.ConnectionClosed): if event.connection is self.conn: - yield from self.fail_connection( - "Client closed UDP connection before upstream server handshake completed.", + yield from self.shutdown_connection( + "Client UDP connection timeout out before upstream server handshake completed.", level="info", ) elif event.connection is self.context.server: From f129a1e5a30fa3534028b667af90c2f5ece517d1 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 20 Jun 2022 04:26:36 +0200 Subject: [PATCH 014/529] [quic] generalize h2 and h3 headers --- mitmproxy/http.py | 4 ++ mitmproxy/proxy/layers/http/_base.py | 71 ++++++++++++++++++++++++++- mitmproxy/proxy/layers/http/_http2.py | 66 ++++--------------------- 3 files changed, 84 insertions(+), 57 deletions(-) diff --git a/mitmproxy/http.py b/mitmproxy/http.py index e88242530c..394cddbdd2 100644 --- a/mitmproxy/http.py +++ b/mitmproxy/http.py @@ -285,6 +285,10 @@ def is_http11(self) -> bool: def is_http2(self) -> bool: return self.data.http_version == b"HTTP/2.0" + @property + def is_http3(self) -> bool: + return self.data.http_version == b"HTTP/3" + @property def headers(self) -> Headers: """ diff --git a/mitmproxy/proxy/layers/http/_base.py b/mitmproxy/proxy/layers/http/_base.py index b5f66d46ba..a0f82a2349 100644 --- a/mitmproxy/proxy/layers/http/_base.py +++ b/mitmproxy/proxy/layers/http/_base.py @@ -2,10 +2,13 @@ import textwrap from dataclasses import dataclass -from mitmproxy import http +import h2.utilities + +from mitmproxy import ctx, http from mitmproxy.connection import Connection from mitmproxy.proxy import commands, events, layer from mitmproxy.proxy.context import Context +from mitmproxy.proxy.layers.http import RequestHeaders, ResponseHeaders StreamId = int @@ -57,3 +60,69 @@ def format_error(status_code: int, message: str) -> bytes: .strip() .encode("utf8", "replace") ) + + +def get_request_headers( + event: RequestHeaders, +) -> layer.CommandGenerator[list[tuple[bytes, bytes]]]: + pseudo_headers = [ + (b":method", event.request.data.method), + (b":scheme", event.request.data.scheme), + (b":path", event.request.data.path), + ] + if event.request.authority: + pseudo_headers.append((b":authority", event.request.data.authority)) + + if event.request.is_http2 or event.request.is_http3: + hdrs = list(event.request.headers.fields) + if ctx.options.normalize_outbound_headers: + yield from normalize_h2_or_h3_headers(hdrs) + else: + headers = event.request.headers + if not event.request.authority and "host" in headers: + headers = headers.copy() + pseudo_headers.append((b":authority", headers.pop(b"host"))) + hdrs = normalize_h1_headers(list(headers.fields), True) + + return pseudo_headers + hdrs + + +def get_response_headers( + event: ResponseHeaders, +) -> layer.CommandGenerator[list[tuple[bytes, bytes]]]: + headers = [ + (b":status", b"%d" % event.response.status_code), + *event.response.headers.fields, + ] + if event.response.is_http2 or event.request.is_http3: + if ctx.options.normalize_outbound_headers: + yield from normalize_h2_or_h3_headers(headers) + else: + headers = normalize_h1_headers(headers, False) + return headers + + +def normalize_h1_headers( + headers: list[tuple[bytes, bytes]], is_client: bool +) -> list[tuple[bytes, bytes]]: + # HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length), + # which isn't valid HTTP/2 or HTTP/3. As such we normalize. + headers = h2.utilities.normalize_outbound_headers( + headers, + h2.utilities.HeaderValidationFlags(is_client, False, not is_client, False), + ) + # make sure that this is not just an iterator but an iterable, + # otherwise hyper-h2 will silently drop headers. + headers = list(headers) + return headers + + +def normalize_h2_or_h3_headers( + headers: list[tuple[bytes, bytes]] +) -> layer.CommandGenerator[None]: + for i in range(len(headers)): + if not headers[i][0].islower(): + yield commands.Log( + f"Lowercased {repr(headers[i][0]).lstrip('b')} header as uppercase is not allowed with HTTP/2 nor HTTP/3." + ) + headers[i] = (headers[i][0].lower(), headers[i][1]) diff --git a/mitmproxy/proxy/layers/http/_http2.py b/mitmproxy/proxy/layers/http/_http2.py index 3e34c325da..bbb1b19f81 100644 --- a/mitmproxy/proxy/layers/http/_http2.py +++ b/mitmproxy/proxy/layers/http/_http2.py @@ -29,7 +29,14 @@ ResponseTrailers, ResponseProtocolError, ) -from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error +from ._base import ( + HttpConnection, + HttpEvent, + ReceiveHttp, + format_error, + get_request_headers, + get_response_headers, +) from ._http_h2 import BufferedH2Connection, H2ConnectionLogger from ...commands import CloseConnection, Log, SendData, RequestWakeup from ...context import Context @@ -289,30 +296,6 @@ def done(self, _) -> CommandGenerator[None]: yield from () -def normalize_h1_headers( - headers: list[tuple[bytes, bytes]], is_client: bool -) -> list[tuple[bytes, bytes]]: - # HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length), - # which isn't valid HTTP/2. As such we normalize. - headers = h2.utilities.normalize_outbound_headers( - headers, - h2.utilities.HeaderValidationFlags(is_client, False, not is_client, False), - ) - # make sure that this is not just an iterator but an iterable, - # otherwise hyper-h2 will silently drop headers. - headers = list(headers) - return headers - - -def normalize_h2_headers(headers: list[tuple[bytes, bytes]]) -> CommandGenerator[None]: - for i in range(len(headers)): - if not headers[i][0].islower(): - yield Log( - f"Lowercased {repr(headers[i][0]).lstrip('b')} header as uppercase is not allowed with HTTP/2." - ) - headers[i] = (headers[i][0].lower(), headers[i][1]) - - class Http2Server(Http2Connection): h2_conf = h2.config.H2Configuration( **Http2Connection.h2_conf_defaults, @@ -330,19 +313,9 @@ def __init__(self, context: Context): def _handle_event(self, event: Event) -> CommandGenerator[None]: if isinstance(event, ResponseHeaders): if self.is_open_for_us(event.stream_id): - headers = [ - (b":status", b"%d" % event.response.status_code), - *event.response.headers.fields, - ] - if event.response.is_http2: - if self.context.options.normalize_outbound_headers: - yield from normalize_h2_headers(headers) - else: - headers = normalize_h1_headers(headers, False) - self.h2_conn.send_headers( event.stream_id, - headers, + headers=(yield from get_response_headers(event)), end_stream=event.end_stream, ) yield SendData(self.conn, self.h2_conn.data_to_send()) @@ -485,28 +458,9 @@ def _handle_event2(self, event: Event) -> CommandGenerator[None]: yield RequestWakeup(self.context.options.http2_ping_keepalive) yield from super()._handle_event(event) elif isinstance(event, RequestHeaders): - pseudo_headers = [ - (b":method", event.request.data.method), - (b":scheme", event.request.data.scheme), - (b":path", event.request.data.path), - ] - if event.request.authority: - pseudo_headers.append((b":authority", event.request.data.authority)) - - if event.request.is_http2: - hdrs = list(event.request.headers.fields) - if self.context.options.normalize_outbound_headers: - yield from normalize_h2_headers(hdrs) - else: - headers = event.request.headers - if not event.request.authority and "host" in headers: - headers = headers.copy() - pseudo_headers.append((b":authority", headers.pop(b"host"))) - hdrs = normalize_h1_headers(list(headers.fields), True) - self.h2_conn.send_headers( event.stream_id, - pseudo_headers + hdrs, + headers=(yield from get_request_headers(event)), end_stream=event.end_stream, ) self.streams[event.stream_id] = StreamState.EXPECTING_HEADERS From 71645ddc8d4236bb23de8f35793a25929dad1b39 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 20 Jun 2022 04:27:42 +0200 Subject: [PATCH 015/529] [quic] first work on H3 connections --- mitmproxy/proxy/layers/http/_http3.py | 168 ++++++++++++++++++++++++-- 1 file changed, 155 insertions(+), 13 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 1f57add6eb..3e0aea5646 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -1,29 +1,171 @@ -from aioquic.h3.connection import H3Connection -from mitmproxy.connection import Connection -from ._base import HttpConnection -from ..quic import QuicLayer -from ...context import Context +from abc import abstractmethod +from typing import Optional, Union + +from aioquic.quic.connection import QuicConnection +from aioquic.h3.connection import ( + H3Connection, + FrameUnexpected, + ErrorCode as H3ErrorCode, +) +from aioquic.h3 import events as h3_events + +from mitmproxy import version +from mitmproxy.net.http import status_codes +from mitmproxy.proxy import context, events, layer +from mitmproxy.proxy.layers.quic import ( + QuicConnectionEvent, + QuicGetConnection, + QuicTransmit, +) + +from . import ( + RequestData, + RequestEndOfMessage, + RequestHeaders, + RequestProtocolError, + ResponseData, + ResponseEndOfMessage, + ResponseHeaders, + RequestTrailers, + ResponseTrailers, + ResponseProtocolError, +) +from ._base import ( + HttpConnection, + HttpEvent, + format_error, + get_request_headers, + get_response_headers, +) class Http3Connection(HttpConnection): - h3_conn: H3Connection + quic: Optional[QuicConnection] = None + h3_conn: Optional[H3Connection] = None + + EventData: type[Union[RequestData, ResponseData]] + ReceiveData: type[Union[RequestData, ResponseData]] + EventEndOfMessage: type[Union[RequestEndOfMessage, ResponseEndOfMessage]] + ReceiveEndOfMessage: type[Union[RequestEndOfMessage, ResponseEndOfMessage]] + EventHeaders: type[Union[RequestHeaders, ResponseHeaders]] + ReceiveHeaders: type[Union[RequestHeaders, ResponseHeaders]] + EventProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]] + ReceiveProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]] + EventTrailers: type[Union[RequestTrailers, ResponseTrailers]] + ReceiveTrailers: type[Union[RequestTrailers, ResponseTrailers]] + + def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: + if isinstance(event, events.Start): + self.quic = yield QuicGetConnection() + assert isinstance(self.quic, H3Connection) + self.h3_conn = H3Connection(self.quic, enable_webtransport=False) + + else: + assert self.quic is not None + assert self.h3_conn is not None + + if isinstance(event, HttpEvent): + try: + + if isinstance(event, self.EventData): + self.h3_conn.send_data( + stream_id=event.stream_id, data=event.data, end_stream=False + ) + elif isinstance(event, self.EventHeaders): + get_headers = ( + get_request_headers + if isinstance(event, RequestHeaders) + else get_response_headers + ) + self.h3_conn.send_headers( + stream_id=event.stream_id, + headers=(yield from get_headers(event)), + end_stream=event.end_stream, + ) + elif isinstance(event, self.EventTrailers): + trailers = [*event.trailers.fields] + self.h3_conn.send_headers( + stream_id=event.stream_id, headers=trailers, end_stream=True + ) + elif isinstance(event, self.EventEndOfMessage): + self.h3_conn.send_data( + stream_id=event.stream_id, data=b"", end_stream=True + ) + elif isinstance(event, self.EventProtocolError): + self.protocol_error(event) + else: + raise AssertionError(f"Unexpected event: {event}") + + except FrameUnexpected: + # Http2Connection also ignores events that violate the current stream state + return - def __init__(self, context: Context, conn: Connection): - super().__init__(context, conn) - quic = context.layers[0] - assert isinstance(quic, QuicLayer) - self.h3_conn = H3Connection(quic.conn) + # transmit buffered data and re-arm timer + yield QuicTransmit(self.quic) + + elif isinstance(event, QuicConnectionEvent): + for h3_event in self.h3_conn.handle_event(event.event): + if isinstance(h3_event, h3_events.DataReceived): + pass + + elif isinstance(h3_event, h3_events.HeadersReceived): + pass + + else: + pass + + @abstractmethod + def protocol_error( + self, event: Union[RequestProtocolError, ResponseProtocolError] + ) -> None: + yield from () # pragma: no cover class Http3Server(Http3Connection): - def __init__(self, context: Context): + def __init__(self, context: context.Context): super().__init__(context, context.client) + def protocol_error( + self, event: Union[RequestProtocolError, ResponseProtocolError] + ) -> None: + assert self.h3_conn is not None + assert isinstance(event, ResponseProtocolError) + + # same as HTTP/2 + code = event.code + if code != status_codes.CLIENT_CLOSED_REQUEST: + code = status_codes.INTERNAL_SERVER_ERROR + self.h3_conn.send_headers( + stream_id=event.stream_id, + headers=[ + (b":status", b"%d" % code), + (b"server", version.MITMPROXY.encode()), + (b"content-type", b"text/html"), + ], + ) + self.h3_conn.send_data( + stream_id=event.stream_id, + data=format_error(code, event.message), + end_stream=True, + ) + class Http3Client(Http3Connection): - def __init__(self, context: Context): + def __init__(self, context: context.Context): super().__init__(context, context.server) + def protocol_error( + self, event: Union[RequestProtocolError, ResponseProtocolError] + ) -> None: + assert isinstance(event, RequestProtocolError) + assert self.quic is not None + + # same as HTTP/2 + code = event.code + if code != H3ErrorCode.H3_REQUEST_CANCELLED: + code = H3ErrorCode.H3_INTERNAL_ERROR + self.quic.reset_stream(stream_id=event.stream_id, error_code=code) + __all__ = [ "Http3Client", From 5454a71bb69b0836c135074725fe60ee7cd044fc Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 20 Jun 2022 17:05:32 +0200 Subject: [PATCH 016/529] [quic] more work on H3 --- mitmproxy/proxy/layers/http/_base.py | 71 +------------ mitmproxy/proxy/layers/http/_http2.py | 83 +++++++++++++-- mitmproxy/proxy/layers/http/_http3.py | 143 ++++++++++++++++++++------ mitmproxy/proxy/layers/quic.py | 4 +- 4 files changed, 189 insertions(+), 112 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_base.py b/mitmproxy/proxy/layers/http/_base.py index a0f82a2349..b5f66d46ba 100644 --- a/mitmproxy/proxy/layers/http/_base.py +++ b/mitmproxy/proxy/layers/http/_base.py @@ -2,13 +2,10 @@ import textwrap from dataclasses import dataclass -import h2.utilities - -from mitmproxy import ctx, http +from mitmproxy import http from mitmproxy.connection import Connection from mitmproxy.proxy import commands, events, layer from mitmproxy.proxy.context import Context -from mitmproxy.proxy.layers.http import RequestHeaders, ResponseHeaders StreamId = int @@ -60,69 +57,3 @@ def format_error(status_code: int, message: str) -> bytes: .strip() .encode("utf8", "replace") ) - - -def get_request_headers( - event: RequestHeaders, -) -> layer.CommandGenerator[list[tuple[bytes, bytes]]]: - pseudo_headers = [ - (b":method", event.request.data.method), - (b":scheme", event.request.data.scheme), - (b":path", event.request.data.path), - ] - if event.request.authority: - pseudo_headers.append((b":authority", event.request.data.authority)) - - if event.request.is_http2 or event.request.is_http3: - hdrs = list(event.request.headers.fields) - if ctx.options.normalize_outbound_headers: - yield from normalize_h2_or_h3_headers(hdrs) - else: - headers = event.request.headers - if not event.request.authority and "host" in headers: - headers = headers.copy() - pseudo_headers.append((b":authority", headers.pop(b"host"))) - hdrs = normalize_h1_headers(list(headers.fields), True) - - return pseudo_headers + hdrs - - -def get_response_headers( - event: ResponseHeaders, -) -> layer.CommandGenerator[list[tuple[bytes, bytes]]]: - headers = [ - (b":status", b"%d" % event.response.status_code), - *event.response.headers.fields, - ] - if event.response.is_http2 or event.request.is_http3: - if ctx.options.normalize_outbound_headers: - yield from normalize_h2_or_h3_headers(headers) - else: - headers = normalize_h1_headers(headers, False) - return headers - - -def normalize_h1_headers( - headers: list[tuple[bytes, bytes]], is_client: bool -) -> list[tuple[bytes, bytes]]: - # HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length), - # which isn't valid HTTP/2 or HTTP/3. As such we normalize. - headers = h2.utilities.normalize_outbound_headers( - headers, - h2.utilities.HeaderValidationFlags(is_client, False, not is_client, False), - ) - # make sure that this is not just an iterator but an iterable, - # otherwise hyper-h2 will silently drop headers. - headers = list(headers) - return headers - - -def normalize_h2_or_h3_headers( - headers: list[tuple[bytes, bytes]] -) -> layer.CommandGenerator[None]: - for i in range(len(headers)): - if not headers[i][0].islower(): - yield commands.Log( - f"Lowercased {repr(headers[i][0]).lstrip('b')} header as uppercase is not allowed with HTTP/2 nor HTTP/3." - ) - headers[i] = (headers[i][0].lower(), headers[i][1]) diff --git a/mitmproxy/proxy/layers/http/_http2.py b/mitmproxy/proxy/layers/http/_http2.py index bbb1b19f81..ec4aed6b72 100644 --- a/mitmproxy/proxy/layers/http/_http2.py +++ b/mitmproxy/proxy/layers/http/_http2.py @@ -13,7 +13,7 @@ import h2.stream import h2.utilities -from mitmproxy import http, version +from mitmproxy import ctx, http, version from mitmproxy.connection import Connection from mitmproxy.net.http import status_codes, url from mitmproxy.utils import human @@ -29,14 +29,7 @@ ResponseTrailers, ResponseProtocolError, ) -from ._base import ( - HttpConnection, - HttpEvent, - ReceiveHttp, - format_error, - get_request_headers, - get_response_headers, -) +from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error from ._http_h2 import BufferedH2Connection, H2ConnectionLogger from ...commands import CloseConnection, Log, SendData, RequestWakeup from ...context import Context @@ -296,6 +289,70 @@ def done(self, _) -> CommandGenerator[None]: yield from () +def normalize_h1_headers( + headers: list[tuple[bytes, bytes]], is_client: bool +) -> list[tuple[bytes, bytes]]: + # HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length), + # which isn't valid HTTP/2. As such we normalize. + headers = h2.utilities.normalize_outbound_headers( + headers, + h2.utilities.HeaderValidationFlags(is_client, False, not is_client, False), + ) + # make sure that this is not just an iterator but an iterable, + # otherwise hyper-h2 will silently drop headers. + headers = list(headers) + return headers + + +def normalize_h2_headers(headers: list[tuple[bytes, bytes]]) -> CommandGenerator[None]: + for i in range(len(headers)): + if not headers[i][0].islower(): + yield Log( + f"Lowercased {repr(headers[i][0]).lstrip('b')} header as uppercase is not allowed with HTTP/2." + ) + headers[i] = (headers[i][0].lower(), headers[i][1]) + + +def format_h2_request_headers( + event: RequestHeaders, +) -> CommandGenerator[list[tuple[bytes, bytes]]]: + pseudo_headers = [ + (b":method", event.request.data.method), + (b":scheme", event.request.data.scheme), + (b":path", event.request.data.path), + ] + if event.request.authority: + pseudo_headers.append((b":authority", event.request.data.authority)) + + if event.request.is_http2 or event.request.is_http3: + hdrs = list(event.request.headers.fields) + if ctx.options.normalize_outbound_headers: + yield from normalize_h2_headers(hdrs) + else: + headers = event.request.headers + if not event.request.authority and "host" in headers: + headers = headers.copy() + pseudo_headers.append((b":authority", headers.pop(b"host"))) + hdrs = normalize_h1_headers(list(headers.fields), True) + + return pseudo_headers + hdrs + + +def format_h2_response_headers( + event: ResponseHeaders, +) -> CommandGenerator[list[tuple[bytes, bytes]]]: + headers = [ + (b":status", b"%d" % event.response.status_code), + *event.response.headers.fields, + ] + if event.response.is_http2: + if ctx.options.normalize_outbound_headers: + yield from normalize_h2_headers(headers) + else: + headers = normalize_h1_headers(headers, False) + return headers + + class Http2Server(Http2Connection): h2_conf = h2.config.H2Configuration( **Http2Connection.h2_conf_defaults, @@ -315,7 +372,7 @@ def _handle_event(self, event: Event) -> CommandGenerator[None]: if self.is_open_for_us(event.stream_id): self.h2_conn.send_headers( event.stream_id, - headers=(yield from get_response_headers(event)), + headers=(yield from format_h2_response_headers(event)), end_stream=event.end_stream, ) yield SendData(self.conn, self.h2_conn.data_to_send()) @@ -460,7 +517,7 @@ def _handle_event2(self, event: Event) -> CommandGenerator[None]: elif isinstance(event, RequestHeaders): self.h2_conn.send_headers( event.stream_id, - headers=(yield from get_request_headers(event)), + headers=(yield from format_h2_request_headers(event)), end_stream=event.end_stream, ) self.streams[event.stream_id] = StreamState.EXPECTING_HEADERS @@ -596,6 +653,10 @@ def parse_h2_response_headers( __all__ = [ + "format_h2_request_headers", + "format_h2_response_headers", + "parse_h2_request_headers", + "parse_h2_response_headers", "Http2Client", "Http2Server", ] diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 3e0aea5646..ec600296e5 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -1,4 +1,5 @@ from abc import abstractmethod +import time from typing import Optional, Union from aioquic.quic.connection import QuicConnection @@ -6,12 +7,13 @@ H3Connection, FrameUnexpected, ErrorCode as H3ErrorCode, + HeadersState as H3HeadersState, ) from aioquic.h3 import events as h3_events -from mitmproxy import version +from mitmproxy import http, version from mitmproxy.net.http import status_codes -from mitmproxy.proxy import context, events, layer +from mitmproxy.proxy import commands, context, events, layer from mitmproxy.proxy.layers.quic import ( QuicConnectionEvent, QuicGetConnection, @@ -33,9 +35,14 @@ from ._base import ( HttpConnection, HttpEvent, + ReceiveHttp, format_error, - get_request_headers, - get_response_headers, +) +from ._http2 import ( + format_h2_request_headers, + format_h2_response_headers, + parse_h2_request_headers, + parse_h2_response_headers, ) @@ -43,22 +50,14 @@ class Http3Connection(HttpConnection): quic: Optional[QuicConnection] = None h3_conn: Optional[H3Connection] = None - EventData: type[Union[RequestData, ResponseData]] - ReceiveData: type[Union[RequestData, ResponseData]] - EventEndOfMessage: type[Union[RequestEndOfMessage, ResponseEndOfMessage]] - ReceiveEndOfMessage: type[Union[RequestEndOfMessage, ResponseEndOfMessage]] - EventHeaders: type[Union[RequestHeaders, ResponseHeaders]] - ReceiveHeaders: type[Union[RequestHeaders, ResponseHeaders]] - EventProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]] - ReceiveProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]] - EventTrailers: type[Union[RequestTrailers, ResponseTrailers]] ReceiveTrailers: type[Union[RequestTrailers, ResponseTrailers]] def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: if isinstance(event, events.Start): - self.quic = yield QuicGetConnection() - assert isinstance(self.quic, H3Connection) - self.h3_conn = H3Connection(self.quic, enable_webtransport=False) + quic = yield QuicGetConnection(self.conn) + assert isinstance(quic, QuicConnection) + self.quic = quic + self.h3_conn = H3Connection(quic, enable_webtransport=False) else: assert self.quic is not None @@ -67,31 +66,34 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: if isinstance(event, HttpEvent): try: - if isinstance(event, self.EventData): + if isinstance(event, (RequestData, ResponseData)): self.h3_conn.send_data( stream_id=event.stream_id, data=event.data, end_stream=False ) - elif isinstance(event, self.EventHeaders): - get_headers = ( - get_request_headers - if isinstance(event, RequestHeaders) - else get_response_headers - ) + elif isinstance(event, (RequestHeaders, ResponseHeaders)): self.h3_conn.send_headers( stream_id=event.stream_id, - headers=(yield from get_headers(event)), + headers=( + yield from ( + format_h2_request_headers(event) + if isinstance(event, RequestHeaders) + else format_h2_response_headers(event) + ) + ), end_stream=event.end_stream, ) - elif isinstance(event, self.EventTrailers): + elif isinstance(event, (RequestTrailers, ResponseTrailers)): trailers = [*event.trailers.fields] self.h3_conn.send_headers( stream_id=event.stream_id, headers=trailers, end_stream=True ) - elif isinstance(event, self.EventEndOfMessage): + elif isinstance(event, (RequestEndOfMessage, ResponseEndOfMessage)): self.h3_conn.send_data( stream_id=event.stream_id, data=b"", end_stream=True ) - elif isinstance(event, self.EventProtocolError): + elif isinstance( + event, (RequestProtocolError, ResponseProtocolError) + ): self.protocol_error(event) else: raise AssertionError(f"Unexpected event: {event}") @@ -108,20 +110,49 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: if isinstance(h3_event, h3_events.DataReceived): pass + # handle headers and trailers elif isinstance(h3_event, h3_events.HeadersReceived): - pass + if ( + self.h3_conn._stream[h3_event.stream_id].headers_recv_state + is H3HeadersState.AFTER_TRAILERS + ): + yield ReceiveHttp( + self.ReceiveTrailers( + stream_id=h3_event.stream_id, + trailers=http.Headers(h3_event.headers), + ) + ) + else: + try: + receive_event = self.headers_received(h3_event) + except ValueError as e: + # TODO + pass + else: + yield ReceiveHttp(receive_event) + # we don't support push, web transport, etc. else: - pass + yield commands.Log( + f"Ignored unsupported H3 event: {h3_event!r}" + ) @abstractmethod def protocol_error( self, event: Union[RequestProtocolError, ResponseProtocolError] ) -> None: - yield from () # pragma: no cover + pass # pragma: no cover + + @abstractmethod + def headers_received( + self, event: h3_events.HeadersReceived + ) -> Union[RequestHeaders, ResponseHeaders]: + pass # pragma: no cover class Http3Server(Http3Connection): + ReceiveTrailers = RequestTrailers + def __init__(self, context: context.Context): super().__init__(context, context.client) @@ -149,8 +180,41 @@ def protocol_error( end_stream=True, ) + def headers_received( + self, event: h3_events.HeadersReceived + ) -> Union[RequestHeaders, ResponseHeaders]: + # same as HTTP/2 + ( + host, + port, + method, + scheme, + authority, + path, + headers, + ) = parse_h2_request_headers(event) + request = http.Request( + host=host, + port=port, + method=method, + scheme=scheme, + authority=authority, + path=path, + http_version=b"HTTP/3", + headers=headers, + content=None, + trailers=None, + timestamp_start=time.time(), + timestamp_end=None, + ) + return RequestHeaders( + stream_id=event.stream_id, request=request, end_stream=event.stream_ended + ) + class Http3Client(Http3Connection): + ReceiveTrailers = ResponseTrailers + def __init__(self, context: context.Context): super().__init__(context, context.server) @@ -166,6 +230,25 @@ def protocol_error( code = H3ErrorCode.H3_INTERNAL_ERROR self.quic.reset_stream(stream_id=event.stream_id, error_code=code) + def headers_received( + self, event: h3_events.HeadersReceived + ) -> Union[RequestHeaders, ResponseHeaders]: + # same as HTTP/2 + status_code, headers = parse_h2_response_headers(event.headers) + response = http.Response( + http_version=b"HTTP/3", + status_code=status_code, + reason=b"", + headers=headers, + content=None, + trailers=None, + timestamp_start=time.time(), + timestamp_end=None, + ) + return ResponseHeaders( + stream_id=event.stream_id, response=response, end_stream=event.stream_ended + ) + __all__ = [ "Http3Client", diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 1fc14282aa..ee656686f5 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -90,7 +90,7 @@ class QuicGetConnection(commands.ConnectionCommand): # -> QuicConnection @dataclass -class QuicTransmit: +class QuicTransmit(commands.Command): connection: QuicConnection @@ -426,6 +426,8 @@ def state_done(self, event: events.Event) -> layer.CommandGenerator[None]: yield from self.child_layer.handle_event(event) def transmit(self) -> layer.CommandGenerator[None]: + assert self.quic + # send all queued datagrams for data, addr in self.quic.datagrams_to_send(now=self._loop.time()): yield commands.SendData(connection=self.conn, data=data, remote_addr=addr) From 9a9c962caa0ededd0f39b84bd7a25bc8633ae577 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 20 Jun 2022 22:09:22 +0200 Subject: [PATCH 017/529] [quic] improve close connection handling --- mitmproxy/proxy/layers/quic.py | 74 +++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index ee656686f5..b53f163d7a 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -7,7 +7,12 @@ from aioquic.buffer import Buffer as QuicBuffer from aioquic.quic import events as quic_events from aioquic.quic.configuration import QuicConfiguration -from aioquic.quic.connection import QuicConnection, QuicConnectionError, QuicErrorCode +from aioquic.quic.connection import ( + QuicConnection, + QuicConnectionError, + QuicConnectionState, + QuicErrorCode, +) from aioquic.tls import CipherSuite, HandshakeType from aioquic.quic.packet import PACKET_TYPE_INITIAL, pull_quic_header from cryptography import x509 @@ -97,7 +102,7 @@ class QuicTransmit(commands.Command): @dataclass(repr=False) class QuicGetConnectionCompleted(events.CommandCompleted): command: QuicGetConnection - connection: QuicConnection + reply: QuicConnection class QuicSecretsLogger: @@ -217,7 +222,7 @@ def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: # filter commands coming from the child layer for command in self.child_layer.handle_event(event): - # answer or queue requests for the aioquic connection instanc + # answer or queue requests for the aioquic connection instance if ( isinstance(command, QuicGetConnection) and command.connection is self.conn @@ -228,7 +233,7 @@ def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: yield from self.child_layer.handle_event( QuicGetConnectionCompleted( command=command, - connection=self.quic, + reply=self.quic, ) ) @@ -242,11 +247,11 @@ def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: and command.connection is self.conn ): reason = "CloseConnection command received." - if self.conn.connected and self.quic is not None: + if self.quic is None: + yield from self.shutdown_connection(reason=reason, level="info") + else: self.quic.close(reason_phrase=reason) yield from self.process_events() - else: - yield from self.shutdown_connection(reason, level="info") # return other commands else: @@ -263,7 +268,8 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: yield QuicTlsStartServerHook(tls_data) if tls_data.settings is None: yield from self.shutdown_connection( - "No TLS settings were provided, failing connection.", level="error" + reason="No TLS settings were provided, failing connection.", + level="error", ) return self.tls = tls_data.settings @@ -283,7 +289,7 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: yield from self.child_layer.handle_event( QuicGetConnectionCompleted( command=self._get_connection_commands.pop(), - connection=self.quic, + reply=self.quic, ) ) @@ -303,20 +309,8 @@ def process_events(self) -> layer.CommandGenerator[None]: self.retire_connection_id_callback(event.connection_id) elif isinstance(event, quic_events.ConnectionTerminated): - # report as TLS failure if the termination happened before the handshake - if not self.conn.tls_established: - self.conn.error = event.reason_phrase - tls_data = QuicTlsData( - conn=self.conn, context=self.context, settings=self.tls - ) - if self.conn is self.context.client: - yield layers.tls.TlsFailedClientHook(tls_data) - else: - yield layers.tls.TlsFailedServerHook(tls_data) - - # always close the connection yield from self.shutdown_connection( - event.reason_phrase, + reason=event.reason_phrase or str(event.error_code), level=( "info" if event.error_code is QuicErrorCode.NO_ERROR else "warn" ), @@ -368,6 +362,21 @@ def shutdown_connection( reason: str, level: Literal["error", "warn", "info", "alert", "debug"], ) -> layer.CommandGenerator[None]: + # ensure QUIC has been properly shut down + assert self.quic is None or self.quic._state is QuicConnectionState.TERMINATED + + # report as TLS failure if the termination happened before the handshake + if not self.conn.tls_established and self.tls is not None: + self.conn.error = reason + tls_data = QuicTlsData( + conn=self.conn, context=self.context, settings=self.tls + ) + if self.conn is self.context.client: + yield layers.tls.TlsFailedClientHook(tls_data) + else: + yield layers.tls.TlsFailedServerHook(tls_data) + + # log the reason, ensure the connection is closed and no longer handle events yield commands.Log( message=f"Connection {self.conn} closed: {reason}", level=level ) @@ -398,13 +407,22 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: yield from self.process_events() return - # check if the connection was closed by peer + # handle connections closed by peer elif ( isinstance(event, events.ConnectionClosed) and event.connection is self.conn ): - yield from self.shutdown_connection( - "Peer UDP connection timed out.", level="info" - ) + reason = "Peer UDP connection timed out." + if self.quic is not None: + # there is no point in calling quic.close, as it cannot send packets anymore + # so we simply set the state and simulate a ConnectionTerminated event + self.quic._set_state(QuicConnectionState.TERMINATED) + yield from self.event_to_child( + QuicConnectionEvent( + connection=self.conn, + event=quic_events.ConnectionTerminated(reason_phrase=reason), + ) + ) + yield from self.shutdown_connection(reason=reason, level="info") # intercept wakeup events for aioquic elif ( @@ -462,7 +480,7 @@ def start(self) -> layer.CommandGenerator[None]: err = yield commands.OpenConnection(self.conn) if err is not None: self.shutdown_connection( - f"Failed to connect: {err}", + reason=f"Failed to connect: {err}", level="warn", ) return @@ -525,7 +543,7 @@ def state_wait_for_upstream( elif isinstance(event, events.ConnectionClosed): if event.connection is self.conn: yield from self.shutdown_connection( - "Client UDP connection timeout out before upstream server handshake completed.", + reason="Client UDP connection timeout out before upstream server handshake completed.", level="info", ) elif event.connection is self.context.server: From 5608b0629a3a82493f8ecad7d41200f02115b863 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 20 Jun 2022 22:27:42 +0200 Subject: [PATCH 018/529] [quic] H3 stream ID translation on error handling --- mitmproxy/proxy/layers/http/_http3.py | 231 ++++++++++++++++++-------- 1 file changed, 164 insertions(+), 67 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index ec600296e5..32a75d0e91 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -1,8 +1,7 @@ from abc import abstractmethod import time -from typing import Optional, Union +from typing import Dict, Optional, Union -from aioquic.quic.connection import QuicConnection from aioquic.h3.connection import ( H3Connection, FrameUnexpected, @@ -10,6 +9,9 @@ HeadersState as H3HeadersState, ) from aioquic.h3 import events as h3_events +from aioquic.quic import events as quic_events +from aioquic.quic.connection import QuicConnection +from aioquic.quic.packet import QuicErrorCode from mitmproxy import http, version from mitmproxy.net.http import status_codes @@ -50,6 +52,9 @@ class Http3Connection(HttpConnection): quic: Optional[QuicConnection] = None h3_conn: Optional[H3Connection] = None + ReceiveData: type[Union[RequestData, ResponseData]] + ReceiveEndOfMessage: type[Union[RequestEndOfMessage, ResponseEndOfMessage]] + ReceiveProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]] ReceiveTrailers: type[Union[RequestTrailers, ResponseTrailers]] def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: @@ -59,83 +64,149 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: self.quic = quic self.h3_conn = H3Connection(quic, enable_webtransport=False) - else: + if isinstance(event, events.ConnectionClosed): + self._handle_event = self.done + + # send mitmproxy HTTP events over the H3 connection + elif isinstance(event, HttpEvent): assert self.quic is not None assert self.h3_conn is not None + try: - if isinstance(event, HttpEvent): - try: + if isinstance(event, (RequestData, ResponseData)): + self.h3_conn.send_data( + stream_id=event.stream_id, data=event.data, end_stream=False + ) + elif isinstance(event, (RequestHeaders, ResponseHeaders)): + self.h3_conn.send_headers( + stream_id=event.stream_id, + headers=( + yield from ( + format_h2_request_headers(event) + if isinstance(event, RequestHeaders) + else format_h2_response_headers(event) + ) + ), + end_stream=event.end_stream, + ) + elif isinstance(event, (RequestTrailers, ResponseTrailers)): + self.h3_conn.send_headers( + stream_id=event.stream_id, + headers=[*event.trailers.fields], + end_stream=True, + ) + elif isinstance(event, (RequestEndOfMessage, ResponseEndOfMessage)): + self.h3_conn.send_data( + stream_id=event.stream_id, data=b"", end_stream=True + ) + elif isinstance( + event, (RequestProtocolError, ResponseProtocolError) + ): + self.protocol_error(event) + else: + raise AssertionError(f"Unexpected event: {event}") - if isinstance(event, (RequestData, ResponseData)): - self.h3_conn.send_data( - stream_id=event.stream_id, data=event.data, end_stream=False - ) - elif isinstance(event, (RequestHeaders, ResponseHeaders)): - self.h3_conn.send_headers( + except FrameUnexpected: + # Http2Connection also ignores HttpEvents that violate the current stream state + return + + # transmit buffered data and re-arm timer + yield QuicTransmit(self.quic) + + # handle events from the underlying QUIC connection + elif isinstance(event, QuicConnectionEvent): + assert self.quic is not None + assert self.h3_conn is not None + + # report abrupt stream resets + if isinstance(event, quic_events.StreamReset): + if event.stream_id in self.h3_conn._stream: + try: + reason = H3ErrorCode(event.error_code).name + except ValueError: + try: + reason = QuicErrorCode(event.error_code).name + except ValueError: + reason = str(event.error_code) + code = ( + status_codes.CLIENT_CLOSED_REQUEST + if event.error_code == H3ErrorCode.H3_REQUEST_CANCELLED + else self.ReceiveProtocolError.code + ) + yield ReceiveHttp( + self.ReceiveProtocolError( stream_id=event.stream_id, - headers=( - yield from ( - format_h2_request_headers(event) - if isinstance(event, RequestHeaders) - else format_h2_response_headers(event) - ) - ), - end_stream=event.end_stream, + message=f"stream reset by client ({reason})", + code=code, + ) + ) + + # report a protocol error for all remaining open streams when a connection is terminated + elif isinstance(event, quic_events.ConnectionTerminated): + for stream in self.h3_conn._stream.values(): + if not stream.ended: + yield ReceiveHttp( + self.ReceiveProtocolError( + stream_id=stream.stream_id, + message=event.reason_phrase, + code=event.error_code, + ) ) - elif isinstance(event, (RequestTrailers, ResponseTrailers)): - trailers = [*event.trailers.fields] - self.h3_conn.send_headers( - stream_id=event.stream_id, headers=trailers, end_stream=True + + # forward QUIC events to the H3 connection + for h3_event in self.h3_conn.handle_event(event.event): + + # report received data + if isinstance(h3_event, h3_events.DataReceived): + yield ReceiveHttp( + self.ReceiveData( + stream_id=h3_event.stream_id, data=h3_event.data ) - elif isinstance(event, (RequestEndOfMessage, ResponseEndOfMessage)): - self.h3_conn.send_data( - stream_id=event.stream_id, data=b"", end_stream=True + ) + if h3_event.stream_ended: + yield ReceiveHttp( + self.ReceiveEndOfMessage(stream_id=event.stream_id) ) - elif isinstance( - event, (RequestProtocolError, ResponseProtocolError) + + # report headers and trailers + elif isinstance(h3_event, h3_events.HeadersReceived): + if ( + self.h3_conn._stream[h3_event.stream_id].headers_recv_state + is H3HeadersState.AFTER_TRAILERS ): - self.protocol_error(event) + yield ReceiveHttp( + self.ReceiveTrailers( + stream_id=h3_event.stream_id, + trailers=http.Headers(h3_event.headers), + ) + ) else: - raise AssertionError(f"Unexpected event: {event}") - - except FrameUnexpected: - # Http2Connection also ignores events that violate the current stream state - return - - # transmit buffered data and re-arm timer - yield QuicTransmit(self.quic) - - elif isinstance(event, QuicConnectionEvent): - for h3_event in self.h3_conn.handle_event(event.event): - if isinstance(h3_event, h3_events.DataReceived): - pass - - # handle headers and trailers - elif isinstance(h3_event, h3_events.HeadersReceived): - if ( - self.h3_conn._stream[h3_event.stream_id].headers_recv_state - is H3HeadersState.AFTER_TRAILERS - ): - yield ReceiveHttp( - self.ReceiveTrailers( - stream_id=h3_event.stream_id, - trailers=http.Headers(h3_event.headers), - ) + try: + receive_event = self.headers_received(h3_event) + except ValueError as e: + # this will result in a ConnectionTerminated event + self.quic.close( + error_code=H3ErrorCode.H3_GENERAL_PROTOCOL_ERROR, + reason_phrase=f"Invalid HTTP/3 request headers: {e}", ) else: - try: - receive_event = self.headers_received(h3_event) - except ValueError as e: - # TODO - pass - else: - yield ReceiveHttp(receive_event) - - # we don't support push, web transport, etc. - else: - yield commands.Log( - f"Ignored unsupported H3 event: {h3_event!r}" - ) + yield ReceiveHttp(receive_event) + if h3_event.stream_ended: + yield ReceiveHttp( + self.ReceiveEndOfMessage(stream_id=event.stream_id) + ) + + # we don't support push, web transport, etc. + else: + yield commands.Log( + f"Ignored unsupported H3 event: {h3_event!r}" + ) + + else: + raise AssertionError(f"Unexpected event: {event!r}") + + def done(self, event: events.Event) -> layer.CommandGenerator[None]: + yield from () @abstractmethod def protocol_error( @@ -151,6 +222,9 @@ def headers_received( class Http3Server(Http3Connection): + ReceiveData = RequestData + ReceiveEndOfMessage = RequestEndOfMessage + ReceiveProtocolError = RequestProtocolError ReceiveTrailers = RequestTrailers def __init__(self, context: context.Context): @@ -213,10 +287,18 @@ def headers_received( class Http3Client(Http3Connection): + ReceiveData = ResponseData + ReceiveEndOfMessage = ResponseEndOfMessage + ReceiveProtocolError = ResponseProtocolError ReceiveTrailers = ResponseTrailers + our_stream_id: Dict[int, int] + their_stream_id: Dict[int, int] + def __init__(self, context: context.Context): super().__init__(context, context.server) + self.our_stream_id = {} + self.their_stream_id = {} def protocol_error( self, event: Union[RequestProtocolError, ResponseProtocolError] @@ -249,6 +331,21 @@ def headers_received( stream_id=event.stream_id, response=response, end_stream=event.stream_ended ) + def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: + # translate stream IDs just like HTTP/2 client + if isinstance(event, HttpEvent): + assert self.quic + ours = self.our_stream_id.get(event.stream_id, None) + if ours is None: + ours = self.quic.get_next_available_stream_id() + self.our_stream_id[event.stream_id] = ours + self.their_stream_id[ours] = event.stream_id + event.stream_id = ours + for cmd in super()._handle_event(event): + if isinstance(cmd, ReceiveHttp): + cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id] + yield cmd + __all__ = [ "Http3Client", From 6e66875e73ad1cb3dd2b93b69fd6079ea348da87 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Tue, 21 Jun 2022 00:10:39 +0200 Subject: [PATCH 019/529] [quic] first connectable version --- mitmproxy/proxy/layers/quic.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index b53f163d7a..78b7b18378 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -1,6 +1,6 @@ from abc import abstractmethod import asyncio -from dataclasses import dataclass +from dataclasses import dataclass, field from ssl import VerifyMode from typing import Callable, List, Literal, Optional, Tuple, Union @@ -31,7 +31,7 @@ class QuicTlsSettings: certificate: Optional[x509.Certificate] = None """The certificate to use for the connection.""" - certificate_chain: List[x509.Certificate] = [] + certificate_chain: List[x509.Certificate] = field(default_factory=list) """A list of additional certificates to send to the peer.""" certificate_private_key: Optional[ Union[dsa.DSAPrivateKey, ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey] @@ -115,7 +115,7 @@ def __init__(self, logger: tls.MasterSecretLogger) -> None: def write(self, s: str) -> int: if s[-1:] == "\n": s = s[:-1] - data = s.encode() + data = s.encode("ascii") self.logger(None, data) # type: ignore return len(data) + 1 @@ -138,7 +138,11 @@ def pull_client_hello_and_connection_id(data: bytes) -> Tuple[ClientHello, bytes # patch aioquic to intercept the client hello quic = QuicConnection( - configuration=QuicConfiguration(), + configuration=QuicConfiguration( + is_client=False, + certificate="", + private_key="", + ), original_destination_connection_id=header.destination_cid, ) _initialize = quic._initialize @@ -202,7 +206,7 @@ def build_configuration(self) -> QuicConfiguration: assert self.tls is not None return QuicConfiguration( - alpn_protocols=self.conn.alpn_offers, + alpn_protocols=[offer.decode("ascii") for offer in self.conn.alpn_offers], connection_id_length=self.context.options.quic_connection_id_length, is_client=self.conn is self.context.server, secrets_log_file=QuicSecretsLogger(tls.log_master_secret) # type: ignore @@ -329,7 +333,7 @@ def process_events(self) -> layer.CommandGenerator[None]: self.conn.certificate_list = [ certs.Cert.from_pyopenssl(x) for x in all_certs ] - self.conn.alpn = event.alpn_protocol.encode() + self.conn.alpn = event.alpn_protocol.encode("ascii") self.conn.cipher = self.quic.tls.key_schedule.cipher_suite.name self.conn.tls_version = "QUIC" From 97e482998b76e8a3f18821a409b6eaa5c52d3754 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Wed, 22 Jun 2022 00:45:51 +0200 Subject: [PATCH 020/529] [quic] implement relay layer --- mitmproxy/proxy/layers/http/_http3.py | 2 +- mitmproxy/proxy/layers/quic.py | 147 +++++++++++++++++++++++++- 2 files changed, 145 insertions(+), 4 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 32a75d0e91..dfca3658ff 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -334,7 +334,7 @@ def headers_received( def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: # translate stream IDs just like HTTP/2 client if isinstance(event, HttpEvent): - assert self.quic + assert self.quic is not None ours = self.our_stream_id.get(event.stream_id, None) if ours is None: ours = self.quic.get_next_available_stream_id() diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 78b7b18378..3ada9da3be 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -2,7 +2,7 @@ import asyncio from dataclasses import dataclass, field from ssl import VerifyMode -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union from aioquic.buffer import Buffer as QuicBuffer from aioquic.quic import events as quic_events @@ -17,9 +17,10 @@ from aioquic.quic.packet import PACKET_TYPE_INITIAL, pull_quic_header from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa -from mitmproxy import certs, connection +from mitmproxy import certs, connection, flow as mitm_flow, tcp from mitmproxy.net import tls from mitmproxy.proxy import commands, context, events, layer, layers +from mitmproxy.proxy.layers import tcp as tcp_layer from mitmproxy.tls import ClientHello, ClientHelloData, TlsData @@ -179,6 +180,146 @@ def initialize_replacement(peer_cid: bytes) -> None: raise ValueError("No ClientHello returned.") +class QuicRelayLayer(layer.Layer): + # for now we're (ab)using the TCPFlow until https://github.com/mitmproxy/mitmproxy/pull/5414 is resolved + datagram_flow: Optional[tcp.TCPFlow] = None + lookup_server: Dict[int, Tuple[int, tcp.TCPFlow]] + lookup_client: Dict[int, Tuple[int, tcp.TCPFlow]] + quic_server: Optional[QuicConnection] = None + quic_client: Optional[QuicConnection] = None + + def __init__(self, context: context.Context) -> None: + super().__init__(context) + self.lookup_server = {} + self.lookup_client = {} + + def end_flow(self, flow: tcp.TCPFlow, event: quic_events.ConnectionTerminated) -> layer.CommandGenerator[None]: + if event.error_code == QuicErrorCode.NO_ERROR: + yield tcp_layer.TcpEndHook(flow) + else: + flow.error = mitm_flow.Error(event.reason_phrase) + yield tcp_layer.TcpErrorHook(flow) + flow.live = False + + def get_quic( + self, conn: connection.Connection + ) -> layer.CommandGenerator[QuicConnection]: + quic = yield QuicGetConnection(conn) + assert isinstance(quic, QuicConnection) + return quic + + def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: + if isinstance(event, events.Start): + self.quic_server = yield from self.get_quic(self.context.server) + self.quic_client = yield from self.get_quic(self.context.client) + + elif isinstance(event, QuicConnectionEvent): + assert self.quic_server is not None + assert self.quic_client is not None + + quic_event = event.event + from_client = event.connection is self.context.client + lookup_in = self.lookup_client if from_client else self.lookup_server + lookup_out = self.lookup_server if from_client else self.lookup_client + # quic_in = self.quic_client if from_client else self.quic_server + quic_out = self.quic_server if from_client else self.quic_client + + # forward close and end all flows + if isinstance(quic_event, quic_events.ConnectionTerminated): + quic_out.close( + error_code=quic_event.error_code, + frame_type=quic_event.frame_type, + reason_phrase=quic_event.reason_phrase, + ) + while lookup_in: + stream_id_in = next(iter(lookup_in)) + stream_id_out, flow = lookup_in[stream_id_in] + yield from self.end_flow(flow=flow, event=quic_event) + del lookup_in[stream_id_in] + del lookup_out[stream_id_out] + + if self.datagram_flow is not None: + yield from self.end_flow(flow=flow, event=quic_event) + self.datagram_flow = None + + # forward datagrams (that are not stream-bound) + elif isinstance(quic_event, quic_events.DatagramFrameReceived): + if self.datagram_flow is None: + self.datagram_flow = tcp.TCPFlow( + client_conn=self.context.client, + server_conn=self.context.server, + live=True, + ) + yield tcp_layer.TcpStartHook(self.datagram_flow) + message = tcp.TCPMessage( + from_client=from_client, content=quic_event.data + ) + self.datagram_flow.messages.append(message) + yield tcp_layer.TcpMessageHook(self.datagram_flow) + quic_out.send_datagram_frame(data=message.content) + + # forward stream data + elif isinstance(quic_event, quic_events.StreamDataReceived): + # get or create the stream on the other side (and flow) + stream_id_in = quic_event.stream_id + if stream_id_in in lookup_in: + stream_id_out, flow = lookup_in[stream_id_in] + else: + stream_id_out = quic_out.get_next_available_stream_id() + flow = tcp.TCPFlow( + client_conn=self.context.client, + server_conn=self.context.server, + live=True, + ) + lookup_in[stream_id_in] = (stream_id_out, flow) + lookup_out[stream_id_out] = (stream_id_in, flow) + yield tcp_layer.TcpStartHook(flow) + + # forward the message allowing addons to change it + message = tcp.TCPMessage( + from_client=from_client, content=quic_event.data + ) + flow.messages.append(message) + yield tcp_layer.TcpMessageHook(flow) + quic_out.send_stream_data( + stream_id=stream_id_out, + data=message.content, + end_stream=quic_event.end_stream, + ) + + # end the flow and remove the lookup if the stream ended + if quic_event.end_stream: + yield tcp_layer.TcpEndHook(flow) + flow.live = False + del lookup_in[stream_id_in] + del lookup_out[stream_id_out] + + # forward resets to peer streams + elif isinstance(quic_event, quic_events.StreamReset): + stream_id_in = quic_event.stream_id + if stream_id_in in lookup_in: + stream_id_out, flow = lookup_in[stream_id_in] + quic_out.stop_stream( + stream_id=stream_id_out, error_code=quic_event.error_code + ) + + # try to get a name describing the reset reason + try: + err = QuicErrorCode(quic_event.error_code).name + except ValueError: + err = str(quic_event.error_code) + + # report the error to addons and delete the stream + flow.error = mitm_flow.Error(str(err)) + yield tcp_layer.TcpErrorHook(flow) + flow.live = False + del lookup_in[stream_id_in] + del lookup_out[stream_id_out] + + def done(self, _) -> layer.CommandGenerator[None]: + yield from () + + class _QuicLayer(layer.Layer): child_layer: layer.Layer conn: connection.Connection @@ -316,7 +457,7 @@ def process_events(self) -> layer.CommandGenerator[None]: yield from self.shutdown_connection( reason=event.reason_phrase or str(event.error_code), level=( - "info" if event.error_code is QuicErrorCode.NO_ERROR else "warn" + "info" if event.error_code == QuicErrorCode.NO_ERROR else "warn" ), ) From 0b5afe54c46eca45ef70bb1ac82828735fe8d7ef Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Wed, 22 Jun 2022 03:45:39 +0200 Subject: [PATCH 021/529] [quic] bugfixes and improvements - next_layer decisions - don't forward obsolete wakeups - remove excessive named arguments --- mitmproxy/addons/next_layer.py | 20 +- mitmproxy/proxy/layers/http/__init__.py | 19 +- mitmproxy/proxy/layers/http/_http3.py | 8 +- mitmproxy/proxy/layers/quic.py | 275 ++++++++++++------------ 4 files changed, 177 insertions(+), 145 deletions(-) diff --git a/mitmproxy/addons/next_layer.py b/mitmproxy/addons/next_layer.py index e2e57e96b9..b9908b0a24 100644 --- a/mitmproxy/addons/next_layer.py +++ b/mitmproxy/addons/next_layer.py @@ -22,7 +22,7 @@ from mitmproxy.net.tls import is_tls_record_magic from mitmproxy.proxy.layers.http import HTTPMode from mitmproxy.proxy import context, layer, layers -from mitmproxy.proxy.layers import modes +from mitmproxy.proxy.layers import modes, quic from mitmproxy.proxy.layers.tls import HTTP_ALPNS, parse_client_hello LayerCls = type[layer.Layer] @@ -117,6 +117,24 @@ def next_layer(self, nextlayer: layer.NextLayer): def _next_layer( self, context: context.Context, data_client: bytes, data_server: bytes ) -> Optional[layer.Layer]: + if isinstance(context.layers[0], quic.QuicLayer): + if context.client.alpn is None: + return None # should never happen, as ask is called after handshake + if context.client.alpn == b"h3" or context.client.alpn.startswith(b"h3-"): + if ctx.options.mode == "regular": + mode = HTTPMode.regular + elif ctx.options.mode == "transparent" or ctx.options.mode.startswith("reverse:"): + mode = HTTPMode.transparent + elif ctx.options.mode.startswith("upstream:"): + mode = HTTPMode.upstream + else: + return None + return layers.HttpLayer(context=context, mode=mode) + else: + if context.server.address is None: + return None + return quic.QuicRelayLayer(context) + if len(context.layers) == 0: return self.make_top_layer(context) diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index 36e0969a9d..81ac1c9189 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -11,7 +11,7 @@ from mitmproxy.net.http import status_codes, url from mitmproxy.net.http.http1 import expected_http_body_size from mitmproxy.proxy import commands, events, layer, tunnel -from mitmproxy.proxy.layers import tcp, tls, websocket +from mitmproxy.proxy.layers import quic, tcp, tls, websocket from mitmproxy.proxy.layers.http import _upstream_proxy from mitmproxy.proxy.utils import expect from mitmproxy.utils import human @@ -62,6 +62,10 @@ def validate_request(mode: HTTPMode, request: http.Request) -> Optional[str]: return None +def is_h3_alpn(alpn: Optional[bytes]) -> bool: + return alpn == b"h3" or (alpn is not None and alpn.startswith(b"h3-")) + + @dataclass class GetHttpConnection(HttpCommand): """ @@ -822,7 +826,7 @@ def __init__(self, context: Context, mode: HTTPMode): self.command_sources = {} http_conn: HttpConnection - if self.context.client.alpn == b"h3": + if is_h3_alpn(self.context.client.alpn): http_conn = Http3Server(context.fork()) elif self.context.client.alpn == b"h2": http_conn = Http2Server(context.fork()) @@ -879,7 +883,7 @@ def _handle_event(self, event: events.Event): elif isinstance(event, events.DataReceived): # The peer has sent data. This can happen with HTTP/2 servers that already send a settings frame. child_layer: HttpConnection - if self.context.server.alpn == b"h3": + if is_h3_alpn(self.context.server.alpn): child_layer = Http3Client(self.context.fork()) elif self.context.server.alpn == b"h2": child_layer = Http2Client(self.context.fork()) @@ -997,7 +1001,7 @@ def get_connection( if not can_use_context_connection: - context.server = Server(event.address) + context.server = Server(event.address, transport_protocol=context.client.transport_protocol) if event.via: context.server.via = event.via @@ -1015,7 +1019,10 @@ def get_connection( context.server.sni = self.context.client.sni or event.address[0] else: context.server.sni = event.address[0] - stack /= tls.ServerTLSLayer(context) + if context.server.transport_protocol == "udp": + stack /= quic.ServerQuicLayer(context) + else: + stack /= tls.ServerTLSLayer(context) stack /= HttpClient(context) @@ -1065,7 +1072,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: else: err = yield commands.OpenConnection(self.context.server) if not err: - if self.context.server.alpn == b"h3": + if is_h3_alpn(self.context.server.alpn): self.child_layer = Http3Client(self.context) elif self.context.server.alpn == b"h2": self.child_layer = Http2Client(self.context) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index dfca3658ff..2c8e0b0b16 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -64,7 +64,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: self.quic = quic self.h3_conn = H3Connection(quic, enable_webtransport=False) - if isinstance(event, events.ConnectionClosed): + elif isinstance(event, events.ConnectionClosed): self._handle_event = self.done # send mitmproxy HTTP events over the H3 connection @@ -165,7 +165,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: ) if h3_event.stream_ended: yield ReceiveHttp( - self.ReceiveEndOfMessage(stream_id=event.stream_id) + self.ReceiveEndOfMessage(stream_id=h3_event.stream_id) ) # report headers and trailers @@ -193,7 +193,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: yield ReceiveHttp(receive_event) if h3_event.stream_ended: yield ReceiveHttp( - self.ReceiveEndOfMessage(stream_id=event.stream_id) + self.ReceiveEndOfMessage(stream_id=h3_event.stream_id) ) # we don't support push, web transport, etc. @@ -266,7 +266,7 @@ def headers_received( authority, path, headers, - ) = parse_h2_request_headers(event) + ) = parse_h2_request_headers(event.headers) request = http.Request( host=host, port=port, diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 3ada9da3be..28db9c1a40 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -2,7 +2,7 @@ import asyncio from dataclasses import dataclass, field from ssl import VerifyMode -from typing import Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union from aioquic.buffer import Buffer as QuicBuffer from aioquic.quic import events as quic_events @@ -90,15 +90,17 @@ class QuicConnectionEvent(events.ConnectionEvent): event: quic_events.QuicEvent -@dataclass class QuicGetConnection(commands.ConnectionCommand): # -> QuicConnection blocking = True -@dataclass class QuicTransmit(commands.Command): connection: QuicConnection + def __init__(self, connection: QuicConnection) -> None: + super().__init__() + self.connection = connection + @dataclass(repr=False) class QuicGetConnectionCompleted(events.CommandCompleted): @@ -159,7 +161,7 @@ def server_handle_hello_replacement( for b in input_buf.pull_bytes(3): length = (length << 8) | b offset = input_buf.tell() - raise QuicClientHello(data=input_buf.data_slice(offset, offset + length)) + raise QuicClientHello(input_buf.data_slice(offset, offset + length)) def initialize_replacement(peer_cid: bytes) -> None: try: @@ -193,7 +195,9 @@ def __init__(self, context: context.Context) -> None: self.lookup_server = {} self.lookup_client = {} - def end_flow(self, flow: tcp.TCPFlow, event: quic_events.ConnectionTerminated) -> layer.CommandGenerator[None]: + def end_flow( + self, flow: tcp.TCPFlow, event: quic_events.ConnectionTerminated + ) -> layer.CommandGenerator[None]: if event.error_code == QuicErrorCode.NO_ERROR: yield tcp_layer.TcpEndHook(flow) else: @@ -227,36 +231,34 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: # forward close and end all flows if isinstance(quic_event, quic_events.ConnectionTerminated): quic_out.close( - error_code=quic_event.error_code, - frame_type=quic_event.frame_type, - reason_phrase=quic_event.reason_phrase, + quic_event.error_code, + quic_event.frame_type, + quic_event.reason_phrase, ) while lookup_in: stream_id_in = next(iter(lookup_in)) stream_id_out, flow = lookup_in[stream_id_in] - yield from self.end_flow(flow=flow, event=quic_event) + yield from self.end_flow(flow, quic_event) del lookup_in[stream_id_in] del lookup_out[stream_id_out] if self.datagram_flow is not None: - yield from self.end_flow(flow=flow, event=quic_event) + yield from self.end_flow(flow, quic_event) self.datagram_flow = None # forward datagrams (that are not stream-bound) elif isinstance(quic_event, quic_events.DatagramFrameReceived): if self.datagram_flow is None: self.datagram_flow = tcp.TCPFlow( - client_conn=self.context.client, - server_conn=self.context.server, + self.context.client, + self.context.server, live=True, ) yield tcp_layer.TcpStartHook(self.datagram_flow) - message = tcp.TCPMessage( - from_client=from_client, content=quic_event.data - ) + message = tcp.TCPMessage(from_client, quic_event.data) self.datagram_flow.messages.append(message) yield tcp_layer.TcpMessageHook(self.datagram_flow) - quic_out.send_datagram_frame(data=message.content) + quic_out.send_datagram_frame(message.content) # forward stream data elif isinstance(quic_event, quic_events.StreamDataReceived): @@ -267,8 +269,8 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: else: stream_id_out = quic_out.get_next_available_stream_id() flow = tcp.TCPFlow( - client_conn=self.context.client, - server_conn=self.context.server, + self.context.client, + self.context.server, live=True, ) lookup_in[stream_id_in] = (stream_id_out, flow) @@ -276,15 +278,13 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: yield tcp_layer.TcpStartHook(flow) # forward the message allowing addons to change it - message = tcp.TCPMessage( - from_client=from_client, content=quic_event.data - ) + message = tcp.TCPMessage(from_client, quic_event.data) flow.messages.append(message) yield tcp_layer.TcpMessageHook(flow) quic_out.send_stream_data( - stream_id=stream_id_out, - data=message.content, - end_stream=quic_event.end_stream, + stream_id_out, + message.content, + quic_event.end_stream, ) # end the flow and remove the lookup if the stream ended @@ -299,9 +299,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: stream_id_in = quic_event.stream_id if stream_id_in in lookup_in: stream_id_out, flow = lookup_in[stream_id_in] - quic_out.stop_stream( - stream_id=stream_id_out, error_code=quic_event.error_code - ) + quic_out.stop_stream(stream_id_out, quic_event.error_code) # try to get a name describing the reset reason try: @@ -316,9 +314,6 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: del lookup_in[stream_id_in] del lookup_out[stream_id_out] - def done(self, _) -> layer.CommandGenerator[None]: - yield from () - class _QuicLayer(layer.Layer): child_layer: layer.Layer @@ -338,10 +333,11 @@ def __init__( self.child_layer = layer.NextLayer(context) self.conn = conn self._loop = asyncio.get_event_loop() - self._get_connection_commands: List[QuicGetConnection] = [] + self._get_connection_commands: List[QuicGetConnection] = list() self._request_wakeup_command_and_timer: Optional[ Tuple[commands.RequestWakeup, float] ] = None + self._obsolete_wakeup_commands: Set[commands.RequestWakeup] = set() def build_configuration(self) -> QuicConfiguration: assert self.tls is not None @@ -364,8 +360,13 @@ def build_configuration(self) -> QuicConfiguration: ) def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: + yield from self.handle_child_commands(self.child_layer.handle_event(event)) + + def handle_child_commands( + self, child_commands: layer.CommandGenerator[None] + ) -> layer.CommandGenerator[None]: # filter commands coming from the child layer - for command in self.child_layer.handle_event(event): + for command in child_commands: # answer or queue requests for the aioquic connection instance if ( @@ -375,11 +376,8 @@ def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: if self.quic is None: self._get_connection_commands.append(command) else: - yield from self.child_layer.handle_event( - QuicGetConnectionCompleted( - command=command, - reply=self.quic, - ) + yield from self.event_to_child( + QuicGetConnectionCompleted(command, self.quic) ) # transmit buffered data and re-arm timer @@ -393,7 +391,7 @@ def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: ): reason = "CloseConnection command received." if self.quic is None: - yield from self.shutdown_connection(reason=reason, level="info") + yield from self.shutdown_connection(reason, level="info") else: self.quic.close(reason_phrase=reason) yield from self.process_events() @@ -413,7 +411,7 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: yield QuicTlsStartServerHook(tls_data) if tls_data.settings is None: yield from self.shutdown_connection( - reason="No TLS settings were provided, failing connection.", + "No TLS settings were provided, failing connection.", level="error", ) return @@ -431,10 +429,9 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: # let the waiters know about the available connection while self._get_connection_commands: assert self.quic is not None - yield from self.child_layer.handle_event( + yield from self.event_to_child( QuicGetConnectionCompleted( - command=self._get_connection_commands.pop(), - reply=self.quic, + self._get_connection_commands.pop(), self.quic ) ) @@ -455,7 +452,7 @@ def process_events(self) -> layer.CommandGenerator[None]: elif isinstance(event, quic_events.ConnectionTerminated): yield from self.shutdown_connection( - reason=event.reason_phrase or str(event.error_code), + event.reason_phrase or str(event.error_code), level=( "info" if event.error_code == QuicErrorCode.NO_ERROR else "warn" ), @@ -479,9 +476,7 @@ def process_events(self) -> layer.CommandGenerator[None]: self.conn.tls_version = "QUIC" # report the success to addons - tls_data = QuicTlsData( - conn=self.conn, context=self.context, settings=self.tls - ) + tls_data = QuicTlsData(self.conn, self.context, settings=self.tls) if self.conn is self.context.client: yield layers.tls.TlsEstablishedClientHook(tls_data) else: @@ -489,12 +484,10 @@ def process_events(self) -> layer.CommandGenerator[None]: # perform next layer decisions now if isinstance(self.child_layer, layer.NextLayer): - yield from self.child_layer._ask() + yield from self.handle_child_commands(self.child_layer._ask()) # forward the event as a QuicConnectionEvent to the child layer - yield from self.event_to_child( - QuicConnectionEvent(connection=self.conn, event=event) - ) + yield from self.event_to_child(QuicConnectionEvent(self.conn, event)) # handle the next event event = self.quic.next_event() @@ -513,18 +506,14 @@ def shutdown_connection( # report as TLS failure if the termination happened before the handshake if not self.conn.tls_established and self.tls is not None: self.conn.error = reason - tls_data = QuicTlsData( - conn=self.conn, context=self.context, settings=self.tls - ) + tls_data = QuicTlsData(self.conn, self.context, settings=self.tls) if self.conn is self.context.client: yield layers.tls.TlsFailedClientHook(tls_data) else: yield layers.tls.TlsFailedServerHook(tls_data) # log the reason, ensure the connection is closed and no longer handle events - yield commands.Log( - message=f"Connection {self.conn} closed: {reason}", level=level - ) + yield commands.Log(f"Connection {self.conn} closed: {reason}", level=level) if self.conn.connected: yield commands.CloseConnection(self.conn) self._handle_event = self.state_done @@ -538,7 +527,7 @@ def state_start(self, event: events.Event) -> layer.CommandGenerator[None]: # start this layer and the child layer yield from self.start() - yield from self.child_layer.handle_event(event) + yield from self.event_to_child(event) def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: assert self.quic is not None @@ -547,7 +536,7 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: if isinstance(event, events.DataReceived) and event.connection is self.conn: assert event.remote_addr is not None self.quic.receive_datagram( - data=event.data, addr=event.remote_addr, now=self._loop.time() + event.data, event.remote_addr, now=self._loop.time() ) yield from self.process_events() return @@ -563,24 +552,32 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: self.quic._set_state(QuicConnectionState.TERMINATED) yield from self.event_to_child( QuicConnectionEvent( - connection=self.conn, - event=quic_events.ConnectionTerminated(reason_phrase=reason), + self.conn, + quic_events.ConnectionTerminated( + error_code=QuicErrorCode.APPLICATION_ERROR, + frame_type=None, + reason_phrase=reason, + ), ) ) - yield from self.shutdown_connection(reason=reason, level="info") + yield from self.shutdown_connection(reason, level="info") # intercept wakeup events for aioquic - elif ( - isinstance(event, events.Wakeup) - and self._request_wakeup_command_and_timer is not None - ): - command, timer = self._request_wakeup_command_and_timer - if event.command is command: - self._request_wakeup_command_and_timer = None - self.quic.handle_timer(now=max(timer, self._loop.time())) - yield from self.process_events() + elif isinstance(event, events.Wakeup): + # swallow obsolete wakeups + if event.command in self._obsolete_wakeup_commands: + self._obsolete_wakeup_commands.remove(event.command) return + # handle active wakeup + elif self._request_wakeup_command_and_timer is not None: + command, timer = self._request_wakeup_command_and_timer + if event.command is command: + self._request_wakeup_command_and_timer = None + self.quic.handle_timer(now=max(timer, self._loop.time())) + yield from self.process_events() + return + # forward other events to the child layer yield from self.event_to_child(event) @@ -593,17 +590,21 @@ def transmit(self) -> layer.CommandGenerator[None]: # send all queued datagrams for data, addr in self.quic.datagrams_to_send(now=self._loop.time()): - yield commands.SendData(connection=self.conn, data=data, remote_addr=addr) + yield commands.SendData(self.conn, data, addr) # ensure the wakeup is set and still correct timer = self.quic.get_timer() if timer is None: - self._request_wakeup_command_and_timer = None + if self._request_wakeup_command_and_timer is not None: + command, _ = self._request_wakeup_command_and_timer + self._obsolete_wakeup_commands.add(command) + self._request_wakeup_command_and_timer = None else: if self._request_wakeup_command_and_timer is not None: - _, existing_timer = self._request_wakeup_command_and_timer + command, existing_timer = self._request_wakeup_command_and_timer if existing_timer == timer: return + self._obsolete_wakeup_commands.add(command) command = commands.RequestWakeup(timer - self._loop.time()) self._request_wakeup_command_and_timer = (command, timer) yield command @@ -620,20 +621,10 @@ def __init__(self, context: context.Context) -> None: super().__init__(context, context.server) def start(self) -> layer.CommandGenerator[None]: - # ensure there is an UDP connection - if not self.conn.connected: - err = yield commands.OpenConnection(self.conn) - if err is not None: - self.shutdown_connection( - reason=f"Failed to connect: {err}", - level="warn", - ) - return - # try to connect yield from self.initialize_connection() if self.quic is not None: - self.quic.connect(addr=self.conn.peername, now=self._loop.time()) + self.quic.connect(self.conn.peername, now=self._loop.time()) yield from self.process_events() @@ -658,11 +649,7 @@ def initialize_connection_and_flush_buffer(self) -> layer.CommandGenerator[None] yield from self.initialize_connection() if self.quic is not None: for data, addr, now in self.buffered_packets: - self.quic.receive_datagram( - data=data, - addr=addr, - now=now, - ) + self.quic.receive_datagram(data, addr, now) yield from self.process_events() def start(self) -> layer.CommandGenerator[None]: @@ -688,7 +675,7 @@ def state_wait_for_upstream( elif isinstance(event, events.ConnectionClosed): if event.connection is self.conn: yield from self.shutdown_connection( - reason="Client UDP connection timeout out before upstream server handshake completed.", + "Client UDP connection timeout out before upstream server handshake completed.", level="info", ) elif event.connection is self.context.server: @@ -728,60 +715,80 @@ def __init__( def build_client_layer( self, connection_id: bytes, wait_for_upstream: bool ) -> ClientQuicLayer: - layer = ClientQuicLayer( - context=self.context, wait_for_upstream=wait_for_upstream - ) + layer = ClientQuicLayer(self.context, wait_for_upstream) layer.original_destination_connection_id = connection_id layer.issue_connection_id_callback = self._issue_cid layer.retire_connection_id_callback = self._retire_cid return layer + def done(self, event: events.Event) -> layer.CommandGenerator[None]: + yield from () + def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: + if isinstance(event, events.Start): + pass + # only handle the first packet from the client - if ( - not isinstance(event, events.DataReceived) - or event.connection is not self.context.client + elif ( + isinstance(event, events.DataReceived) + and event.connection is self.context.client ): - return - - # extract the client hello - try: - client_hello, connection_id = pull_client_hello_and_connection_id( - event.data - ) - except ValueError as e: - yield commands.Log( - f"Cannot parse ClientHello: {str(e)} ({event.data.hex()})" - ) - yield commands.CloseConnection(self.context.client) - return - - # copy the information - self.context.client.sni = client_hello.sni - self.context.client.alpn_offers = client_hello.alpn_protocols + # extract the client hello + try: + client_hello, connection_id = pull_client_hello_and_connection_id( + event.data + ) + except ValueError as e: + yield commands.Log( + f"Cannot parse ClientHello: {str(e)} ({event.data.hex()})" + ) + yield commands.CloseConnection(self.context.client) + self._handle_event = self.done + else: - # check with addons what we shall do - next_layer: layer.Layer - hook_data = ClientHelloData(self.context, client_hello) - yield layers.tls.TlsClienthelloHook(hook_data) + # copy the information + self.context.client.sni = client_hello.sni + self.context.client.alpn_offers = client_hello.alpn_protocols + + # check with addons what we shall do + next_layer: layer.Layer + hook_data = ClientHelloData(self.context, client_hello) + yield layers.tls.TlsClienthelloHook(hook_data) + + # simply relay everything + if hook_data.ignore_connection: + next_layer = layers.TCPLayer(self.context, ignore=True) + + # contact the upstream server first + elif hook_data.establish_server_tls_first: + err = yield commands.OpenConnection(self.context.server) + if err is None: + next_layer = ServerQuicLayer(self.context) + next_layer.child_layer = self.build_client_layer( + connection_id, + wait_for_upstream=True, + ) + else: + yield commands.Log( + f"Failed to connect to upstream first (will continue with client anyway): {err}" + ) + next_layer = self.build_client_layer( + connection_id, + wait_for_upstream=False, + ) - # simply relay everything - if hook_data.ignore_connection: - next_layer = layers.TCPLayer(self.context, ignore=True) + # perform the client handshake immediately + else: + next_layer = self.build_client_layer( + connection_id, + wait_for_upstream=False, + ) - # contact the upstream server first - elif hook_data.establish_server_tls_first: - next_layer = ServerQuicLayer(self.context) - next_layer.child_layer = self.build_client_layer( - connection_id, wait_for_upstream=True - ) + # replace this layer and start the next one + self.handle_event = next_layer.handle_event + self._handle_event = next_layer._handle_event + yield from next_layer.handle_event(events.Start()) + yield from next_layer.handle_event(event) - # perform the client handshake immediately else: - next_layer = self.build_client_layer(connection_id, wait_for_upstream=False) - - # replace this layer and start the next one - self.handle_event = next_layer.handle_event - self._handle_event = next_layer._handle_event - yield from next_layer.handle_event(events.Start()) - yield from next_layer.handle_event(event) + raise AssertionError(f"Unexpected event: {event}") From 2426d3d03847e0273707436268d79c24616b3e74 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Wed, 22 Jun 2022 06:16:10 +0200 Subject: [PATCH 022/529] [quic] bugfixes and simplified connection opening --- mitmproxy/proxy/layers/http/__init__.py | 10 +- mitmproxy/proxy/layers/quic.py | 231 +++++++++++------------- 2 files changed, 108 insertions(+), 133 deletions(-) diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index 81ac1c9189..6b887269f1 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -1001,7 +1001,10 @@ def get_connection( if not can_use_context_connection: - context.server = Server(event.address, transport_protocol=context.client.transport_protocol) + context.server = Server(event.address) + if isinstance(context.layers[0], quic.QuicLayer): + context.server.transport_protocol = "udp" + stack /= quic.ServerQuicLayer(context) if event.via: context.server.via = event.via @@ -1019,10 +1022,7 @@ def get_connection( context.server.sni = self.context.client.sni or event.address[0] else: context.server.sni = event.address[0] - if context.server.transport_protocol == "udp": - stack /= quic.ServerQuicLayer(context) - else: - stack /= tls.ServerTLSLayer(context) + stack /= tls.ServerTLSLayer(context) stack /= HttpClient(context) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 28db9c1a40..7922d8ca6c 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -1,4 +1,3 @@ -from abc import abstractmethod import asyncio from dataclasses import dataclass, field from ssl import VerifyMode @@ -21,6 +20,7 @@ from mitmproxy.net import tls from mitmproxy.proxy import commands, context, events, layer, layers from mitmproxy.proxy.layers import tcp as tcp_layer +from mitmproxy.proxy.utils import expect from mitmproxy.tls import ClientHello, ClientHelloData, TlsData @@ -90,10 +90,16 @@ class QuicConnectionEvent(events.ConnectionEvent): event: quic_events.QuicEvent -class QuicGetConnection(commands.ConnectionCommand): # -> QuicConnection +class QuicGetConnection(commands.ConnectionCommand): # -> Optional[QuicConnection] blocking = True +@dataclass(repr=False) +class QuicGetConnectionCompleted(events.CommandCompleted): + command: QuicGetConnection + reply: Optional[QuicConnection] + + class QuicTransmit(commands.Command): connection: QuicConnection @@ -102,12 +108,6 @@ def __init__(self, connection: QuicConnection) -> None: self.connection = connection -@dataclass(repr=False) -class QuicGetConnectionCompleted(events.CommandCompleted): - command: QuicGetConnection - reply: QuicConnection - - class QuicSecretsLogger: logger: tls.MasterSecretLogger @@ -333,7 +333,7 @@ def __init__( self.child_layer = layer.NextLayer(context) self.conn = conn self._loop = asyncio.get_event_loop() - self._get_connection_commands: List[QuicGetConnection] = list() + self._pending_open_command: Optional[commands.OpenConnection] = None self._request_wakeup_command_and_timer: Optional[ Tuple[commands.RequestWakeup, float] ] = None @@ -368,22 +368,35 @@ def handle_child_commands( # filter commands coming from the child layer for command in child_commands: - # answer or queue requests for the aioquic connection instance + # answer with the aioquic connection instance if ( isinstance(command, QuicGetConnection) and command.connection is self.conn ): - if self.quic is None: - self._get_connection_commands.append(command) - else: - yield from self.event_to_child( - QuicGetConnectionCompleted(command, self.quic) - ) + yield from self.event_to_child( + QuicGetConnectionCompleted(command, self.quic) + ) # transmit buffered data and re-arm timer elif isinstance(command, QuicTransmit) and command.connection is self.quic: yield from self.transmit() + # open the QUIC connection + elif ( + isinstance(command, commands.OpenConnection) + and command.connection is self.conn + ): + assert self._pending_open_command is None + self._pending_open_command = command + err = yield commands.OpenConnection(self.conn) + if err is None: + yield from self.initialize_connection() + if self.quic is not None: + self.quic.connect(self.conn.peername, now=self._loop.time()) + yield from self.process_events() + else: + yield from self.shutdown_connection(err, level="warn") + # properly close QUIC connections elif ( isinstance(command, commands.CloseConnection) @@ -426,15 +439,6 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: self.issue_connection_id_callback(self.quic.host_cid) self._handle_event = self.state_ready - # let the waiters know about the available connection - while self._get_connection_commands: - assert self.quic is not None - yield from self.event_to_child( - QuicGetConnectionCompleted( - self._get_connection_commands.pop(), self.quic - ) - ) - def process_events(self) -> layer.CommandGenerator[None]: assert self.quic is not None assert self.tls is not None @@ -482,6 +486,15 @@ def process_events(self) -> layer.CommandGenerator[None]: else: yield layers.tls.TlsEstablishedServerHook(tls_data) + # let the child layer know + if self._pending_open_command is not None: + yield from self.event_to_child( + events.OpenConnectionCompleted( + self._pending_open_command, reply=None + ) + ) + self._pending_open_command = None + # perform next layer decisions now if isinstance(self.child_layer, layer.NextLayer): yield from self.handle_child_commands(self.child_layer._ask()) @@ -512,22 +525,28 @@ def shutdown_connection( else: yield layers.tls.TlsFailedServerHook(tls_data) - # log the reason, ensure the connection is closed and no longer handle events + # log the reason and ensure the connection gets closed yield commands.Log(f"Connection {self.conn} closed: {reason}", level=level) if self.conn.connected: yield commands.CloseConnection(self.conn) - self._handle_event = self.state_done - - @abstractmethod - def start(self) -> layer.CommandGenerator[None]: - yield from () # pragma: no cover - def state_start(self, event: events.Event) -> layer.CommandGenerator[None]: - assert isinstance(event, events.Start) + # let the child layer know and stop handling events + if self._pending_open_command is not None: + yield from self.event_to_child( + events.OpenConnectionCompleted(self._pending_open_command, reason) + ) + self._pending_open_command = None + self._handle_event = self.state_done - # start this layer and the child layer - yield from self.start() - yield from self.event_to_child(event) + def state_done(self, event: events.Event) -> layer.CommandGenerator[None]: + # when done, just forward the event to the child layer (except for obsolete wakeups) + if ( + isinstance(event, events.Wakeup) + and event.command in self._obsolete_wakeup_commands + ): + self._obsolete_wakeup_commands.remove(event.command) + else: + yield from self.child_layer.handle_event(event) def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: assert self.quic is not None @@ -570,7 +589,7 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: return # handle active wakeup - elif self._request_wakeup_command_and_timer is not None: + if self._request_wakeup_command_and_timer is not None: command, timer = self._request_wakeup_command_and_timer if event.command is command: self._request_wakeup_command_and_timer = None @@ -581,30 +600,27 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: # forward other events to the child layer yield from self.event_to_child(event) - def state_done(self, event: events.Event) -> layer.CommandGenerator[None]: - # when done, just forward the event - yield from self.child_layer.handle_event(event) + def state_start(self, event: events.Event) -> layer.CommandGenerator[None]: + # wait for the state to change and inspect the child layer's commands + yield from self.event_to_child(event) def transmit(self) -> layer.CommandGenerator[None]: - assert self.quic + assert self.quic is not None # send all queued datagrams for data, addr in self.quic.datagrams_to_send(now=self._loop.time()): yield commands.SendData(self.conn, data, addr) - # ensure the wakeup is set and still correct + # mark an existing wakeup command as obsolete if it now longer matches the time timer = self.quic.get_timer() - if timer is None: - if self._request_wakeup_command_and_timer is not None: - command, _ = self._request_wakeup_command_and_timer + if self._request_wakeup_command_and_timer is not None: + command, existing_timer = self._request_wakeup_command_and_timer + if existing_timer != timer: self._obsolete_wakeup_commands.add(command) self._request_wakeup_command_and_timer = None - else: - if self._request_wakeup_command_and_timer is not None: - command, existing_timer = self._request_wakeup_command_and_timer - if existing_timer == timer: - return - self._obsolete_wakeup_commands.add(command) + + # request a new wakeup if necessary + if timer is not None and self._request_wakeup_command_and_timer is None: command = commands.RequestWakeup(timer - self._loop.time()) self._request_wakeup_command_and_timer = (command, timer) yield command @@ -620,20 +636,13 @@ class ServerQuicLayer(_QuicLayer): def __init__(self, context: context.Context) -> None: super().__init__(context, context.server) - def start(self) -> layer.CommandGenerator[None]: - # try to connect - yield from self.initialize_connection() - if self.quic is not None: - self.quic.connect(self.conn.peername, now=self._loop.time()) - yield from self.process_events() - class ClientQuicLayer(_QuicLayer): """ This layer establishes QUIC on a single client connection. """ - buffered_packets: Optional[List[Tuple[bytes, connection.Address, float]]] + wait_for_upstream: bool def __init__( self, @@ -641,60 +650,26 @@ def __init__( wait_for_upstream: bool, ) -> None: super().__init__(context, context.client) - self.buffered_packets = [] if wait_for_upstream else None - - def initialize_connection_and_flush_buffer(self) -> layer.CommandGenerator[None]: - assert self.buffered_packets is not None - - yield from self.initialize_connection() - if self.quic is not None: - for data, addr, now in self.buffered_packets: - self.quic.receive_datagram(data, addr, now) - yield from self.process_events() - - def start(self) -> layer.CommandGenerator[None]: - if self.buffered_packets is None: - yield from self.initialize_connection() - else: - self._handle_event = self.state_wait_for_upstream + self.wait_for_upstream = wait_for_upstream + @expect(events.Start) def state_wait_for_upstream( self, event: events.Event ) -> layer.CommandGenerator[None]: - assert self.buffered_packets is not None + self._handle_event = self.state_start - # buffer incoming packets until the upstream handshake completed - if isinstance(event, events.DataReceived) and event.connection is self.conn: - assert event.remote_addr is not None - self.buffered_packets.append( - (event.data, event.remote_addr, self._loop.time()) - ) - return - - # watch for closed connections on both legs - elif isinstance(event, events.ConnectionClosed): - if event.connection is self.conn: - yield from self.shutdown_connection( - "Client UDP connection timeout out before upstream server handshake completed.", - level="info", - ) - elif event.connection is self.context.server: + # open the upstream connection if possible, but always initialize the client connection + if self.wait_for_upstream: + err = yield commands.OpenConnection(self.context.server) + if err is not None: yield commands.Log( - f"Unable to establish QUIC connection with server ({self.context.server.error or 'Connection closed.'}). " + f"Unable to establish QUIC connection with server ({err}). " f"Trying to establish QUIC with client anyway." ) - yield from self.initialize_connection_and_flush_buffer() + yield from self.initialize_connection() + yield from self.state_start(event) - # continue if upstream completed the handshake - elif ( - isinstance(event, QuicConnectionEvent) - and event.connection is self.context.server - and isinstance(event.event, quic_events.HandshakeCompleted) - ): - yield from self.initialize_connection_and_flush_buffer() - - # forward other events to the child layer - yield from self.event_to_child(event) + _handle_event = state_wait_for_upstream class QuicLayer(layer.Layer): @@ -721,18 +696,22 @@ def build_client_layer( layer.retire_connection_id_callback = self._retire_cid return layer - def done(self, event: events.Event) -> layer.CommandGenerator[None]: + @expect(events.DataReceived, events.ConnectionClosed) + def state_done(self, _) -> layer.CommandGenerator[None]: yield from () - def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: - if isinstance(event, events.Start): - pass + @expect(events.Start) + def state_start(self, _) -> layer.CommandGenerator[None]: + self._handle_event = self.state_wait_for_hello + yield from () + + @expect(events.DataReceived, events.ConnectionClosed) + def state_wait_for_hello(self, event: events.Event) -> layer.CommandGenerator[None]: + assert isinstance(event, events.ConnectionEvent) + assert event.connection is self.context.client # only handle the first packet from the client - elif ( - isinstance(event, events.DataReceived) - and event.connection is self.context.client - ): + if isinstance(event, events.DataReceived): # extract the client hello try: client_hello, connection_id = pull_client_hello_and_connection_id( @@ -743,7 +722,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: f"Cannot parse ClientHello: {str(e)} ({event.data.hex()})" ) yield commands.CloseConnection(self.context.client) - self._handle_event = self.done + self._handle_event = self.state_done else: # copy the information @@ -761,21 +740,11 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: # contact the upstream server first elif hook_data.establish_server_tls_first: - err = yield commands.OpenConnection(self.context.server) - if err is None: - next_layer = ServerQuicLayer(self.context) - next_layer.child_layer = self.build_client_layer( - connection_id, - wait_for_upstream=True, - ) - else: - yield commands.Log( - f"Failed to connect to upstream first (will continue with client anyway): {err}" - ) - next_layer = self.build_client_layer( - connection_id, - wait_for_upstream=False, - ) + next_layer = ServerQuicLayer(self.context) + next_layer.child_layer = self.build_client_layer( + connection_id, + wait_for_upstream=True, + ) # perform the client handshake immediately else: @@ -790,5 +759,11 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: yield from next_layer.handle_event(events.Start()) yield from next_layer.handle_event(event) + # stop if the connection was closed (usually we will always get one packet) + elif isinstance(event, events.ConnectionClosed): + self._handle_event = self.state_done + else: raise AssertionError(f"Unexpected event: {event}") + + _handle_event = state_start From 1426dec45e360f790367a9ce851a4b75779cdee6 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Wed, 22 Jun 2022 17:34:17 +0200 Subject: [PATCH 023/529] [quic] use state machine expect --- mitmproxy/proxy/layers/quic.py | 194 +++++++++++++++++++++------------ 1 file changed, 127 insertions(+), 67 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 7922d8ca6c..daae51e05c 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -386,35 +386,47 @@ def handle_child_commands( isinstance(command, commands.OpenConnection) and command.connection is self.conn ): + # ensure only one open command at a time for uninitialized connections + assert self.quic is None assert self._pending_open_command is None self._pending_open_command = command + + # try to open the underlying UDP connection err = yield commands.OpenConnection(self.conn) if err is None: - yield from self.initialize_connection() - if self.quic is not None: + # open succeeded, now try to initialize QUIC and connect + if (yield from self.initialize_connection()): + assert self.quic is not None self.quic.connect(self.conn.peername, now=self._loop.time()) yield from self.process_events() - else: - yield from self.shutdown_connection(err, level="warn") + else: + err = "initialize QUIC failed" + yield commands.CloseConnection(self.conn) + if err is not None: + # notify the child immediately + self._pending_open_command = None + yield from self.event_to_child( + events.OpenConnectionCompleted(command, err) + ) # properly close QUIC connections elif ( isinstance(command, commands.CloseConnection) and command.connection is self.conn ): - reason = "CloseConnection command received." if self.quic is None: - yield from self.shutdown_connection(reason, level="info") + yield command else: - self.quic.close(reason_phrase=reason) + self.quic.close(reason_phrase="CloseConnection command received.") yield from self.process_events() # return other commands else: yield command - def initialize_connection(self) -> layer.CommandGenerator[None]: + def initialize_connection(self) -> layer.CommandGenerator[bool]: assert self.quic is None + assert self.tls is None # query addons to provide the necessary TLS settings tls_data = QuicTlsData(self.conn, self.context) @@ -423,21 +435,18 @@ def initialize_connection(self) -> layer.CommandGenerator[None]: else: yield QuicTlsStartServerHook(tls_data) if tls_data.settings is None: - yield from self.shutdown_connection( - "No TLS settings were provided, failing connection.", - level="error", - ) - return - self.tls = tls_data.settings + return False # create the aioquic connection + self.tls = tls_data.settings self.quic = QuicConnection( configuration=self.build_configuration(), original_destination_connection_id=self.original_destination_connection_id, ) if self.issue_connection_id_callback is not None: self.issue_connection_id_callback(self.quic.host_cid) - self._handle_event = self.state_ready + self._handle_event = self.state_connected + return True def process_events(self) -> layer.CommandGenerator[None]: assert self.quic is not None @@ -455,12 +464,24 @@ def process_events(self) -> layer.CommandGenerator[None]: self.retire_connection_id_callback(event.connection_id) elif isinstance(event, quic_events.ConnectionTerminated): + # only forward the event if the connection has been properly initialized + if self.conn.tls_established: + yield from self.event_to_child( + QuicConnectionEvent(self.conn, event) + ) + + # shutdown and close the connection yield from self.shutdown_connection( event.reason_phrase or str(event.error_code), level=( "info" if event.error_code == QuicErrorCode.NO_ERROR else "warn" ), ) + yield commands.CloseConnection(self.conn) + + elif isinstance(event, quic_events.ProtocolNegotiated): + # too early, we act on HandshakeCompleted + pass elif isinstance(event, quic_events.HandshakeCompleted): # concatenate all peer certificates @@ -488,19 +509,19 @@ def process_events(self) -> layer.CommandGenerator[None]: # let the child layer know if self._pending_open_command is not None: + command = self._pending_open_command + self._pending_open_command = None yield from self.event_to_child( - events.OpenConnectionCompleted( - self._pending_open_command, reply=None - ) + events.OpenConnectionCompleted(command, reply=None) ) - self._pending_open_command = None # perform next layer decisions now if isinstance(self.child_layer, layer.NextLayer): yield from self.handle_child_commands(self.child_layer._ask()) - # forward the event as a QuicConnectionEvent to the child layer - yield from self.event_to_child(QuicConnectionEvent(self.conn, event)) + else: + # forward the event as a QuicConnectionEvent to the child layer + yield from self.event_to_child(QuicConnectionEvent(self.conn, event)) # handle the next event event = self.quic.next_event() @@ -514,10 +535,18 @@ def shutdown_connection( level: Literal["error", "warn", "info", "alert", "debug"], ) -> layer.CommandGenerator[None]: # ensure QUIC has been properly shut down - assert self.quic is None or self.quic._state is QuicConnectionState.TERMINATED + assert self.quic is not None + assert self.tls is not None + assert self.quic._state is QuicConnectionState.TERMINATED + + # obsolete any current timer + if self._request_wakeup_command_and_timer is not None: + command, _ = self._request_wakeup_command_and_timer + self._obsolete_wakeup_commands.add(command) + self._request_wakeup_command_and_timer = None # report as TLS failure if the termination happened before the handshake - if not self.conn.tls_established and self.tls is not None: + if not self.conn.tls_established: self.conn.error = reason tls_data = QuicTlsData(self.conn, self.context, settings=self.tls) if self.conn is self.context.client: @@ -525,30 +554,27 @@ def shutdown_connection( else: yield layers.tls.TlsFailedServerHook(tls_data) - # log the reason and ensure the connection gets closed - yield commands.Log(f"Connection {self.conn} closed: {reason}", level=level) - if self.conn.connected: - yield commands.CloseConnection(self.conn) + # make a log entry directly + yield commands.Log( + f"QUIC connection {self.conn} shutdown: {reason}", level=level + ) # let the child layer know and stop handling events + # we also don't handle any commands from the child at this point anymore if self._pending_open_command is not None: - yield from self.event_to_child( + yield from self.child_layer.handle_event( events.OpenConnectionCompleted(self._pending_open_command, reason) ) self._pending_open_command = None self._handle_event = self.state_done - def state_done(self, event: events.Event) -> layer.CommandGenerator[None]: - # when done, just forward the event to the child layer (except for obsolete wakeups) - if ( - isinstance(event, events.Wakeup) - and event.command in self._obsolete_wakeup_commands - ): - self._obsolete_wakeup_commands.remove(event.command) - else: - yield from self.child_layer.handle_event(event) + def start(self) -> layer.CommandGenerator[None]: + yield from self.event_to_child(events.Start()) - def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: + @expect( + events.ConnectionClosed, events.DataReceived, events.Wakeup, QuicConnectionEvent + ) + def state_connected(self, event: events.Event) -> layer.CommandGenerator[None]: assert self.quic is not None # forward incoming data only to aioquic @@ -558,17 +584,16 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: event.data, event.remote_addr, now=self._loop.time() ) yield from self.process_events() - return # handle connections closed by peer elif ( isinstance(event, events.ConnectionClosed) and event.connection is self.conn ): reason = "Peer UDP connection timed out." - if self.quic is not None: - # there is no point in calling quic.close, as it cannot send packets anymore - # so we simply set the state and simulate a ConnectionTerminated event - self.quic._set_state(QuicConnectionState.TERMINATED) + # there is no point in calling quic.close, as it cannot send packets anymore + # so we simply set the state and simulate a ConnectionTerminated event + self.quic._set_state(QuicConnectionState.TERMINATED) + if self.conn.tls_established: yield from self.event_to_child( QuicConnectionEvent( self.conn, @@ -579,6 +604,10 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: ), ) ) + + # forward the event only if no open command is pending and shutdown + if self._pending_open_command is None: + yield from self.event_to_child(event) yield from self.shutdown_connection(reason, level="info") # intercept wakeup events for aioquic @@ -586,23 +615,52 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: # swallow obsolete wakeups if event.command in self._obsolete_wakeup_commands: self._obsolete_wakeup_commands.remove(event.command) - return - - # handle active wakeup - if self._request_wakeup_command_and_timer is not None: - command, timer = self._request_wakeup_command_and_timer - if event.command is command: - self._request_wakeup_command_and_timer = None - self.quic.handle_timer(now=max(timer, self._loop.time())) - yield from self.process_events() - return + else: + # handle active wakeup and forward others to child layer + if self._request_wakeup_command_and_timer is not None: + command, timer = self._request_wakeup_command_and_timer + if event.command is command: + self._request_wakeup_command_and_timer = None + self.quic.handle_timer(now=max(timer, self._loop.time())) + yield from self.process_events() + else: + yield from self.event_to_child(event) + else: + yield from self.event_to_child(event) + + else: + # forward other events to the child layer + yield from self.event_to_child(event) - # forward other events to the child layer - yield from self.event_to_child(event) + @expect( + events.ConnectionClosed, events.DataReceived, events.Wakeup, QuicConnectionEvent + ) + def state_done(self, event: events.Event) -> layer.CommandGenerator[None]: + # filter out obsolete wakeups + if ( + isinstance(event, events.Wakeup) + and event.command in self._obsolete_wakeup_commands + ): + self._obsolete_wakeup_commands.remove(event.command) - def state_start(self, event: events.Event) -> layer.CommandGenerator[None]: - # wait for the state to change and inspect the child layer's commands - yield from self.event_to_child(event) + # ignore any further received data + elif isinstance(event, events.DataReceived) and event.connection is self.conn: + pass + + # forward all other events + else: + yield from self.child_layer.handle_event(event) + + @expect( + events.ConnectionClosed, events.DataReceived, events.Wakeup, QuicConnectionEvent + ) + def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: + yield from () + + @expect(events.Start) + def state_start(self, _) -> layer.CommandGenerator[None]: + self._handle_event = self.state_ready + yield from self.start() def transmit(self) -> layer.CommandGenerator[None]: assert self.quic is not None @@ -652,13 +710,10 @@ def __init__( super().__init__(context, context.client) self.wait_for_upstream = wait_for_upstream - @expect(events.Start) - def state_wait_for_upstream( - self, event: events.Event - ) -> layer.CommandGenerator[None]: - self._handle_event = self.state_start + def start(self) -> layer.CommandGenerator[None]: + yield from super().start() - # open the upstream connection if possible, but always initialize the client connection + # try to open the upstream connection if self.wait_for_upstream: err = yield commands.OpenConnection(self.context.server) if err is not None: @@ -666,10 +721,15 @@ def state_wait_for_upstream( f"Unable to establish QUIC connection with server ({err}). " f"Trying to establish QUIC with client anyway." ) - yield from self.initialize_connection() - yield from self.state_start(event) - _handle_event = state_wait_for_upstream + # is still connected then initialize, close on failure + if not self.conn.connected and not (yield from self.initialize_connection()): + yield commands.CloseConnection(self.conn) + self._handle_event = self.state_failed + + @expect(events.ConnectionClosed, events.DataReceived, QuicConnectionEvent) + def state_failed(self, _) -> layer.CommandGenerator[None]: + yield from () class QuicLayer(layer.Layer): From b9afc502a699bfc66e14fe439dd1902727ebf36c Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Thu, 23 Jun 2022 01:18:06 +0200 Subject: [PATCH 024/529] [quic] rework states and add more comments --- mitmproxy/proxy/layers/quic.py | 349 +++++++++++++++++++-------------- 1 file changed, 204 insertions(+), 145 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index daae51e05c..0a6225ed13 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -359,6 +359,117 @@ def build_configuration(self) -> QuicConfiguration: verify_mode=self.tls.verify_mode, ) + def create_quic(self) -> layer.CommandGenerator[bool]: + # must only be called if QUIC is uninitialized + assert self.quic is None + assert self.tls is None + + # in case the connection is being reused, clear all handshake data + self.conn.timestamp_tls_setup = None + self.conn.certificate_list = () + self.conn.alpn = None + self.conn.cipher = None + self.conn.tls_version = None + + # query addons to provide the necessary TLS settings + tls_data = QuicTlsData(self.conn, self.context) + if self.conn is self.context.client: + yield QuicTlsStartClientHook(tls_data) + else: + yield QuicTlsStartServerHook(tls_data) + if tls_data.settings is None: + yield commands.Log( + f"{self.conn}: No QUIC TLS settings provided by addon(s).", + level="error", + ) + return False + + # create the aioquic connection + self.tls = tls_data.settings + self.quic = QuicConnection( + configuration=self.build_configuration(), + original_destination_connection_id=self.original_destination_connection_id, + ) + self._handle_event = self.state_has_quic + + # issue the host connection ID right away + if self.issue_connection_id_callback is not None: + self.issue_connection_id_callback(self.quic.host_cid) + + # record an entry in the log + yield commands.Log(f"{self.conn}: QUIC connection created.", level="info") + return True + + def destroy_quic( + self, + reason: str, + level: Literal["error", "warn", "info", "alert", "debug"], + ) -> layer.CommandGenerator[None]: + # ensure QUIC has been properly shut down + assert self.quic is not None + assert self.tls is not None + assert self.quic._state is QuicConnectionState.TERMINATED + + # report as TLS failure if the termination happened before the handshake + if not self.conn.tls_established: + self.conn.error = reason + tls_data = QuicTlsData(self.conn, self.context, settings=self.tls) + if self.conn is self.context.client: + yield layers.tls.TlsFailedClientHook(tls_data) + else: + yield layers.tls.TlsFailedServerHook(tls_data) + + # clear the quic fields + self.quic = None + self.tls = None + self._handle_event = self.state_no_quic + + # obsolete any current timer + if self._request_wakeup_command_and_timer is not None: + command, _ = self._request_wakeup_command_and_timer + self._obsolete_wakeup_commands.add(command) + self._request_wakeup_command_and_timer = None + + # record an entry in the log + yield commands.Log( + f"{self.conn}: QUIC connection destroyed: {reason}", level=level + ) + + def establish_quic( + self, event: quic_events.HandshakeCompleted + ) -> layer.CommandGenerator[None]: + # must only be called if QUIC is initialized + assert self.quic is not None + assert self.tls is not None + + # concatenate all peer certificates + all_certs = [] + if self.quic.tls._peer_certificate is not None: + all_certs.append(self.quic.tls._peer_certificate) + if self.quic.tls._peer_certificate_chain is not None: + all_certs.extend(self.quic.tls._peer_certificate_chain) + + # set the connection's TLS properties + self.conn.timestamp_tls_setup = self._loop.time() + self.conn.certificate_list = [certs.Cert.from_pyopenssl(x) for x in all_certs] + self.conn.alpn = event.alpn_protocol.encode("ascii") + self.conn.cipher = self.quic.tls.key_schedule.cipher_suite.name + self.conn.tls_version = "QUIC" + + # report the success to addons + tls_data = QuicTlsData(self.conn, self.context, settings=self.tls) + if self.conn is self.context.client: + yield layers.tls.TlsEstablishedClientHook(tls_data) + else: + yield layers.tls.TlsEstablishedServerHook(tls_data) + + # record an entry in the log + yield commands.Log( + f"{self.conn}: QUIC connection established. " + f"(early_data={event.early_data_accepted}, resumed={event.session_resumed})", + level="info", + ) + def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: yield from self.handle_child_commands(self.child_layer.handle_event(event)) @@ -386,34 +497,17 @@ def handle_child_commands( isinstance(command, commands.OpenConnection) and command.connection is self.conn ): - # ensure only one open command at a time for uninitialized connections - assert self.quic is None - assert self._pending_open_command is None - self._pending_open_command = command - - # try to open the underlying UDP connection - err = yield commands.OpenConnection(self.conn) - if err is None: - # open succeeded, now try to initialize QUIC and connect - if (yield from self.initialize_connection()): - assert self.quic is not None - self.quic.connect(self.conn.peername, now=self._loop.time()) - yield from self.process_events() - else: - err = "initialize QUIC failed" - yield commands.CloseConnection(self.conn) - if err is not None: - # notify the child immediately - self._pending_open_command = None - yield from self.event_to_child( - events.OpenConnectionCompleted(command, err) - ) + yield from self.open_connection_begin(command) # properly close QUIC connections elif ( isinstance(command, commands.CloseConnection) and command.connection is self.conn ): + # CloseConnection during pending OpenConnection is not allowed + assert self._pending_open_command is None + + # without QUIC simply close the connection, otherwise close QUIC first if self.quic is None: yield command else: @@ -424,28 +518,38 @@ def handle_child_commands( else: yield command - def initialize_connection(self) -> layer.CommandGenerator[bool]: + def open_connection_begin( + self, command: commands.OpenConnection + ) -> layer.CommandGenerator[None]: + # ensure only one OpenConnection at a time is called for uninitialized connections assert self.quic is None - assert self.tls is None - - # query addons to provide the necessary TLS settings - tls_data = QuicTlsData(self.conn, self.context) - if self.conn is self.context.client: - yield QuicTlsStartClientHook(tls_data) + assert self._pending_open_command is None + self._pending_open_command = command + + # try to open the underlying UDP connection + err = yield commands.OpenConnection(self.conn) + if not err: + # initialize QUIC and connect (notify the child layer after handshake) + if (yield from self.create_quic()): + assert self.quic is not None + self.quic.connect(self.conn.peername, now=self._loop.time()) + yield from self.process_events() + else: + # TLS failed, close the connection (notify child layer once closed) + yield commands.CloseConnection(self.conn) else: - yield QuicTlsStartServerHook(tls_data) - if tls_data.settings is None: + # notify the child layer immediately about the error + self._pending_open_command = None + yield from self.event_to_child(events.OpenConnectionCompleted(command, err)) + + def open_connection_end(self, reply: Optional[str]) -> layer.CommandGenerator[bool]: + if self._pending_open_command is None: return False - # create the aioquic connection - self.tls = tls_data.settings - self.quic = QuicConnection( - configuration=self.build_configuration(), - original_destination_connection_id=self.original_destination_connection_id, - ) - if self.issue_connection_id_callback is not None: - self.issue_connection_id_callback(self.quic.host_cid) - self._handle_event = self.state_connected + # let the child layer know that the connection is now open (or failed to open) + command = self._pending_open_command + self._pending_open_command = None + yield from self.event_to_child(events.OpenConnectionCompleted(command, reply)) return True def process_events(self) -> layer.CommandGenerator[None]: @@ -471,7 +575,7 @@ def process_events(self) -> layer.CommandGenerator[None]: ) # shutdown and close the connection - yield from self.shutdown_connection( + yield from self.destroy_quic( event.reason_phrase or str(event.error_code), level=( "info" if event.error_code == QuicErrorCode.NO_ERROR else "warn" @@ -479,119 +583,69 @@ def process_events(self) -> layer.CommandGenerator[None]: ) yield commands.CloseConnection(self.conn) - elif isinstance(event, quic_events.ProtocolNegotiated): - # too early, we act on HandshakeCompleted - pass + # we don't handle any further events, nor do/can we transmit data, so exit + return elif isinstance(event, quic_events.HandshakeCompleted): - # concatenate all peer certificates - all_certs = [] - if self.quic.tls._peer_certificate is not None: - all_certs.append(self.quic.tls._peer_certificate) - if self.quic.tls._peer_certificate_chain is not None: - all_certs.extend(self.quic.tls._peer_certificate_chain) - - # set the connection's TLS properties - self.conn.timestamp_tls_setup = self._loop.time() - self.conn.certificate_list = [ - certs.Cert.from_pyopenssl(x) for x in all_certs - ] - self.conn.alpn = event.alpn_protocol.encode("ascii") - self.conn.cipher = self.quic.tls.key_schedule.cipher_suite.name - self.conn.tls_version = "QUIC" - - # report the success to addons - tls_data = QuicTlsData(self.conn, self.context, settings=self.tls) - if self.conn is self.context.client: - yield layers.tls.TlsEstablishedClientHook(tls_data) - else: - yield layers.tls.TlsEstablishedServerHook(tls_data) - - # let the child layer know - if self._pending_open_command is not None: - command = self._pending_open_command - self._pending_open_command = None - yield from self.event_to_child( - events.OpenConnectionCompleted(command, reply=None) - ) + # set all TLS fields and notify the child layer + yield from self.establish_quic(event) + yield from self.open_connection_end(None) # perform next layer decisions now if isinstance(self.child_layer, layer.NextLayer): yield from self.handle_child_commands(self.child_layer._ask()) - else: - # forward the event as a QuicConnectionEvent to the child layer + elif isinstance(event, quic_events.PingAcknowledged): + # we let aioquic do it's thing but don't really care ourselves + pass + + elif isinstance(event, quic_events.ProtocolNegotiated): + # too early, we act on HandshakeCompleted + pass + + elif isinstance( + event, + ( + quic_events.DatagramFrameReceived, + quic_events.StreamDataReceived, + quic_events.StreamReset, + ), + ): + # post-handshake event, forward as QuicConnectionEvent to the child layer + assert self.conn.tls_established yield from self.event_to_child(QuicConnectionEvent(self.conn, event)) + else: + raise AssertionError(f"Unexpected event: {event}") + # handle the next event event = self.quic.next_event() # transmit buffered data and re-arm timer yield from self.transmit() - def shutdown_connection( - self, - reason: str, - level: Literal["error", "warn", "info", "alert", "debug"], - ) -> layer.CommandGenerator[None]: - # ensure QUIC has been properly shut down - assert self.quic is not None - assert self.tls is not None - assert self.quic._state is QuicConnectionState.TERMINATED - - # obsolete any current timer - if self._request_wakeup_command_and_timer is not None: - command, _ = self._request_wakeup_command_and_timer - self._obsolete_wakeup_commands.add(command) - self._request_wakeup_command_and_timer = None - - # report as TLS failure if the termination happened before the handshake - if not self.conn.tls_established: - self.conn.error = reason - tls_data = QuicTlsData(self.conn, self.context, settings=self.tls) - if self.conn is self.context.client: - yield layers.tls.TlsFailedClientHook(tls_data) - else: - yield layers.tls.TlsFailedServerHook(tls_data) - - # make a log entry directly - yield commands.Log( - f"QUIC connection {self.conn} shutdown: {reason}", level=level - ) - - # let the child layer know and stop handling events - # we also don't handle any commands from the child at this point anymore - if self._pending_open_command is not None: - yield from self.child_layer.handle_event( - events.OpenConnectionCompleted(self._pending_open_command, reason) - ) - self._pending_open_command = None - self._handle_event = self.state_done - def start(self) -> layer.CommandGenerator[None]: yield from self.event_to_child(events.Start()) - @expect( - events.ConnectionClosed, events.DataReceived, events.Wakeup, QuicConnectionEvent - ) - def state_connected(self, event: events.Event) -> layer.CommandGenerator[None]: + def state_has_quic(self, event: events.Event) -> layer.CommandGenerator[None]: assert self.quic is not None - # forward incoming data only to aioquic if isinstance(event, events.DataReceived) and event.connection is self.conn: + # forward incoming data only to aioquic assert event.remote_addr is not None self.quic.receive_datagram( event.data, event.remote_addr, now=self._loop.time() ) yield from self.process_events() - # handle connections closed by peer elif ( isinstance(event, events.ConnectionClosed) and event.connection is self.conn ): + # handle connections closed by peer, which in UDP's case is a timeout reason = "Peer UDP connection timed out." + # there is no point in calling quic.close, as it cannot send packets anymore - # so we simply set the state and simulate a ConnectionTerminated event + # set the new connection state and simulate a ConnectionTerminated event (if established) self.quic._set_state(QuicConnectionState.TERMINATED) if self.conn.tls_established: yield from self.event_to_child( @@ -605,14 +659,14 @@ def state_connected(self, event: events.Event) -> layer.CommandGenerator[None]: ) ) - # forward the event only if no open command is pending and shutdown - if self._pending_open_command is None: + # shutdown QUIC and handle the ConnectionClosed event + yield from self.destroy_quic(reason, level="info") + if not (yield from self.open_connection_end(reason)): + # connection was opened before QUIC layer, report to the child layer yield from self.event_to_child(event) - yield from self.shutdown_connection(reason, level="info") - # intercept wakeup events for aioquic elif isinstance(event, events.Wakeup): - # swallow obsolete wakeups + # swallow obsolete wakeup events if event.command in self._obsolete_wakeup_commands: self._obsolete_wakeup_commands.remove(event.command) else: @@ -632,34 +686,37 @@ def state_connected(self, event: events.Event) -> layer.CommandGenerator[None]: # forward other events to the child layer yield from self.event_to_child(event) - @expect( - events.ConnectionClosed, events.DataReceived, events.Wakeup, QuicConnectionEvent - ) - def state_done(self, event: events.Event) -> layer.CommandGenerator[None]: - # filter out obsolete wakeups + def state_no_quic(self, event: events.Event) -> layer.CommandGenerator[None]: + assert self.quic is None + if ( isinstance(event, events.Wakeup) and event.command in self._obsolete_wakeup_commands ): + # filter out obsolete wakeups self._obsolete_wakeup_commands.remove(event.command) - # ignore any further received data + elif ( + isinstance(event, events.ConnectionClosed) and event.connection is self.conn + ): + # if there is was an OpenConnection command, then create_quic failed + # otherwise the connection was opened before the QUIC layer, so forward the event + if not (yield from self.open_connection_end("QUIC initialization failed")): + yield from self.event_to_child(event) + elif isinstance(event, events.DataReceived) and event.connection is self.conn: + # ignore received data events + # this either happens after QUIC is closed or if the underlying UDP connection is opened + # before the QUIC layer and missing initialization during the child layer's start event pass - # forward all other events else: - yield from self.child_layer.handle_event(event) - - @expect( - events.ConnectionClosed, events.DataReceived, events.Wakeup, QuicConnectionEvent - ) - def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: - yield from () + # forward all other events to the child layer + yield from self.event_to_child(event) @expect(events.Start) def state_start(self, _) -> layer.CommandGenerator[None]: - self._handle_event = self.state_ready + self._handle_event = self.state_no_quic yield from self.start() def transmit(self) -> layer.CommandGenerator[None]: @@ -716,15 +773,17 @@ def start(self) -> layer.CommandGenerator[None]: # try to open the upstream connection if self.wait_for_upstream: err = yield commands.OpenConnection(self.context.server) - if err is not None: + if err: yield commands.Log( f"Unable to establish QUIC connection with server ({err}). " f"Trying to establish QUIC with client anyway." ) - # is still connected then initialize, close on failure - if not self.conn.connected and not (yield from self.initialize_connection()): + # if (still) connected then initialize, close on failure + if self.conn.connected and not (yield from self.create_quic()): yield commands.CloseConnection(self.conn) + if self.wait_for_upstream and err is not None: + yield commands.CloseConnection(self.context.server) self._handle_event = self.state_failed @expect(events.ConnectionClosed, events.DataReceived, QuicConnectionEvent) From f8b7b6e173ad0167da3de9482bd04a484ce069f1 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Thu, 23 Jun 2022 01:57:01 +0200 Subject: [PATCH 025/529] [quic] fix cert issue --- mitmproxy/addons/tlsconfig.py | 8 +++++--- mitmproxy/proxy/layers/quic.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index 7a7c7fbffc..3b6df62600 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -322,6 +322,7 @@ def quic_tls_start_client(self, tls_start: quic.QuicTlsData) -> None: tls_start.settings.cipher_suites = [ CipherSuite(cipher) for cipher in client.cipher_list ] + if ctx.options.add_upstream_certs_to_client_chain: tls_start.settings.certificate_chain.extend(cert._cert for cert in server.certificate_list) @@ -348,9 +349,10 @@ def quic_tls_start_server(self, tls_start: quic.QuicTlsData) -> None: if not server.cipher_list and ctx.options.ciphers_server: server.cipher_list = ctx.options.ciphers_server.split(":") - tls_start.settings.cipher_suites = [ - CipherSuite(cipher) for cipher in server.cipher_list - ] + if server.cipher_list: + tls_start.settings.cipher_suites = [ + CipherSuite(cipher) for cipher in server.cipher_list + ] client_cert = self.get_client_cert(server) if client_cert: diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 0a6225ed13..c4a57de9cd 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -443,7 +443,7 @@ def establish_quic( assert self.tls is not None # concatenate all peer certificates - all_certs = [] + all_certs: List[x509.Certificate] = [] if self.quic.tls._peer_certificate is not None: all_certs.append(self.quic.tls._peer_certificate) if self.quic.tls._peer_certificate_chain is not None: @@ -451,7 +451,7 @@ def establish_quic( # set the connection's TLS properties self.conn.timestamp_tls_setup = self._loop.time() - self.conn.certificate_list = [certs.Cert.from_pyopenssl(x) for x in all_certs] + self.conn.certificate_list = [certs.Cert(cert) for cert in all_certs] self.conn.alpn = event.alpn_protocol.encode("ascii") self.conn.cipher = self.quic.tls.key_schedule.cipher_suite.name self.conn.tls_version = "QUIC" From db9f4b5a2df6cea4d8f6f455ead3a10b55eba122 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Thu, 23 Jun 2022 02:46:52 +0200 Subject: [PATCH 026/529] [quic] add asserts --- mitmproxy/proxy/layers/http/_http3.py | 1 + mitmproxy/proxy/layers/quic.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 2c8e0b0b16..18bff41a64 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -60,6 +60,7 @@ class Http3Connection(HttpConnection): def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: if isinstance(event, events.Start): quic = yield QuicGetConnection(self.conn) + assert quic is not None assert isinstance(quic, QuicConnection) self.quic = quic self.h3_conn = H3Connection(quic, enable_webtransport=False) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index c4a57de9cd..d57649ead6 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -484,6 +484,7 @@ def handle_child_commands( isinstance(command, QuicGetConnection) and command.connection is self.conn ): + assert self.quic is not None yield from self.event_to_child( QuicGetConnectionCompleted(command, self.quic) ) @@ -641,8 +642,8 @@ def state_has_quic(self, event: events.Event) -> layer.CommandGenerator[None]: elif ( isinstance(event, events.ConnectionClosed) and event.connection is self.conn ): - # handle connections closed by peer, which in UDP's case is a timeout - reason = "Peer UDP connection timed out." + # handle connections closed by peer (which in UDP's case is usually a timeout) + reason = "Peer UDP connection closed or timed out." # there is no point in calling quic.close, as it cannot send packets anymore # set the new connection state and simulate a ConnectionTerminated event (if established) From 5902c7b0fa9d354172a3d438ec408d88c959857a Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Thu, 23 Jun 2022 15:13:13 +0200 Subject: [PATCH 027/529] [quic] temp workaround for QuicGetConnection issue --- mitmproxy/proxy/layers/http/_http3.py | 41 +++++++++++++++++++-------- mitmproxy/proxy/layers/quic.py | 33 ++++++++++----------- 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 18bff41a64..18fafb8d9f 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -17,10 +17,12 @@ from mitmproxy.net.http import status_codes from mitmproxy.proxy import commands, context, events, layer from mitmproxy.proxy.layers.quic import ( + _QuicLayer, QuicConnectionEvent, - QuicGetConnection, + # QuicGetConnection, QuicTransmit, ) +from mitmproxy.proxy.utils import expect from . import ( RequestData, @@ -59,14 +61,20 @@ class Http3Connection(HttpConnection): def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: if isinstance(event, events.Start): - quic = yield QuicGetConnection(self.conn) - assert quic is not None - assert isinstance(quic, QuicConnection) - self.quic = quic - self.h3_conn = H3Connection(quic, enable_webtransport=False) + # this doesn't always work: + # quic = yield QuicGetConnection(self.conn) + # assert isinstance(quic, QuicConnection) + # self.quic = quic + # + # temporary workaround: + for layer_ in self.context.layers: + if isinstance(layer_, _QuicLayer) and layer_.conn is self.conn: + self.quic = layer_.quic + assert self.quic is not None + self.h3_conn = H3Connection(self.quic, enable_webtransport=False) elif isinstance(event, events.ConnectionClosed): - self._handle_event = self.done + self._handle_event = self.done # type: ignore # send mitmproxy HTTP events over the H3 connection elif isinstance(event, HttpEvent): @@ -90,6 +98,9 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: ), end_stream=event.end_stream, ) + if event.end_stream: + # this will prevent any further headers or data from being sent + self.h3_conn._stream[event.stream_id].headers_send_state = H3HeadersState.AFTER_TRAILERS elif isinstance(event, (RequestTrailers, ResponseTrailers)): self.h3_conn.send_headers( stream_id=event.stream_id, @@ -122,6 +133,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: # report abrupt stream resets if isinstance(event, quic_events.StreamReset): if event.stream_id in self.h3_conn._stream: + # try to get a name for the error from its code try: reason = H3ErrorCode(event.error_code).name except ValueError: @@ -129,6 +141,8 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: reason = QuicErrorCode(event.error_code).name except ValueError: reason = str(event.error_code) + + # report the protocol error (doing the same error code mingling as H2) code = ( status_codes.CLIENT_CLOSED_REQUEST if event.error_code == H3ErrorCode.H3_REQUEST_CANCELLED @@ -192,10 +206,12 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: ) else: yield ReceiveHttp(receive_event) - if h3_event.stream_ended: - yield ReceiveHttp( - self.ReceiveEndOfMessage(stream_id=h3_event.stream_id) - ) + + # always report an EndOfMessage if the stream has ended + if h3_event.stream_ended: + yield ReceiveHttp( + self.ReceiveEndOfMessage(stream_id=h3_event.stream_id) + ) # we don't support push, web transport, etc. else: @@ -206,7 +222,8 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: else: raise AssertionError(f"Unexpected event: {event!r}") - def done(self, event: events.Event) -> layer.CommandGenerator[None]: + @expect(events.DataReceived, HttpEvent, events.ConnectionClosed) + def done(self, _) -> layer.CommandGenerator[None]: yield from () @abstractmethod diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index d57649ead6..e373e40bf2 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -364,6 +364,10 @@ def create_quic(self) -> layer.CommandGenerator[bool]: assert self.quic is None assert self.tls is None + # cannot initialize QUIC on a closed connection + if not self.conn.connected: + return False + # in case the connection is being reused, clear all handshake data self.conn.timestamp_tls_setup = None self.conn.certificate_list = () @@ -419,7 +423,7 @@ def destroy_quic( else: yield layers.tls.TlsFailedServerHook(tls_data) - # clear the quic fields + # clear the QUIC fields self.quic = None self.tls = None self._handle_event = self.state_no_quic @@ -479,28 +483,27 @@ def handle_child_commands( # filter commands coming from the child layer for command in child_commands: - # answer with the aioquic connection instance if ( isinstance(command, QuicGetConnection) and command.connection is self.conn ): + # answer with the aioquic connection instance assert self.quic is not None yield from self.event_to_child( QuicGetConnectionCompleted(command, self.quic) ) - # transmit buffered data and re-arm timer elif isinstance(command, QuicTransmit) and command.connection is self.quic: + # transmit buffered data and re-arm timer yield from self.transmit() - # open the QUIC connection elif ( isinstance(command, commands.OpenConnection) and command.connection is self.conn ): + # try to open the QUIC connection and report OpenConnectionCompleted later yield from self.open_connection_begin(command) - # properly close QUIC connections elif ( isinstance(command, commands.CloseConnection) and command.connection is self.conn @@ -515,14 +518,14 @@ def handle_child_commands( self.quic.close(reason_phrase="CloseConnection command received.") yield from self.process_events() - # return other commands else: + # return other commands yield command def open_connection_begin( self, command: commands.OpenConnection ) -> layer.CommandGenerator[None]: - # ensure only one OpenConnection at a time is called for uninitialized connections + # ensure only one OpenConnection is called at a time and only for uninitialized connections assert self.quic is None assert self._pending_open_command is None self._pending_open_command = command @@ -700,15 +703,14 @@ def state_no_quic(self, event: events.Event) -> layer.CommandGenerator[None]: elif ( isinstance(event, events.ConnectionClosed) and event.connection is self.conn ): - # if there is was an OpenConnection command, then create_quic failed + # if there was an OpenConnection command, then create_quic failed # otherwise the connection was opened before the QUIC layer, so forward the event if not (yield from self.open_connection_end("QUIC initialization failed")): yield from self.event_to_child(event) elif isinstance(event, events.DataReceived) and event.connection is self.conn: - # ignore received data events - # this either happens after QUIC is closed or if the underlying UDP connection is opened - # before the QUIC layer and missing initialization during the child layer's start event + # ignore received data, which either happens after QUIC is closed or if the underlying + # UDP connection is already opened and no QUIC initialization is being performed pass else: @@ -727,7 +729,7 @@ def transmit(self) -> layer.CommandGenerator[None]: for data, addr in self.quic.datagrams_to_send(now=self._loop.time()): yield commands.SendData(self.conn, data, addr) - # mark an existing wakeup command as obsolete if it now longer matches the time + # mark an existing wakeup command as obsolete if it no longer matches the timer timer = self.quic.get_timer() if self._request_wakeup_command_and_timer is not None: command, existing_timer = self._request_wakeup_command_and_timer @@ -780,14 +782,13 @@ def start(self) -> layer.CommandGenerator[None]: f"Trying to establish QUIC with client anyway." ) - # if (still) connected then initialize, close on failure - if self.conn.connected and not (yield from self.create_quic()): + # initialize QUIC, shutdown on failure + if not (yield from self.create_quic()): yield commands.CloseConnection(self.conn) if self.wait_for_upstream and err is not None: yield commands.CloseConnection(self.context.server) self._handle_event = self.state_failed - @expect(events.ConnectionClosed, events.DataReceived, QuicConnectionEvent) def state_failed(self, _) -> layer.CommandGenerator[None]: yield from () @@ -874,7 +875,7 @@ def state_wait_for_hello(self, event: events.Event) -> layer.CommandGenerator[No ) # replace this layer and start the next one - self.handle_event = next_layer.handle_event + self.handle_event = next_layer.handle_event # type: ignore self._handle_event = next_layer._handle_event yield from next_layer.handle_event(events.Start()) yield from next_layer.handle_event(event) From bd213b4a25d6da2692aab6f4cf9ec6d645e22b1d Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Fri, 24 Jun 2022 15:08:45 +0200 Subject: [PATCH 028/529] [quic] unified error handling --- mitmproxy/proxy/layers/http/_http3.py | 27 ++++++++------------------ mitmproxy/proxy/layers/quic.py | 28 ++++++++++++++++++--------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 18fafb8d9f..348fae241e 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -11,7 +11,6 @@ from aioquic.h3 import events as h3_events from aioquic.quic import events as quic_events from aioquic.quic.connection import QuicConnection -from aioquic.quic.packet import QuicErrorCode from mitmproxy import http, version from mitmproxy.net.http import status_codes @@ -21,6 +20,7 @@ QuicConnectionEvent, # QuicGetConnection, QuicTransmit, + error_code_to_str, ) from mitmproxy.proxy.utils import expect @@ -100,7 +100,9 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: ) if event.end_stream: # this will prevent any further headers or data from being sent - self.h3_conn._stream[event.stream_id].headers_send_state = H3HeadersState.AFTER_TRAILERS + self.h3_conn._stream[ + event.stream_id + ].headers_send_state = H3HeadersState.AFTER_TRAILERS elif isinstance(event, (RequestTrailers, ResponseTrailers)): self.h3_conn.send_headers( stream_id=event.stream_id, @@ -111,12 +113,10 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: self.h3_conn.send_data( stream_id=event.stream_id, data=b"", end_stream=True ) - elif isinstance( - event, (RequestProtocolError, ResponseProtocolError) - ): + elif isinstance(event, (RequestProtocolError, ResponseProtocolError)): self.protocol_error(event) else: - raise AssertionError(f"Unexpected event: {event}") + raise AssertionError(f"Unexpected event: {event!r}") except FrameUnexpected: # Http2Connection also ignores HttpEvents that violate the current stream state @@ -133,15 +133,6 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: # report abrupt stream resets if isinstance(event, quic_events.StreamReset): if event.stream_id in self.h3_conn._stream: - # try to get a name for the error from its code - try: - reason = H3ErrorCode(event.error_code).name - except ValueError: - try: - reason = QuicErrorCode(event.error_code).name - except ValueError: - reason = str(event.error_code) - # report the protocol error (doing the same error code mingling as H2) code = ( status_codes.CLIENT_CLOSED_REQUEST @@ -151,7 +142,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: yield ReceiveHttp( self.ReceiveProtocolError( stream_id=event.stream_id, - message=f"stream reset by client ({reason})", + message=f"stream reset by client ({error_code_to_str(event.error_code)})", code=code, ) ) @@ -215,9 +206,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: # we don't support push, web transport, etc. else: - yield commands.Log( - f"Ignored unsupported H3 event: {h3_event!r}" - ) + yield commands.Log(f"Ignored unsupported H3 event: {h3_event!r}") else: raise AssertionError(f"Unexpected event: {event!r}") diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index e373e40bf2..ac4111a7b7 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union from aioquic.buffer import Buffer as QuicBuffer +from aioquic.h3.connection import ErrorCode as H3ErrorCode from aioquic.quic import events as quic_events from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.connection import ( @@ -132,6 +133,16 @@ class QuicClientHello(Exception): data: bytes +def error_code_to_str(error_code: int) -> str: + try: + return H3ErrorCode(error_code).name + except ValueError: + try: + return QuicErrorCode(error_code).name + except ValueError: + return f"unknown error (0x{error_code:x})" + + def pull_client_hello_and_connection_id(data: bytes) -> Tuple[ClientHello, bytes]: # ensure the first packet is indeed the initial one buffer = QuicBuffer(data=data) @@ -301,14 +312,10 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: stream_id_out, flow = lookup_in[stream_id_in] quic_out.stop_stream(stream_id_out, quic_event.error_code) - # try to get a name describing the reset reason - try: - err = QuicErrorCode(quic_event.error_code).name - except ValueError: - err = str(quic_event.error_code) - # report the error to addons and delete the stream - flow.error = mitm_flow.Error(str(err)) + flow.error = mitm_flow.Error( + error_code_to_str(quic_event.error_code) + ) yield tcp_layer.TcpErrorHook(flow) flow.live = False del lookup_in[stream_id_in] @@ -580,9 +587,12 @@ def process_events(self) -> layer.CommandGenerator[None]: # shutdown and close the connection yield from self.destroy_quic( - event.reason_phrase or str(event.error_code), + event.reason_phrase or error_code_to_str(event.error_code), level=( - "info" if event.error_code == QuicErrorCode.NO_ERROR else "warn" + "info" + if event.error_code + in (QuicErrorCode.NO_ERROR, H3ErrorCode.H3_NO_ERROR) + else "warn" ), ) yield commands.CloseConnection(self.conn) From 46096e6af9b8431301897a4b3f411dcf401c09ea Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Sun, 26 Jun 2022 23:47:03 +0200 Subject: [PATCH 029/529] [quic] fix context and stream ended handling --- mitmproxy/proxy/layers/http/_http2.py | 10 +++--- mitmproxy/proxy/layers/http/_http3.py | 45 +++++++++++++++++---------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http2.py b/mitmproxy/proxy/layers/http/_http2.py index ec4aed6b72..cfe2b53de9 100644 --- a/mitmproxy/proxy/layers/http/_http2.py +++ b/mitmproxy/proxy/layers/http/_http2.py @@ -314,6 +314,7 @@ def normalize_h2_headers(headers: list[tuple[bytes, bytes]]) -> CommandGenerator def format_h2_request_headers( + context: Context, event: RequestHeaders, ) -> CommandGenerator[list[tuple[bytes, bytes]]]: pseudo_headers = [ @@ -326,7 +327,7 @@ def format_h2_request_headers( if event.request.is_http2 or event.request.is_http3: hdrs = list(event.request.headers.fields) - if ctx.options.normalize_outbound_headers: + if context.options.normalize_outbound_headers: yield from normalize_h2_headers(hdrs) else: headers = event.request.headers @@ -339,6 +340,7 @@ def format_h2_request_headers( def format_h2_response_headers( + context: Context, event: ResponseHeaders, ) -> CommandGenerator[list[tuple[bytes, bytes]]]: headers = [ @@ -346,7 +348,7 @@ def format_h2_response_headers( *event.response.headers.fields, ] if event.response.is_http2: - if ctx.options.normalize_outbound_headers: + if context.options.normalize_outbound_headers: yield from normalize_h2_headers(headers) else: headers = normalize_h1_headers(headers, False) @@ -372,7 +374,7 @@ def _handle_event(self, event: Event) -> CommandGenerator[None]: if self.is_open_for_us(event.stream_id): self.h2_conn.send_headers( event.stream_id, - headers=(yield from format_h2_response_headers(event)), + headers=(yield from format_h2_response_headers(self.context, event)), end_stream=event.end_stream, ) yield SendData(self.conn, self.h2_conn.data_to_send()) @@ -517,7 +519,7 @@ def _handle_event2(self, event: Event) -> CommandGenerator[None]: elif isinstance(event, RequestHeaders): self.h2_conn.send_headers( event.stream_id, - headers=(yield from format_h2_request_headers(event)), + headers=(yield from format_h2_request_headers(self.context, event)), end_stream=event.end_stream, ) self.streams[event.stream_id] = StreamState.EXPECTING_HEADERS diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 348fae241e..26916b1800 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -10,7 +10,7 @@ ) from aioquic.h3 import events as h3_events from aioquic.quic import events as quic_events -from aioquic.quic.connection import QuicConnection +from aioquic.quic.connection import QuicConnection, stream_is_unidirectional from mitmproxy import http, version from mitmproxy.net.http import status_codes @@ -91,9 +91,9 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: stream_id=event.stream_id, headers=( yield from ( - format_h2_request_headers(event) + format_h2_request_headers(self.context, event) if isinstance(event, RequestHeaders) - else format_h2_response_headers(event) + else format_h2_response_headers(self.context, event) ) ), end_stream=event.end_stream, @@ -133,24 +133,33 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: # report abrupt stream resets if isinstance(event, quic_events.StreamReset): if event.stream_id in self.h3_conn._stream: - # report the protocol error (doing the same error code mingling as H2) - code = ( - status_codes.CLIENT_CLOSED_REQUEST - if event.error_code == H3ErrorCode.H3_REQUEST_CANCELLED - else self.ReceiveProtocolError.code - ) - yield ReceiveHttp( - self.ReceiveProtocolError( - stream_id=event.stream_id, - message=f"stream reset by client ({error_code_to_str(event.error_code)})", - code=code, + stream = self.h3_conn._stream[event.stream_id] + if not stream.ended: + # mark the receiving part of the stream as ended + # (H3Connection alas doesn't handle StreamReset) + stream.ended = True + + # report the protocol error (doing the same error code mingling as H2) + code = ( + status_codes.CLIENT_CLOSED_REQUEST + if event.error_code == H3ErrorCode.H3_REQUEST_CANCELLED + else self.ReceiveProtocolError.code + ) + yield ReceiveHttp( + self.ReceiveProtocolError( + stream_id=event.stream_id, + message=f"stream reset by client ({error_code_to_str(event.error_code)})", + code=code, + ) ) - ) # report a protocol error for all remaining open streams when a connection is terminated elif isinstance(event, quic_events.ConnectionTerminated): for stream in self.h3_conn._stream.values(): - if not stream.ended: + if ( + self.quic._stream_can_receive(stream.stream_id) + and not stream.ended + ): yield ReceiveHttp( self.ReceiveProtocolError( stream_id=stream.stream_id, @@ -344,7 +353,9 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: assert self.quic is not None ours = self.our_stream_id.get(event.stream_id, None) if ours is None: - ours = self.quic.get_next_available_stream_id() + ours = self.quic.get_next_available_stream_id( + is_unidirectional=stream_is_unidirectional(event.stream_id) + ) self.our_stream_id[event.stream_id] = ours self.their_stream_id[ours] = event.stream_id event.stream_id = ours From 534bc598337c29d3c004b028bd1c975c4a449715 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 27 Jun 2022 05:14:09 +0200 Subject: [PATCH 030/529] [quic] improve relay stream layer --- mitmproxy/addons/next_layer.py | 7 +- mitmproxy/proxy/layers/http/_http2.py | 2 +- mitmproxy/proxy/layers/http/_http3.py | 16 +- mitmproxy/proxy/layers/quic.py | 410 +++++++++++++++++--------- 4 files changed, 283 insertions(+), 152 deletions(-) diff --git a/mitmproxy/addons/next_layer.py b/mitmproxy/addons/next_layer.py index b9908b0a24..9b78292253 100644 --- a/mitmproxy/addons/next_layer.py +++ b/mitmproxy/addons/next_layer.py @@ -129,11 +129,14 @@ def _next_layer( mode = HTTPMode.upstream else: return None - return layers.HttpLayer(context=context, mode=mode) + return layers.HttpLayer(context, mode) else: if context.server.address is None: return None - return quic.QuicRelayLayer(context) + if isinstance(context.layers[1], quic.ServerQuicLayer): + return quic.QuicRelayLayer(context) + else: + return quic.ServerQuicLayer(context, quic.QuicRelayLayer(context)) if len(context.layers) == 0: return self.make_top_layer(context) diff --git a/mitmproxy/proxy/layers/http/_http2.py b/mitmproxy/proxy/layers/http/_http2.py index cfe2b53de9..785b0e1e36 100644 --- a/mitmproxy/proxy/layers/http/_http2.py +++ b/mitmproxy/proxy/layers/http/_http2.py @@ -13,7 +13,7 @@ import h2.stream import h2.utilities -from mitmproxy import ctx, http, version +from mitmproxy import http, version from mitmproxy.connection import Connection from mitmproxy.net.http import status_codes, url from mitmproxy.utils import human diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 26916b1800..cb37086591 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -16,11 +16,10 @@ from mitmproxy.net.http import status_codes from mitmproxy.proxy import commands, context, events, layer from mitmproxy.proxy.layers.quic import ( - _QuicLayer, QuicConnectionEvent, - # QuicGetConnection, QuicTransmit, error_code_to_str, + get_quic_connection, ) from mitmproxy.proxy.utils import expect @@ -61,16 +60,7 @@ class Http3Connection(HttpConnection): def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: if isinstance(event, events.Start): - # this doesn't always work: - # quic = yield QuicGetConnection(self.conn) - # assert isinstance(quic, QuicConnection) - # self.quic = quic - # - # temporary workaround: - for layer_ in self.context.layers: - if isinstance(layer_, _QuicLayer) and layer_.conn is self.conn: - self.quic = layer_.quic - assert self.quic is not None + self.quic = get_quic_connection(self.context, self.conn) self.h3_conn = H3Connection(self.quic, enable_webtransport=False) elif isinstance(event, events.ConnectionClosed): @@ -123,7 +113,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: return # transmit buffered data and re-arm timer - yield QuicTransmit(self.quic) + yield QuicTransmit(self.conn, self.quic) # handle events from the underlying QUIC connection elif isinstance(event, QuicConnectionEvent): diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index ac4111a7b7..e9bc9fe886 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -12,6 +12,7 @@ QuicConnectionError, QuicConnectionState, QuicErrorCode, + stream_is_unidirectional, ) from aioquic.tls import CipherSuite, HandshakeType from aioquic.quic.packet import PACKET_TYPE_INITIAL, pull_quic_header @@ -88,25 +89,27 @@ class QuicTlsStartServerHook(commands.StartHook): @dataclass class QuicConnectionEvent(events.ConnectionEvent): - event: quic_events.QuicEvent - + """ + Connection-based event that is triggered whenever a new event from QUIC is received. + Established connections are guaranteed to receive a ConnectionTerminated event at the end. -class QuicGetConnection(commands.ConnectionCommand): # -> Optional[QuicConnection] - blocking = True + Note: + 'Established' means that an OpenConnection command called in a child layer returned no error. + Without a predefined child layer, the QUIC layer uses NextLayer mechanics to select the child + layer. The moment is asks addons for the child layer, the connection is considered established. + """ + event: quic_events.QuicEvent -@dataclass(repr=False) -class QuicGetConnectionCompleted(events.CommandCompleted): - command: QuicGetConnection - reply: Optional[QuicConnection] +class QuicTransmit(commands.ConnectionCommand): + """Command that will transmit buffered data and re-arm the given QUIC connection's timer.""" -class QuicTransmit(commands.Command): - connection: QuicConnection + quic: QuicConnection - def __init__(self, connection: QuicConnection) -> None: - super().__init__() - self.connection = connection + def __init__(self, connection: connection.Connection, quic: QuicConnection) -> None: + super().__init__(connection) + self.quic = quic class QuicSecretsLogger: @@ -128,12 +131,9 @@ def flush(self) -> None: pass -@dataclass -class QuicClientHello(Exception): - data: bytes - - def error_code_to_str(error_code: int) -> str: + """Returns the corresponding name of the given error code or a string containing its numeric value.""" + try: return H3ErrorCode(error_code).name except ValueError: @@ -143,7 +143,37 @@ def error_code_to_str(error_code: int) -> str: return f"unknown error (0x{error_code:x})" +def get_quic_connection( + context: context.Context, connection: connection.Connection +) -> QuicConnection: + """Retrieve the QUIC connection associated with the given connection in the given context.""" + + for quic_layer in context.layers: + if isinstance(quic_layer, _QuicLayer) and quic_layer.conn is connection: + if not quic_layer.conn.tls_established: + raise ValueError( + f"QUIC on connection {connection} has not been established yet." + ) + return quic_layer.quic + raise ValueError(f"Connection {connection} has no QUIC.") + + +def is_success_error_code(error_code: int) -> bool: + """Returns whether the given error code actually indicates no error.""" + + return error_code in (QuicErrorCode.NO_ERROR, H3ErrorCode.H3_NO_ERROR) + + +@dataclass +class QuicClientHello(Exception): + """Helper error only used in `pull_client_hello_and_connection_id`.""" + + data: bytes + + def pull_client_hello_and_connection_id(data: bytes) -> Tuple[ClientHello, bytes]: + """Helper function that parses a client hello packet.""" + # ensure the first packet is indeed the initial one buffer = QuicBuffer(data=data) header = pull_quic_header(buffer) @@ -193,133 +223,249 @@ def initialize_replacement(peer_cid: bytes) -> None: raise ValueError("No ClientHello returned.") +@dataclass +class QuicRelayStream: + client_ended: bool + client_id: int + flow: tcp.TCPFlow + server_ended: bool + server_id: int + + def stream_id(self, client: bool) -> int: + return self.client_id if client else self.server_id + + def has_ended(self, client: bool) -> bool: + stream_ended = self.client_ended if client else self.server_ended + return stream_ended or not self.flow.live + + class QuicRelayLayer(layer.Layer): + """ + Layer on top of `ClientQuicLayer` and `ServerQuicLayer`, that simply relays all QUIC streams and datagrams. + This layer is chosen by the default NextLayer addon if ALPN yields no known protocol. + """ + # for now we're (ab)using the TCPFlow until https://github.com/mitmproxy/mitmproxy/pull/5414 is resolved - datagram_flow: Optional[tcp.TCPFlow] = None - lookup_server: Dict[int, Tuple[int, tcp.TCPFlow]] - lookup_client: Dict[int, Tuple[int, tcp.TCPFlow]] - quic_server: Optional[QuicConnection] = None + flow: tcp.TCPFlow # used for datagrams and to signal general connection issues + streams_by_flow: Dict[tcp.TCPFlow, QuicRelayStream] + streams_by_client_id: Dict[int, QuicRelayStream] + streams_by_server_id: Dict[int, QuicRelayStream] quic_client: Optional[QuicConnection] = None + quic_server: Optional[QuicConnection] = None def __init__(self, context: context.Context) -> None: super().__init__(context) - self.lookup_server = {} - self.lookup_client = {} - - def end_flow( - self, flow: tcp.TCPFlow, event: quic_events.ConnectionTerminated - ) -> layer.CommandGenerator[None]: - if event.error_code == QuicErrorCode.NO_ERROR: - yield tcp_layer.TcpEndHook(flow) + self.flow = tcp.TCPFlow( + self.context.client, + self.context.server, + live=True, + ) + self.streams_by_flow = {} + self.streams_by_client_id = {} + self.streams_by_server_id = {} + + def get_or_create_stream( + self, stream_id: int, from_client: bool + ) -> layer.CommandGenerator[QuicRelayStream]: + streams_by_id = ( + self.streams_by_client_id if from_client else self.streams_by_server_id + ) + if stream_id in streams_by_id: + return streams_by_id[stream_id] else: - flow.error = mitm_flow.Error(event.reason_phrase) - yield tcp_layer.TcpErrorHook(flow) - flow.live = False - - def get_quic( - self, conn: connection.Connection - ) -> layer.CommandGenerator[QuicConnection]: - quic = yield QuicGetConnection(conn) - assert isinstance(quic, QuicConnection) - return quic - - def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: - if isinstance(event, events.Start): - self.quic_server = yield from self.get_quic(self.context.server) - self.quic_client = yield from self.get_quic(self.context.client) - - elif isinstance(event, QuicConnectionEvent): - assert self.quic_server is not None - assert self.quic_client is not None + # reserve the peer stream id + is_unidirectional = stream_is_unidirectional(stream_id) + peer_quic = self.quic_server if from_client else self.quic_client + assert peer_quic + peer_stream_id = peer_quic.get_next_available_stream_id(is_unidirectional) + + # create the instance and make sure unidirectional streams are marked as ended + stream = QuicRelayStream( + flow=tcp.TCPFlow( + self.context.client, + self.context.server, + live=True, + ), + client_ended=is_unidirectional and not from_client, + server_ended=is_unidirectional and from_client, + client_id=stream_id if from_client else peer_stream_id, + server_id=peer_stream_id if from_client else stream_id, + ) + + # register the stream and start the flow + self.streams_by_flow[stream.flow] = stream + self.streams_by_client_id[stream.client_id] = stream + self.streams_by_server_id[stream.server_id] = stream + yield tcp_layer.TcpStartHook(stream.flow) + return stream + + @expect(events.Start) + def state_start(self, _) -> layer.CommandGenerator[None]: + # retrieve the client QUIC connection and mark the main flow as started + self.quic_client = get_quic_connection(self.context, self.context.client) + yield tcp_layer.TcpStartHook(self.flow) + + # open the upstream connection if necessary + if self.context.server.timestamp_start is None: + err = yield commands.OpenConnection(self.context.server) + if err: + self.flow.error = mitm_flow.Error(str(err)) + yield tcp_layer.TcpErrorHook(self.flow) + self.flow.live = False + yield commands.CloseConnection(self.context.client) + self._handle_event = self.state_done + return + self.quic_server = get_quic_connection(self.context, self.context.server) + self._handle_event = self.state_ready + + @expect(QuicConnectionEvent, tcp_layer.TcpMessageInjected) + def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: + + if isinstance(event, tcp_layer.TcpMessageInjected): + # translate injected messages into QUIC events + flow = event.flow + assert isinstance(flow, tcp.TCPFlow) + connection = ( + self.context.client + if event.message.from_client + else self.context.server + ) + if flow is self.flow: + event = QuicConnectionEvent( + connection, quic_events.DatagramFrameReceived(event.message.content) + ) + elif flow in self.streams_by_flow: + stream = self.streams_by_flow[flow] + event = QuicConnectionEvent( + connection, + quic_events.StreamDataReceived( + stream_id=( + stream.client_id + if event.message.from_client + else stream.server_id + ), + data=event.message.content, + end_stream=False, + ), + ) + else: + # only handle messages of known flows + return + if isinstance(event, QuicConnectionEvent): + # define helper variables quic_event = event.event from_client = event.connection is self.context.client - lookup_in = self.lookup_client if from_client else self.lookup_server - lookup_out = self.lookup_server if from_client else self.lookup_client - # quic_in = self.quic_client if from_client else self.quic_server - quic_out = self.quic_server if from_client else self.quic_client + peer_connection = ( + self.context.server if from_client else self.context.client + ) + peer_quic = self.quic_server if from_client else self.quic_client + assert peer_quic is not None - # forward close and end all flows if isinstance(quic_event, quic_events.ConnectionTerminated): - quic_out.close( + # report the termination as error to all non-ended streams + for flow in self.streams_by_flow: + if flow.live: + self.flow.error = mitm_flow.Error( + "Connection terminated " + f" (code={quic_event.error_code}, reason={quic_event.reason_phrase})." + ) + yield tcp_layer.TcpErrorHook(flow) + flow.live = False + + # end the main flow + if self.flow.live: + if is_success_error_code(quic_event.error_code): + yield tcp_layer.TcpEndHook(flow) + else: + self.flow.error = mitm_flow.Error( + quic_event.reason_phrase + or error_code_to_str(quic_event.error_code) + ) + yield tcp_layer.TcpErrorHook(flow) + self.flow.live = False + + # close the peer as well and don't handle further events + peer_quic.close( quic_event.error_code, quic_event.frame_type, quic_event.reason_phrase, ) - while lookup_in: - stream_id_in = next(iter(lookup_in)) - stream_id_out, flow = lookup_in[stream_id_in] - yield from self.end_flow(flow, quic_event) - del lookup_in[stream_id_in] - del lookup_out[stream_id_out] - - if self.datagram_flow is not None: - yield from self.end_flow(flow, quic_event) - self.datagram_flow = None - - # forward datagrams (that are not stream-bound) + self._handle_event = self.state_done + elif isinstance(quic_event, quic_events.DatagramFrameReceived): - if self.datagram_flow is None: - self.datagram_flow = tcp.TCPFlow( - self.context.client, - self.context.server, - live=True, - ) - yield tcp_layer.TcpStartHook(self.datagram_flow) + # forward datagrams (that are not stream-bound) + if not self.flow.live: + return message = tcp.TCPMessage(from_client, quic_event.data) - self.datagram_flow.messages.append(message) - yield tcp_layer.TcpMessageHook(self.datagram_flow) - quic_out.send_datagram_frame(message.content) + self.flow.messages.append(message) + yield tcp_layer.TcpMessageHook(self.flow) + peer_quic.send_datagram_frame(message.content) - # forward stream data elif isinstance(quic_event, quic_events.StreamDataReceived): - # get or create the stream on the other side (and flow) - stream_id_in = quic_event.stream_id - if stream_id_in in lookup_in: - stream_id_out, flow = lookup_in[stream_id_in] - else: - stream_id_out = quic_out.get_next_available_stream_id() - flow = tcp.TCPFlow( - self.context.client, - self.context.server, - live=True, - ) - lookup_in[stream_id_in] = (stream_id_out, flow) - lookup_out[stream_id_out] = (stream_id_in, flow) - yield tcp_layer.TcpStartHook(flow) + # ignore data received from already ended streams + stream = yield from self.get_or_create_stream( + quic_event.stream_id, from_client + ) + if stream.has_ended(from_client): + return # forward the message allowing addons to change it message = tcp.TCPMessage(from_client, quic_event.data) - flow.messages.append(message) - yield tcp_layer.TcpMessageHook(flow) - quic_out.send_stream_data( - stream_id_out, - message.content, - quic_event.end_stream, + stream.flow.messages.append(message) + yield tcp_layer.TcpMessageHook(stream.flow) + peer_quic.send_stream_data( + stream_id=stream.stream_id(not from_client), + data=message.content, + end_stream=quic_event.end_stream, ) - # end the flow and remove the lookup if the stream ended + # mark the stream as ended if needed if quic_event.end_stream: - yield tcp_layer.TcpEndHook(flow) - flow.live = False - del lookup_in[stream_id_in] - del lookup_out[stream_id_out] + if from_client: + stream.client_ended = True + else: + stream.server_ended = True + + # end the flow if both legs ended + if stream.client_ended and stream.server_ended: + yield tcp_layer.TcpEndHook(stream.flow) + stream.flow.live = False - # forward resets to peer streams elif isinstance(quic_event, quic_events.StreamReset): - stream_id_in = quic_event.stream_id - if stream_id_in in lookup_in: - stream_id_out, flow = lookup_in[stream_id_in] - quic_out.stop_stream(stream_id_out, quic_event.error_code) - - # report the error to addons and delete the stream - flow.error = mitm_flow.Error( - error_code_to_str(quic_event.error_code) - ) - yield tcp_layer.TcpErrorHook(flow) - flow.live = False - del lookup_in[stream_id_in] - del lookup_out[stream_id_out] + # ignore resets from already ended streams + stream = yield from self.get_or_create_stream( + quic_event.stream_id, from_client + ) + if stream.has_ended(from_client): + return + + # forward resets to peer streams and report them to addons + peer_quic.reset_stream( + stream_id=stream.stream_id(not from_client), + error_code=quic_event.error_code, + ) + stream.flow.error = mitm_flow.Error( + error_code_to_str(quic_event.error_code) + ) + yield tcp_layer.TcpErrorHook(stream.flow) + stream.flow.live = False + + else: + # ignore other QUIC events + return + + # transmit data to the peer + yield QuicTransmit(peer_connection, peer_quic) + + else: + raise AssertionError(f"Unexpected event: {event!r}") + + @expect(QuicConnectionEvent, tcp_layer.TcpMessageInjected, events.ConnectionClosed) + def state_done(self, _) -> layer.CommandGenerator[None]: + yield from () + + _handle_event = state_start class _QuicLayer(layer.Layer): @@ -490,19 +636,10 @@ def handle_child_commands( # filter commands coming from the child layer for command in child_commands: - if ( - isinstance(command, QuicGetConnection) - and command.connection is self.conn - ): - # answer with the aioquic connection instance - assert self.quic is not None - yield from self.event_to_child( - QuicGetConnectionCompleted(command, self.quic) - ) - - elif isinstance(command, QuicTransmit) and command.connection is self.quic: + if isinstance(command, QuicTransmit) and command.connection is self.conn: # transmit buffered data and re-arm timer - yield from self.transmit() + if command.quic is self.quic: + yield from self.transmit() elif ( isinstance(command, commands.OpenConnection) @@ -589,10 +726,7 @@ def process_events(self) -> layer.CommandGenerator[None]: yield from self.destroy_quic( event.reason_phrase or error_code_to_str(event.error_code), level=( - "info" - if event.error_code - in (QuicErrorCode.NO_ERROR, H3ErrorCode.H3_NO_ERROR) - else "warn" + "info" if is_success_error_code(event.error_code) else "warn" ), ) yield commands.CloseConnection(self.conn) @@ -630,7 +764,7 @@ def process_events(self) -> layer.CommandGenerator[None]: yield from self.event_to_child(QuicConnectionEvent(self.conn, event)) else: - raise AssertionError(f"Unexpected event: {event}") + raise AssertionError(f"Unexpected event: {event!r}") # handle the next event event = self.quic.next_event() @@ -761,8 +895,12 @@ class ServerQuicLayer(_QuicLayer): This layer establishes QUIC for a single server connection. """ - def __init__(self, context: context.Context) -> None: + def __init__( + self, context: context.Context, child_layer: Optional[layer.Layer] = None + ) -> None: super().__init__(context, context.server) + if child_layer is not None: + self.child_layer = child_layer class ClientQuicLayer(_QuicLayer): @@ -895,6 +1033,6 @@ def state_wait_for_hello(self, event: events.Event) -> layer.CommandGenerator[No self._handle_event = self.state_done else: - raise AssertionError(f"Unexpected event: {event}") + raise AssertionError(f"Unexpected event: {event!r}") _handle_event = state_start From 569faf41d0f20257a96bdf1c05d3e4b132b9e883 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 27 Jun 2022 18:09:14 +0200 Subject: [PATCH 031/529] [quic] introduce QuicStart event --- mitmproxy/proxy/layers/http/_http3.py | 7 +- mitmproxy/proxy/layers/quic.py | 376 +++++++++++++++----------- 2 files changed, 219 insertions(+), 164 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index cb37086591..86689f726d 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -17,9 +17,9 @@ from mitmproxy.proxy import commands, context, events, layer from mitmproxy.proxy.layers.quic import ( QuicConnectionEvent, + QuicStart, QuicTransmit, error_code_to_str, - get_quic_connection, ) from mitmproxy.proxy.utils import expect @@ -60,7 +60,10 @@ class Http3Connection(HttpConnection): def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: if isinstance(event, events.Start): - self.quic = get_quic_connection(self.context, self.conn) + pass + + elif isinstance(event, QuicStart): + self.quic = event.quic self.h3_conn = H3Connection(self.quic, enable_webtransport=False) elif isinstance(event, events.ConnectionClosed): diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index e9bc9fe886..699cbbfdfe 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -102,6 +102,19 @@ class QuicConnectionEvent(events.ConnectionEvent): event: quic_events.QuicEvent +class QuicStart(events.DataReceived): + """ + Event that indicates that QUIC has been established on a given connection. + This inherits from `DataReceived` in order to trigger next layer behavior and initialize HTTP clients. + """ + + quic: QuicConnection + + def __init__(self, connection: connection.Connection, quic: QuicConnection) -> None: + super().__init__(connection, data=b"") + self.quic = quic + + class QuicTransmit(commands.ConnectionCommand): """Command that will transmit buffered data and re-arm the given QUIC connection's timer.""" @@ -143,21 +156,6 @@ def error_code_to_str(error_code: int) -> str: return f"unknown error (0x{error_code:x})" -def get_quic_connection( - context: context.Context, connection: connection.Connection -) -> QuicConnection: - """Retrieve the QUIC connection associated with the given connection in the given context.""" - - for quic_layer in context.layers: - if isinstance(quic_layer, _QuicLayer) and quic_layer.conn is connection: - if not quic_layer.conn.tls_established: - raise ValueError( - f"QUIC on connection {connection} has not been established yet." - ) - return quic_layer.quic - raise ValueError(f"Connection {connection} has no QUIC.") - - def is_success_error_code(error_code: int) -> bool: """Returns whether the given error code actually indicates no error.""" @@ -245,16 +243,21 @@ class QuicRelayLayer(layer.Layer): This layer is chosen by the default NextLayer addon if ALPN yields no known protocol. """ - # for now we're (ab)using the TCPFlow until https://github.com/mitmproxy/mitmproxy/pull/5414 is resolved + # NOTE: for now we're (ab)using the TCPFlow until https://github.com/mitmproxy/mitmproxy/pull/5414 is resolved + + buffer_from_client: List[quic_events.QuicEvent] + buffer_from_server: List[quic_events.QuicEvent] flow: tcp.TCPFlow # used for datagrams and to signal general connection issues + quic_client: Optional[QuicConnection] = None + quic_server: Optional[QuicConnection] = None streams_by_flow: Dict[tcp.TCPFlow, QuicRelayStream] streams_by_client_id: Dict[int, QuicRelayStream] streams_by_server_id: Dict[int, QuicRelayStream] - quic_client: Optional[QuicConnection] = None - quic_server: Optional[QuicConnection] = None def __init__(self, context: context.Context) -> None: super().__init__(context) + self.buffer_from_client = [] + self.buffer_from_server = [] self.flow = tcp.TCPFlow( self.context.client, self.context.server, @@ -299,10 +302,119 @@ def get_or_create_stream( yield tcp_layer.TcpStartHook(stream.flow) return stream + def handle_quic_event( + self, + event: quic_events.QuicEvent, + from_client: bool, + allow_buffering: bool, + ) -> layer.CommandGenerator[None]: + # get the peer connections + peer_connection = self.context.server if from_client else self.context.client + peer_quic = self.quic_server if from_client else self.quic_client + if peer_quic is None: + # buffer events since the peer is not ready yet + if not allow_buffering: + raise AssertionError( + f"Cannot buffer event from {'client' if from_client else 'server'}." + ) + if from_client: + self.buffer_from_client.append(event) + else: + self.buffer_from_server.append(event) + return + + if isinstance(event, quic_events.ConnectionTerminated): + # report the termination as error to all non-ended streams + for flow in self.streams_by_flow: + if flow.live: + self.flow.error = mitm_flow.Error( + "Connection terminated " + f" (code={event.error_code}, reason={event.reason_phrase})." + ) + yield tcp_layer.TcpErrorHook(flow) + flow.live = False + + # end the main flow + if self.flow.live: + if is_success_error_code(event.error_code): + yield tcp_layer.TcpEndHook(flow) + else: + self.flow.error = mitm_flow.Error( + event.reason_phrase or error_code_to_str(event.error_code) + ) + yield tcp_layer.TcpErrorHook(flow) + self.flow.live = False + + # close the peer as well and don't handle further events + peer_quic.close( + event.error_code, + event.frame_type, + event.reason_phrase, + ) + self._handle_event = self.state_done + + elif isinstance(event, quic_events.DatagramFrameReceived): + # forward datagrams (that are not stream-bound) + if not self.flow.live: + return + message = tcp.TCPMessage(from_client, event.data) + self.flow.messages.append(message) + yield tcp_layer.TcpMessageHook(self.flow) + peer_quic.send_datagram_frame(message.content) + + elif isinstance(event, quic_events.StreamDataReceived): + # ignore data received from already ended streams + stream = yield from self.get_or_create_stream(event.stream_id, from_client) + if stream.has_ended(from_client): + return + + # forward the message allowing addons to change it + message = tcp.TCPMessage(from_client, event.data) + stream.flow.messages.append(message) + yield tcp_layer.TcpMessageHook(stream.flow) + peer_quic.send_stream_data( + stream_id=stream.stream_id(not from_client), + data=message.content, + end_stream=event.end_stream, + ) + + # mark the stream as ended if needed + if event.end_stream: + if from_client: + stream.client_ended = True + else: + stream.server_ended = True + + # end the flow if both legs ended + if stream.client_ended and stream.server_ended: + yield tcp_layer.TcpEndHook(stream.flow) + stream.flow.live = False + + elif isinstance(event, quic_events.StreamReset): + # ignore resets from already ended streams + stream = yield from self.get_or_create_stream(event.stream_id, from_client) + if stream.has_ended(from_client): + return + + # forward resets to peer streams and report them to addons + peer_quic.reset_stream( + stream_id=stream.stream_id(not from_client), + error_code=event.error_code, + ) + stream.flow.error = mitm_flow.Error(error_code_to_str(event.error_code)) + yield tcp_layer.TcpErrorHook(stream.flow) + stream.flow.live = False + + else: + # ignore other QUIC events + return + + # transmit data to the peer + yield QuicTransmit(peer_connection, peer_quic) + @expect(events.Start) def state_start(self, _) -> layer.CommandGenerator[None]: - # retrieve the client QUIC connection and mark the main flow as started - self.quic_client = get_quic_connection(self.context, self.context.client) + # mark the main flow as started yield tcp_layer.TcpStartHook(self.flow) # open the upstream connection if necessary @@ -315,29 +427,49 @@ def state_start(self, _) -> layer.CommandGenerator[None]: yield commands.CloseConnection(self.context.client) self._handle_event = self.state_done return - self.quic_server = get_quic_connection(self.context, self.context.server) self._handle_event = self.state_ready - @expect(QuicConnectionEvent, tcp_layer.TcpMessageInjected) + @expect(QuicStart, QuicConnectionEvent, tcp_layer.TcpMessageInjected) def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: - if isinstance(event, tcp_layer.TcpMessageInjected): + if isinstance(event, QuicStart): + # QUIC connection has been established, store it and flush buffered events + if event.connection is self.context.client: + assert self.quic_client is None + self.quic_client = event.quic + for quic_event in self.buffer_from_server: + yield from self.handle_quic_event( + quic_event, + from_client=False, + allow_buffering=False, + ) + elif event.connection is self.context.server: + assert self.quic_server is None + self.quic_server = event.quic + for quic_event in self.buffer_from_client: + yield from self.handle_quic_event( + quic_event, + from_client=True, + allow_buffering=False, + ) + else: + raise AssertionError( + f"Connection {event.connection} not associated with layer." + ) + + elif isinstance(event, tcp_layer.TcpMessageInjected): # translate injected messages into QUIC events flow = event.flow assert isinstance(flow, tcp.TCPFlow) - connection = ( - self.context.client - if event.message.from_client - else self.context.server - ) if flow is self.flow: - event = QuicConnectionEvent( - connection, quic_events.DatagramFrameReceived(event.message.content) + yield from self.handle_quic_event( + quic_events.DatagramFrameReceived(event.message.content), + event.message.from_client, + allow_buffering=True, ) elif flow in self.streams_by_flow: stream = self.streams_by_flow[flow] - event = QuicConnectionEvent( - connection, + yield from self.handle_quic_event( quic_events.StreamDataReceived( stream_id=( stream.client_id @@ -347,121 +479,31 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: data=event.message.content, end_stream=False, ), + event.message.from_client, + allow_buffering=True, ) else: - # only handle messages of known flows - return - - if isinstance(event, QuicConnectionEvent): - # define helper variables - quic_event = event.event - from_client = event.connection is self.context.client - peer_connection = ( - self.context.server if from_client else self.context.client - ) - peer_quic = self.quic_server if from_client else self.quic_client - assert peer_quic is not None - - if isinstance(quic_event, quic_events.ConnectionTerminated): - # report the termination as error to all non-ended streams - for flow in self.streams_by_flow: - if flow.live: - self.flow.error = mitm_flow.Error( - "Connection terminated " - f" (code={quic_event.error_code}, reason={quic_event.reason_phrase})." - ) - yield tcp_layer.TcpErrorHook(flow) - flow.live = False - - # end the main flow - if self.flow.live: - if is_success_error_code(quic_event.error_code): - yield tcp_layer.TcpEndHook(flow) - else: - self.flow.error = mitm_flow.Error( - quic_event.reason_phrase - or error_code_to_str(quic_event.error_code) - ) - yield tcp_layer.TcpErrorHook(flow) - self.flow.live = False - - # close the peer as well and don't handle further events - peer_quic.close( - quic_event.error_code, - quic_event.frame_type, - quic_event.reason_phrase, + raise AssertionError( + f"Flow {event.flow} not associated with the current layer." ) - self._handle_event = self.state_done - elif isinstance(quic_event, quic_events.DatagramFrameReceived): - # forward datagrams (that are not stream-bound) - if not self.flow.live: - return - message = tcp.TCPMessage(from_client, quic_event.data) - self.flow.messages.append(message) - yield tcp_layer.TcpMessageHook(self.flow) - peer_quic.send_datagram_frame(message.content) - - elif isinstance(quic_event, quic_events.StreamDataReceived): - # ignore data received from already ended streams - stream = yield from self.get_or_create_stream( - quic_event.stream_id, from_client - ) - if stream.has_ended(from_client): - return - - # forward the message allowing addons to change it - message = tcp.TCPMessage(from_client, quic_event.data) - stream.flow.messages.append(message) - yield tcp_layer.TcpMessageHook(stream.flow) - peer_quic.send_stream_data( - stream_id=stream.stream_id(not from_client), - data=message.content, - end_stream=quic_event.end_stream, - ) - - # mark the stream as ended if needed - if quic_event.end_stream: - if from_client: - stream.client_ended = True - else: - stream.server_ended = True - - # end the flow if both legs ended - if stream.client_ended and stream.server_ended: - yield tcp_layer.TcpEndHook(stream.flow) - stream.flow.live = False - - elif isinstance(quic_event, quic_events.StreamReset): - # ignore resets from already ended streams - stream = yield from self.get_or_create_stream( - quic_event.stream_id, from_client - ) - if stream.has_ended(from_client): - return - - # forward resets to peer streams and report them to addons - peer_quic.reset_stream( - stream_id=stream.stream_id(not from_client), - error_code=quic_event.error_code, - ) - stream.flow.error = mitm_flow.Error( - error_code_to_str(quic_event.error_code) - ) - yield tcp_layer.TcpErrorHook(stream.flow) - stream.flow.live = False - - else: - # ignore other QUIC events - return - - # transmit data to the peer - yield QuicTransmit(peer_connection, peer_quic) + elif isinstance(event, QuicConnectionEvent): + # handle or buffer QUIC events + yield from self.handle_quic_event( + event.event, + from_client=event.connection is self.context.client, + allow_buffering=True, + ) else: raise AssertionError(f"Unexpected event: {event!r}") - @expect(QuicConnectionEvent, tcp_layer.TcpMessageInjected, events.ConnectionClosed) + @expect( + QuicStart, + QuicConnectionEvent, + tcp_layer.TcpMessageInjected, + events.ConnectionClosed, + ) def state_done(self, _) -> layer.CommandGenerator[None]: yield from () @@ -491,6 +533,7 @@ def __init__( Tuple[commands.RequestWakeup, float] ] = None self._obsolete_wakeup_commands: Set[commands.RequestWakeup] = set() + self.conn.tls = True def build_configuration(self) -> QuicConfiguration: assert self.tls is not None @@ -628,13 +671,8 @@ def establish_quic( ) def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: - yield from self.handle_child_commands(self.child_layer.handle_event(event)) - - def handle_child_commands( - self, child_commands: layer.CommandGenerator[None] - ) -> layer.CommandGenerator[None]: # filter commands coming from the child layer - for command in child_commands: + for command in self.child_layer.handle_event(event): if isinstance(command, QuicTransmit) and command.connection is self.conn: # transmit buffered data and re-arm timer @@ -729,6 +767,21 @@ def process_events(self) -> layer.CommandGenerator[None]: "info" if is_success_error_code(event.error_code) else "warn" ), ) + if self.conn is self.context.client: + # once the client connection is closed, all servers are terminated immediately + # use this opportunity to properly shutdown QUIC connections + for quic_layer in self.context.layers: + if ( + isinstance(quic_layer, _QuicLayer) + and quic_layer.conn is not self.context.client + and quic_layer.quic is not None + ): + quic_layer.quic.close( + event.error_code, + event.frame_type, + event.reason_phrase, + ) + yield from quic_layer.transmit() yield commands.CloseConnection(self.conn) # we don't handle any further events, nor do/can we transmit data, so exit @@ -738,10 +791,7 @@ def process_events(self) -> layer.CommandGenerator[None]: # set all TLS fields and notify the child layer yield from self.establish_quic(event) yield from self.open_connection_end(None) - - # perform next layer decisions now - if isinstance(self.child_layer, layer.NextLayer): - yield from self.handle_child_commands(self.child_layer._ask()) + yield from self.event_to_child(QuicStart(self.conn, self.quic)) elif isinstance(event, quic_events.PingAcknowledged): # we let aioquic do it's thing but don't really care ourselves @@ -789,27 +839,27 @@ def state_has_quic(self, event: events.Event) -> layer.CommandGenerator[None]: elif ( isinstance(event, events.ConnectionClosed) and event.connection is self.conn ): - # handle connections closed by peer (which in UDP's case is usually a timeout) - reason = "Peer UDP connection closed or timed out." - # there is no point in calling quic.close, as it cannot send packets anymore # set the new connection state and simulate a ConnectionTerminated event (if established) + close_event = self.quic._close_event + if close_event is None: + close_event = quic_events.ConnectionTerminated( + error_code=QuicErrorCode.APPLICATION_ERROR, + frame_type=None, + reason_phrase="Peer UDP connection closed or timed out.", + ) self.quic._set_state(QuicConnectionState.TERMINATED) if self.conn.tls_established: yield from self.event_to_child( - QuicConnectionEvent( - self.conn, - quic_events.ConnectionTerminated( - error_code=QuicErrorCode.APPLICATION_ERROR, - frame_type=None, - reason_phrase=reason, - ), - ) + QuicConnectionEvent(self.conn, close_event) ) # shutdown QUIC and handle the ConnectionClosed event - yield from self.destroy_quic(reason, level="info") - if not (yield from self.open_connection_end(reason)): + yield from self.destroy_quic( + close_event.reason_phrase or error_code_to_str(close_event.error_code), + level="info", + ) + if not (yield from self.open_connection_end(close_event.reason_phrase)): # connection was opened before QUIC layer, report to the child layer yield from self.event_to_child(event) @@ -955,6 +1005,8 @@ def __init__( super().__init__(context) self._issue_cid = issue_cid self._retire_cid = retire_cid + self.context.client.tls = True + self.context.server.tls = True def build_client_layer( self, connection_id: bytes, wait_for_upstream: bool From 8e71b0331b8de95c4204d5cc26fb07e967883972 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 27 Jun 2022 19:30:59 +0200 Subject: [PATCH 032/529] [quic] add is_http3 where necessary --- mitmproxy/addons/dumper.py | 2 +- mitmproxy/addons/next_layer.py | 12 +++++------- mitmproxy/http.py | 8 ++++---- mitmproxy/proxy/layers/http/__init__.py | 2 +- mitmproxy/proxy/layers/http/_http1.py | 6 +++--- mitmproxy/proxy/layers/http/_http2.py | 2 +- 6 files changed, 15 insertions(+), 17 deletions(-) diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py index 7da08ecf4c..a35a350565 100644 --- a/mitmproxy/addons/dumper.py +++ b/mitmproxy/addons/dumper.py @@ -200,7 +200,7 @@ def _echo_response_line(self, flow: http.HTTPFlow) -> None: blink=(code_int == 418), ) - if not flow.response.is_http2: + if not (flow.response.is_http2 or flow.response.is_http3): reason = flow.response.reason else: reason = http.status_codes.RESPONSES.get(flow.response.status_code, "") diff --git a/mitmproxy/addons/next_layer.py b/mitmproxy/addons/next_layer.py index 9b78292253..f067a48ec7 100644 --- a/mitmproxy/addons/next_layer.py +++ b/mitmproxy/addons/next_layer.py @@ -130,13 +130,11 @@ def _next_layer( else: return None return layers.HttpLayer(context, mode) - else: - if context.server.address is None: - return None - if isinstance(context.layers[1], quic.ServerQuicLayer): - return quic.QuicRelayLayer(context) - else: - return quic.ServerQuicLayer(context, quic.QuicRelayLayer(context)) + if context.server.address is None: + return None # not H3 and no predefined destination, nothing we can do + if isinstance(context.layers[1], quic.ServerQuicLayer): + return quic.QuicRelayLayer(context) # server layer already present + return quic.ServerQuicLayer(context, quic.QuicRelayLayer(context)) if len(context.layers) == 0: return self.make_top_layer(context) diff --git a/mitmproxy/http.py b/mitmproxy/http.py index 394cddbdd2..151dcdccb9 100644 --- a/mitmproxy/http.py +++ b/mitmproxy/http.py @@ -767,7 +767,7 @@ def host_header(self) -> Optional[str]: *See also:* `Request.authority`,`Request.host`, `Request.pretty_host` """ - if self.is_http2: + if self.is_http2 or self.is_http3: return self.authority or self.data.headers.get("Host", None) else: return self.data.headers.get("Host", None) @@ -775,13 +775,13 @@ def host_header(self) -> Optional[str]: @host_header.setter def host_header(self, val: Union[None, str, bytes]) -> None: if val is None: - if self.is_http2: + if self.is_http2 or self.is_http3: self.data.authority = b"" self.headers.pop("Host", None) else: - if self.is_http2: + if self.is_http2 or self.is_http3: self.authority = val # type: ignore - if not self.is_http2 or "Host" in self.headers: + if not (self.is_http2 or self.is_http3) or "Host" in self.headers: # For h2, we only overwrite, but not create, as :authority is the h2 host header. self.headers["Host"] = val diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index 6b887269f1..69d98255a4 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -220,7 +220,7 @@ def state_wait_for_request_headers( "https" if self.context.client.tls else "http" ) - if self.mode is HTTPMode.regular and not self.flow.request.is_http2: + if self.mode is HTTPMode.regular and not (self.flow.request.is_http2 or self.flow.request.is_http3): # Set the request target to origin-form for HTTP/1, some servers don't support absolute-form requests. # see https://github.com/mitmproxy/mitmproxy/issues/1759 self.flow.request.authority = "" diff --git a/mitmproxy/proxy/layers/http/_http1.py b/mitmproxy/proxy/layers/http/_http1.py index 4cab5bd9b6..fa49fc3f23 100644 --- a/mitmproxy/proxy/layers/http/_http1.py +++ b/mitmproxy/proxy/layers/http/_http1.py @@ -189,7 +189,7 @@ def mark_done( # If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request. # This simplifies our connection management quite a bit as we can rely on # the proxyserver's max-connection-per-server throttling. - or (self.request.is_http2 and isinstance(self, Http1Client)) + or ((self.request.is_http2 or self.request.is_http3) and isinstance(self, Http1Client)) ) if connection_done: yield commands.CloseConnection(self.conn) @@ -223,7 +223,7 @@ def send(self, event: HttpEvent) -> layer.CommandGenerator[None]: if isinstance(event, ResponseHeaders): self.response = response = event.response - if response.is_http2: + if response.is_http2 or response.is_http3: response = response.copy() # Convert to an HTTP/1 response. response.http_version = "HTTP/1.1" @@ -340,7 +340,7 @@ def send(self, event: HttpEvent) -> layer.CommandGenerator[None]: if isinstance(event, RequestHeaders): request = event.request - if request.is_http2: + if request.is_http2 or request.is_http3: # Convert to an HTTP/1 request. request = ( request.copy() diff --git a/mitmproxy/proxy/layers/http/_http2.py b/mitmproxy/proxy/layers/http/_http2.py index 785b0e1e36..3307d9cfb5 100644 --- a/mitmproxy/proxy/layers/http/_http2.py +++ b/mitmproxy/proxy/layers/http/_http2.py @@ -347,7 +347,7 @@ def format_h2_response_headers( (b":status", b"%d" % event.response.status_code), *event.response.headers.fields, ] - if event.response.is_http2: + if event.response.is_http2 or event.response.is_http3: if context.options.normalize_outbound_headers: yield from normalize_h2_headers(headers) else: From 2a221ad9b7c376135386af4eecf444182ae0322f Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 27 Jun 2022 19:54:48 +0200 Subject: [PATCH 033/529] [quic] reworked close handling --- mitmproxy/proxy/layers/http/_http3.py | 66 ++++++------ mitmproxy/proxy/layers/quic.py | 138 ++++++++++++-------------- 2 files changed, 100 insertions(+), 104 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 86689f726d..14b8c0b7f4 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -58,21 +58,26 @@ class Http3Connection(HttpConnection): ReceiveProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]] ReceiveTrailers: type[Union[RequestTrailers, ResponseTrailers]] - def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: - if isinstance(event, events.Start): - pass + @expect(events.Start) + def state_start(self, _) -> layer.CommandGenerator[None]: + self._handle_event = self.state_wait_for_quic + yield from () - elif isinstance(event, QuicStart): - self.quic = event.quic - self.h3_conn = H3Connection(self.quic, enable_webtransport=False) + @expect(QuicStart) + def state_wait_for_quic(self, event: events.Event) -> layer.CommandGenerator[None]: + assert isinstance(event, QuicStart) + self.quic = event.quic + self.h3_conn = H3Connection(self.quic, enable_webtransport=False) + self._handle_event = self.state_ready + yield from () - elif isinstance(event, events.ConnectionClosed): - self._handle_event = self.done # type: ignore + @expect(HttpEvent, QuicConnectionEvent, events.ConnectionClosed) + def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: + assert self.quic is not None + assert self.h3_conn is not None # send mitmproxy HTTP events over the H3 connection - elif isinstance(event, HttpEvent): - assert self.quic is not None - assert self.h3_conn is not None + if isinstance(event, HttpEvent): try: if isinstance(event, (RequestData, ResponseData)): @@ -120,8 +125,6 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: # handle events from the underlying QUIC connection elif isinstance(event, QuicConnectionEvent): - assert self.quic is not None - assert self.h3_conn is not None # report abrupt stream resets if isinstance(event, quic_events.StreamReset): @@ -146,21 +149,6 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: ) ) - # report a protocol error for all remaining open streams when a connection is terminated - elif isinstance(event, quic_events.ConnectionTerminated): - for stream in self.h3_conn._stream.values(): - if ( - self.quic._stream_can_receive(stream.stream_id) - and not stream.ended - ): - yield ReceiveHttp( - self.ReceiveProtocolError( - stream_id=stream.stream_id, - message=event.reason_phrase, - code=event.error_code, - ) - ) - # forward QUIC events to the H3 connection for h3_event in self.h3_conn.handle_event(event.event): @@ -197,6 +185,7 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: error_code=H3ErrorCode.H3_GENERAL_PROTOCOL_ERROR, reason_phrase=f"Invalid HTTP/3 request headers: {e}", ) + yield QuicTransmit(self.conn, self.quic) else: yield ReceiveHttp(receive_event) @@ -210,11 +199,26 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: else: yield commands.Log(f"Ignored unsupported H3 event: {h3_event!r}") + # report a protocol error for all remaining open streams when a connection is closed + elif isinstance(event, events.ConnectionClosed): + for stream in self.h3_conn._stream.values(): + if self.quic._stream_can_receive(stream.stream_id) and not stream.ended: + close_event = self.quic._close_event + assert close_event is not None + yield ReceiveHttp( + self.ReceiveProtocolError( + stream_id=stream.stream_id, + message=close_event.reason_phrase, + code=close_event.error_code, + ) + ) + self._handle_event = self.state_done + else: raise AssertionError(f"Unexpected event: {event!r}") - @expect(events.DataReceived, HttpEvent, events.ConnectionClosed) - def done(self, _) -> layer.CommandGenerator[None]: + @expect(HttpEvent, QuicConnectionEvent, events.ConnectionClosed) + def state_done(self, _) -> layer.CommandGenerator[None]: yield from () @abstractmethod @@ -229,6 +233,8 @@ def headers_received( ) -> Union[RequestHeaders, ResponseHeaders]: pass # pragma: no cover + _handle_event = state_start + class Http3Server(Http3Connection): ReceiveData = RequestData diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 699cbbfdfe..d35963896f 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -91,7 +91,6 @@ class QuicTlsStartServerHook(commands.StartHook): class QuicConnectionEvent(events.ConnectionEvent): """ Connection-based event that is triggered whenever a new event from QUIC is received. - Established connections are guaranteed to receive a ConnectionTerminated event at the end. Note: 'Established' means that an OpenConnection command called in a child layer returned no error. @@ -308,11 +307,9 @@ def handle_quic_event( from_client: bool, allow_buffering: bool, ) -> layer.CommandGenerator[None]: - # get the peer connections - peer_connection = self.context.server if from_client else self.context.client + # buffer events if the peer is not ready yet peer_quic = self.quic_server if from_client else self.quic_client if peer_quic is None: - # buffer events since the peer is not ready yet if not allow_buffering: raise AssertionError( f"Cannot buffer event from {'client' if from_client else 'server'}." @@ -322,38 +319,9 @@ def handle_quic_event( else: self.buffer_from_server.append(event) return + peer_connection = self.context.server if from_client else self.context.client - if isinstance(event, quic_events.ConnectionTerminated): - # report the termination as error to all non-ended streams - for flow in self.streams_by_flow: - if flow.live: - self.flow.error = mitm_flow.Error( - "Connection terminated " - f" (code={event.error_code}, reason={event.reason_phrase})." - ) - yield tcp_layer.TcpErrorHook(flow) - flow.live = False - - # end the main flow - if self.flow.live: - if is_success_error_code(event.error_code): - yield tcp_layer.TcpEndHook(flow) - else: - self.flow.error = mitm_flow.Error( - event.reason_phrase or error_code_to_str(event.error_code) - ) - yield tcp_layer.TcpErrorHook(flow) - self.flow.live = False - - # close the peer as well and don't handle further events - peer_quic.close( - event.error_code, - event.frame_type, - event.reason_phrase, - ) - self._handle_event = self.state_done - - elif isinstance(event, quic_events.DatagramFrameReceived): + if isinstance(event, quic_events.DatagramFrameReceived): # forward datagrams (that are not stream-bound) if not self.flow.live: return @@ -429,10 +397,56 @@ def state_start(self, _) -> layer.CommandGenerator[None]: return self._handle_event = self.state_ready - @expect(QuicStart, QuicConnectionEvent, tcp_layer.TcpMessageInjected) + @expect( + QuicStart, + QuicConnectionEvent, + tcp_layer.TcpMessageInjected, + events.ConnectionClosed, + ) def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: - if isinstance(event, QuicStart): + if isinstance(event, events.ConnectionClosed): + # define helper variables + from_client = event.connection is self.context.client + peer_conn = self.context.server if from_client else self.context.client + local_quic = self.quic_client if from_client else self.quic_server + peer_quic = self.quic_server if from_client else self.quic_client + assert local_quic is not None + close_event = local_quic._close_event + assert close_event is not None + + # report the termination as error to all non-ended streams + for flow in self.streams_by_flow: + if flow.live: + self.flow.error = mitm_flow.Error("Connection closed.") + yield tcp_layer.TcpErrorHook(flow) + flow.live = False + + # end the main flow + if self.flow.live: + if is_success_error_code(close_event.error_code): + yield tcp_layer.TcpEndHook(flow) + else: + self.flow.error = mitm_flow.Error( + close_event.reason_phrase + or error_code_to_str(close_event.error_code) + ) + yield tcp_layer.TcpErrorHook(flow) + self.flow.live = False + + # close the peer as well + if peer_quic is not None: + peer_quic.close( + close_event.error_code, + close_event.frame_type, + close_event.reason_phrase, + ) + yield QuicTransmit(peer_conn, peer_quic) + else: + yield commands.CloseConnection(peer_conn) + self._handle_event = self.state_done + + elif isinstance(event, QuicStart): # QUIC connection has been established, store it and flush buffered events if event.connection is self.context.client: assert self.quic_client is None @@ -459,16 +473,15 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: elif isinstance(event, tcp_layer.TcpMessageInjected): # translate injected messages into QUIC events - flow = event.flow - assert isinstance(flow, tcp.TCPFlow) - if flow is self.flow: + assert isinstance(event.flow, tcp.TCPFlow) + if event.flow is self.flow: yield from self.handle_quic_event( quic_events.DatagramFrameReceived(event.message.content), event.message.from_client, allow_buffering=True, ) - elif flow in self.streams_by_flow: - stream = self.streams_by_flow[flow] + elif event.flow in self.streams_by_flow: + stream = self.streams_by_flow[event.flow] yield from self.handle_quic_event( quic_events.StreamDataReceived( stream_id=( @@ -754,12 +767,6 @@ def process_events(self) -> layer.CommandGenerator[None]: self.retire_connection_id_callback(event.connection_id) elif isinstance(event, quic_events.ConnectionTerminated): - # only forward the event if the connection has been properly initialized - if self.conn.tls_established: - yield from self.event_to_child( - QuicConnectionEvent(self.conn, event) - ) - # shutdown and close the connection yield from self.destroy_quic( event.reason_phrase or error_code_to_str(event.error_code), @@ -767,21 +774,6 @@ def process_events(self) -> layer.CommandGenerator[None]: "info" if is_success_error_code(event.error_code) else "warn" ), ) - if self.conn is self.context.client: - # once the client connection is closed, all servers are terminated immediately - # use this opportunity to properly shutdown QUIC connections - for quic_layer in self.context.layers: - if ( - isinstance(quic_layer, _QuicLayer) - and quic_layer.conn is not self.context.client - and quic_layer.quic is not None - ): - quic_layer.quic.close( - event.error_code, - event.frame_type, - event.reason_phrase, - ) - yield from quic_layer.transmit() yield commands.CloseConnection(self.conn) # we don't handle any further events, nor do/can we transmit data, so exit @@ -840,26 +832,24 @@ def state_has_quic(self, event: events.Event) -> layer.CommandGenerator[None]: isinstance(event, events.ConnectionClosed) and event.connection is self.conn ): # there is no point in calling quic.close, as it cannot send packets anymore - # set the new connection state and simulate a ConnectionTerminated event (if established) - close_event = self.quic._close_event - if close_event is None: - close_event = quic_events.ConnectionTerminated( + # just set the new connection state and ensure there is exists a close event + self.quic._set_state(QuicConnectionState.TERMINATED) + if self.quic._close_event is None: + self.quic._close_event = quic_events.ConnectionTerminated( error_code=QuicErrorCode.APPLICATION_ERROR, frame_type=None, reason_phrase="Peer UDP connection closed or timed out.", ) - self.quic._set_state(QuicConnectionState.TERMINATED) - if self.conn.tls_established: - yield from self.event_to_child( - QuicConnectionEvent(self.conn, close_event) - ) # shutdown QUIC and handle the ConnectionClosed event + reason = self.quic._close_event.reason_phrase or error_code_to_str( + self.quic._close_event.error_code + ) yield from self.destroy_quic( - close_event.reason_phrase or error_code_to_str(close_event.error_code), + reason, level="info", ) - if not (yield from self.open_connection_end(close_event.reason_phrase)): + if not (yield from self.open_connection_end(reason)): # connection was opened before QUIC layer, report to the child layer yield from self.event_to_child(event) From 0f33e70a82f2b955b7521e9aaa1c08bc401221a0 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Mon, 27 Jun 2022 20:14:41 +0200 Subject: [PATCH 034/529] [quic] fix empty layers issue --- mitmproxy/addons/next_layer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mitmproxy/addons/next_layer.py b/mitmproxy/addons/next_layer.py index f067a48ec7..f8a6b00324 100644 --- a/mitmproxy/addons/next_layer.py +++ b/mitmproxy/addons/next_layer.py @@ -117,6 +117,10 @@ def next_layer(self, nextlayer: layer.NextLayer): def _next_layer( self, context: context.Context, data_client: bytes, data_server: bytes ) -> Optional[layer.Layer]: + if len(context.layers) == 0: + return self.make_top_layer(context) + + # handle QUIC connections if isinstance(context.layers[0], quic.QuicLayer): if context.client.alpn is None: return None # should never happen, as ask is called after handshake @@ -136,9 +140,6 @@ def _next_layer( return quic.QuicRelayLayer(context) # server layer already present return quic.ServerQuicLayer(context, quic.QuicRelayLayer(context)) - if len(context.layers) == 0: - return self.make_top_layer(context) - if len(data_client) < 3 and not data_server: return None # not enough data yet to make a decision From 46195ce4a6d0e803e3a0593325647cdf8990e518 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Tue, 28 Jun 2022 04:08:30 +0200 Subject: [PATCH 035/529] [quic] remove stream ID mapping --- mitmproxy/proxy/layers/http/_http3.py | 26 ++--------------------- test/mitmproxy/addons/test_proxyserver.py | 6 +++--- test/mitmproxy/net/test_udp.py | 4 ++-- 3 files changed, 7 insertions(+), 29 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 14b8c0b7f4..7ac4522f58 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -1,6 +1,6 @@ from abc import abstractmethod import time -from typing import Dict, Optional, Union +from typing import Optional, Union from aioquic.h3.connection import ( H3Connection, @@ -10,7 +10,7 @@ ) from aioquic.h3 import events as h3_events from aioquic.quic import events as quic_events -from aioquic.quic.connection import QuicConnection, stream_is_unidirectional +from aioquic.quic.connection import QuicConnection from mitmproxy import http, version from mitmproxy.net.http import status_codes @@ -307,13 +307,8 @@ class Http3Client(Http3Connection): ReceiveProtocolError = ResponseProtocolError ReceiveTrailers = ResponseTrailers - our_stream_id: Dict[int, int] - their_stream_id: Dict[int, int] - def __init__(self, context: context.Context): super().__init__(context, context.server) - self.our_stream_id = {} - self.their_stream_id = {} def protocol_error( self, event: Union[RequestProtocolError, ResponseProtocolError] @@ -346,23 +341,6 @@ def headers_received( stream_id=event.stream_id, response=response, end_stream=event.stream_ended ) - def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: - # translate stream IDs just like HTTP/2 client - if isinstance(event, HttpEvent): - assert self.quic is not None - ours = self.our_stream_id.get(event.stream_id, None) - if ours is None: - ours = self.quic.get_next_available_stream_id( - is_unidirectional=stream_is_unidirectional(event.stream_id) - ) - self.our_stream_id[event.stream_id] = ours - self.their_stream_id[ours] = event.stream_id - event.stream_id = ours - for cmd in super()._handle_event(event): - if isinstance(cmd, ReceiveHttp): - cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id] - yield cmd - __all__ = [ "Http3Client", diff --git a/test/mitmproxy/addons/test_proxyserver.py b/test/mitmproxy/addons/test_proxyserver.py index 9abeb983ec..5edc333f56 100644 --- a/test/mitmproxy/addons/test_proxyserver.py +++ b/test/mitmproxy/addons/test_proxyserver.py @@ -266,16 +266,16 @@ async def test_dns() -> None: await tctx.master.await_log("Invalid DNS datagram received", level="info") req = tdnsreq() w.write(req.packed) - resp = dns.Message.unpack(await r.read(udp.MAX_DATAGRAM_SIZE)) + resp = dns.Message.unpack((await r.read(udp.MAX_DATAGRAM_SIZE))[0]) assert req.id == resp.id and "8.8.8.8" in str(resp) assert len(ps._connections) == 1 w.write(req.packed) - resp = dns.Message.unpack(await r.read(udp.MAX_DATAGRAM_SIZE)) + resp = dns.Message.unpack((await r.read(udp.MAX_DATAGRAM_SIZE))[0]) assert req.id == resp.id and "8.8.8.8" in str(resp) assert len(ps._connections) == 1 req.id = req.id + 1 w.write(req.packed) - resp = dns.Message.unpack(await r.read(udp.MAX_DATAGRAM_SIZE)) + resp = dns.Message.unpack((await r.read(udp.MAX_DATAGRAM_SIZE))[0]) assert req.id == resp.id and "8.8.8.8" in str(resp) assert len(ps._connections) == 2 await ps.shutdown_server() diff --git a/test/mitmproxy/net/test_udp.py b/test/mitmproxy/net/test_udp.py index 1db5a60997..c7f0ef7338 100644 --- a/test/mitmproxy/net/test_udp.py +++ b/test/mitmproxy/net/test_udp.py @@ -12,8 +12,8 @@ async def test_reader(): reader.feed_data(bytearray(MAX_DATAGRAM_SIZE + 1), addr) reader.feed_data(b"Second message", addr) reader.feed_eof() - assert await reader.read(65535) == b"First message" + assert await reader.read(65535) == (b"First message", addr) with pytest.raises(AssertionError): await reader.read(MAX_DATAGRAM_SIZE - 1) - assert await reader.read(65535) == b"Second message" + assert await reader.read(65535) == (b"Second message", addr) assert not await reader.read(65535) From 8a1355e4a118a672bff0681ddfe351cf9d90c7e9 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Tue, 28 Jun 2022 15:26:56 +0200 Subject: [PATCH 036/529] [quic] preserve stream IDs in relay layer --- mitmproxy/proxy/layers/quic.py | 143 +++++++++++++-------------------- test/mitmproxy/net/test_udp.py | 2 +- 2 files changed, 57 insertions(+), 88 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index d35963896f..83d32ef819 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -12,6 +12,7 @@ QuicConnectionError, QuicConnectionState, QuicErrorCode, + stream_is_client_initiated, stream_is_unidirectional, ) from aioquic.tls import CipherSuite, HandshakeType @@ -220,21 +221,32 @@ def initialize_replacement(peer_cid: bytes) -> None: raise ValueError("No ClientHello returned.") -@dataclass class QuicRelayStream: - client_ended: bool - client_id: int flow: tcp.TCPFlow - server_ended: bool - server_id: int + stream_id: int - def stream_id(self, client: bool) -> int: - return self.client_id if client else self.server_id + def __init__(self, context: context.Context, stream_id: int) -> None: + self.flow = tcp.TCPFlow( + context.client, + context.server, + live=True, + ) + self.stream_id = stream_id + is_unidirectional = stream_is_unidirectional(stream_id) + from_client = stream_is_client_initiated(stream_id) + self._ended_client = is_unidirectional and not from_client + self._ended_server = is_unidirectional and from_client def has_ended(self, client: bool) -> bool: - stream_ended = self.client_ended if client else self.server_ended + stream_ended = self._ended_client if client else self._ended_server return stream_ended or not self.flow.live + def mark_ended(self, client: bool) -> None: + if client: + self._ended_client = True + else: + self._ended_server = True + class QuicRelayLayer(layer.Layer): """ @@ -244,14 +256,13 @@ class QuicRelayLayer(layer.Layer): # NOTE: for now we're (ab)using the TCPFlow until https://github.com/mitmproxy/mitmproxy/pull/5414 is resolved - buffer_from_client: List[quic_events.QuicEvent] - buffer_from_server: List[quic_events.QuicEvent] + buffer_from_client: Optional[List[quic_events.QuicEvent]] + buffer_from_server: Optional[List[quic_events.QuicEvent]] flow: tcp.TCPFlow # used for datagrams and to signal general connection issues quic_client: Optional[QuicConnection] = None quic_server: Optional[QuicConnection] = None streams_by_flow: Dict[tcp.TCPFlow, QuicRelayStream] - streams_by_client_id: Dict[int, QuicRelayStream] - streams_by_server_id: Dict[int, QuicRelayStream] + streams_by_id: Dict[int, QuicRelayStream] def __init__(self, context: context.Context) -> None: super().__init__(context) @@ -263,41 +274,18 @@ def __init__(self, context: context.Context) -> None: live=True, ) self.streams_by_flow = {} - self.streams_by_client_id = {} - self.streams_by_server_id = {} + self.streams_by_id = {} def get_or_create_stream( - self, stream_id: int, from_client: bool + self, stream_id: int ) -> layer.CommandGenerator[QuicRelayStream]: - streams_by_id = ( - self.streams_by_client_id if from_client else self.streams_by_server_id - ) - if stream_id in streams_by_id: - return streams_by_id[stream_id] + if stream_id in self.streams_by_id: + return self.streams_by_id[stream_id] else: - # reserve the peer stream id - is_unidirectional = stream_is_unidirectional(stream_id) - peer_quic = self.quic_server if from_client else self.quic_client - assert peer_quic - peer_stream_id = peer_quic.get_next_available_stream_id(is_unidirectional) - - # create the instance and make sure unidirectional streams are marked as ended - stream = QuicRelayStream( - flow=tcp.TCPFlow( - self.context.client, - self.context.server, - live=True, - ), - client_ended=is_unidirectional and not from_client, - server_ended=is_unidirectional and from_client, - client_id=stream_id if from_client else peer_stream_id, - server_id=peer_stream_id if from_client else stream_id, - ) - # register the stream and start the flow + stream = QuicRelayStream(self.context, stream_id) self.streams_by_flow[stream.flow] = stream - self.streams_by_client_id[stream.client_id] = stream - self.streams_by_server_id[stream.server_id] = stream + self.streams_by_id[stream.stream_id] = stream yield tcp_layer.TcpStartHook(stream.flow) return stream @@ -305,19 +293,13 @@ def handle_quic_event( self, event: quic_events.QuicEvent, from_client: bool, - allow_buffering: bool, ) -> layer.CommandGenerator[None]: # buffer events if the peer is not ready yet peer_quic = self.quic_server if from_client else self.quic_client if peer_quic is None: - if not allow_buffering: - raise AssertionError( - f"Cannot buffer event from {'client' if from_client else 'server'}." - ) - if from_client: - self.buffer_from_client.append(event) - else: - self.buffer_from_server.append(event) + buffer = self.buffer_from_client if from_client else self.buffer_from_server + assert buffer is not None + buffer.append(event) return peer_connection = self.context.server if from_client else self.context.client @@ -332,7 +314,7 @@ def handle_quic_event( elif isinstance(event, quic_events.StreamDataReceived): # ignore data received from already ended streams - stream = yield from self.get_or_create_stream(event.stream_id, from_client) + stream = yield from self.get_or_create_stream(event.stream_id) if stream.has_ended(from_client): return @@ -341,33 +323,30 @@ def handle_quic_event( stream.flow.messages.append(message) yield tcp_layer.TcpMessageHook(stream.flow) peer_quic.send_stream_data( - stream_id=stream.stream_id(not from_client), - data=message.content, - end_stream=event.end_stream, + stream.stream_id, + message.content, + event.end_stream, ) # mark the stream as ended if needed if event.end_stream: - if from_client: - stream.client_ended = True - else: - stream.server_ended = True + stream.mark_ended(from_client) - # end the flow if both legs ended - if stream.client_ended and stream.server_ended: + # end the flow if both sides ended + if stream.has_ended(not from_client): yield tcp_layer.TcpEndHook(stream.flow) stream.flow.live = False elif isinstance(event, quic_events.StreamReset): # ignore resets from already ended streams - stream = yield from self.get_or_create_stream(event.stream_id, from_client) + stream = yield from self.get_or_create_stream(event.stream_id) if stream.has_ended(from_client): return # forward resets to peer streams and report them to addons peer_quic.reset_stream( - stream_id=stream.stream_id(not from_client), - error_code=event.error_code, + stream.stream_id, + event.error_code, ) stream.flow.error = mitm_flow.Error(error_code_to_str(event.error_code)) yield tcp_layer.TcpErrorHook(stream.flow) @@ -447,53 +426,45 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: self._handle_event = self.state_done elif isinstance(event, QuicStart): - # QUIC connection has been established, store it and flush buffered events + # QUIC connection has been established, store it and get the peer's buffer if event.connection is self.context.client: assert self.quic_client is None self.quic_client = event.quic - for quic_event in self.buffer_from_server: - yield from self.handle_quic_event( - quic_event, - from_client=False, - allow_buffering=False, - ) + from_client = False + buffer = self.buffer_from_server + self.buffer_from_server = None elif event.connection is self.context.server: assert self.quic_server is None self.quic_server = event.quic - for quic_event in self.buffer_from_client: - yield from self.handle_quic_event( - quic_event, - from_client=True, - allow_buffering=False, - ) + from_client = True + buffer = self.buffer_from_client + self.buffer_from_client = None else: raise AssertionError( f"Connection {event.connection} not associated with layer." ) + # flush the buffer + for quic_event in buffer: + yield from self.handle_quic_event(quic_event, from_client) + elif isinstance(event, tcp_layer.TcpMessageInjected): # translate injected messages into QUIC events assert isinstance(event.flow, tcp.TCPFlow) if event.flow is self.flow: yield from self.handle_quic_event( - quic_events.DatagramFrameReceived(event.message.content), + quic_events.DatagramFrameReceived(data=event.message.content), event.message.from_client, - allow_buffering=True, ) elif event.flow in self.streams_by_flow: stream = self.streams_by_flow[event.flow] yield from self.handle_quic_event( quic_events.StreamDataReceived( - stream_id=( - stream.client_id - if event.message.from_client - else stream.server_id - ), + stream_id=stream.stream_id, data=event.message.content, end_stream=False, ), event.message.from_client, - allow_buffering=True, ) else: raise AssertionError( @@ -505,7 +476,6 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: yield from self.handle_quic_event( event.event, from_client=event.connection is self.context.client, - allow_buffering=True, ) else: @@ -753,7 +723,6 @@ def open_connection_end(self, reply: Optional[str]) -> layer.CommandGenerator[bo def process_events(self) -> layer.CommandGenerator[None]: assert self.quic is not None - assert self.tls is not None # handle all buffered aioquic connection events event = self.quic.next_event() @@ -832,7 +801,7 @@ def state_has_quic(self, event: events.Event) -> layer.CommandGenerator[None]: isinstance(event, events.ConnectionClosed) and event.connection is self.conn ): # there is no point in calling quic.close, as it cannot send packets anymore - # just set the new connection state and ensure there is exists a close event + # just set the new connection state and ensure there exists a close event self.quic._set_state(QuicConnectionState.TERMINATED) if self.quic._close_event is None: self.quic._close_event = quic_events.ConnectionTerminated( diff --git a/test/mitmproxy/net/test_udp.py b/test/mitmproxy/net/test_udp.py index c7f0ef7338..0180550e5e 100644 --- a/test/mitmproxy/net/test_udp.py +++ b/test/mitmproxy/net/test_udp.py @@ -16,4 +16,4 @@ async def test_reader(): with pytest.raises(AssertionError): await reader.read(MAX_DATAGRAM_SIZE - 1) assert await reader.read(65535) == (b"Second message", addr) - assert not await reader.read(65535) + assert not (await reader.read(65535))[0] From af5be0b92817eebb534e05fa0cc45127a70fa113 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 28 Jun 2022 16:16:19 +0200 Subject: [PATCH 037/529] reopen main for development --- mitmproxy/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mitmproxy/version.py b/mitmproxy/version.py index a047d86f61..b87cabc790 100644 --- a/mitmproxy/version.py +++ b/mitmproxy/version.py @@ -2,7 +2,7 @@ import subprocess import sys -VERSION = "8.1.1" +VERSION = "9.0.0.dev" MITMPROXY = "mitmproxy " + VERSION # Serialization format version. This is displayed nowhere, it just needs to be incremented by one From f92a20af4cda9d1b01102e9177eda5b53f65e818 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Tue, 28 Jun 2022 17:16:24 +0200 Subject: [PATCH 038/529] [quic] H2<->H3 stream ID mapping --- mitmproxy/proxy/layers/http/_http3.py | 261 +++++++++++++++----------- 1 file changed, 155 insertions(+), 106 deletions(-) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 7ac4522f58..29cc6e3585 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -1,16 +1,16 @@ from abc import abstractmethod import time -from typing import Optional, Union +from typing import Dict, Optional, Union from aioquic.h3.connection import ( H3Connection, - FrameUnexpected, ErrorCode as H3ErrorCode, + FrameUnexpected as H3FrameUnexpected, HeadersState as H3HeadersState, ) from aioquic.h3 import events as h3_events from aioquic.quic import events as quic_events -from aioquic.quic.connection import QuicConnection +from aioquic.quic.connection import QuicConnection, stream_is_client_initiated from mitmproxy import http, version from mitmproxy.net.http import status_codes @@ -58,17 +58,26 @@ class Http3Connection(HttpConnection): ReceiveProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]] ReceiveTrailers: type[Union[RequestTrailers, ResponseTrailers]] - @expect(events.Start) - def state_start(self, _) -> layer.CommandGenerator[None]: - self._handle_event = self.state_wait_for_quic - yield from () + @abstractmethod + def parse_headers( + self, event: h3_events.HeadersReceived + ) -> Union[RequestHeaders, ResponseHeaders]: + pass # pragma: no cover - @expect(QuicStart) - def state_wait_for_quic(self, event: events.Event) -> layer.CommandGenerator[None]: - assert isinstance(event, QuicStart) - self.quic = event.quic - self.h3_conn = H3Connection(self.quic, enable_webtransport=False) - self._handle_event = self.state_ready + def postprocess_outgoing_event(self, event: HttpEvent) -> HttpEvent: + return event + + def preprocess_incoming_event(self, event: HttpEvent) -> HttpEvent: + return event + + @abstractmethod + def send_protocol_error( + self, event: Union[RequestProtocolError, ResponseProtocolError] + ) -> None: + pass # pragma: no cover + + @expect(HttpEvent, QuicConnectionEvent, events.ConnectionClosed) + def state_done(self, _) -> layer.CommandGenerator[None]: yield from () @expect(HttpEvent, QuicConnectionEvent, events.ConnectionClosed) @@ -78,6 +87,7 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: # send mitmproxy HTTP events over the H3 connection if isinstance(event, HttpEvent): + event = self.preprocess_incoming_event(event) try: if isinstance(event, (RequestData, ResponseData)): @@ -112,87 +122,106 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: stream_id=event.stream_id, data=b"", end_stream=True ) elif isinstance(event, (RequestProtocolError, ResponseProtocolError)): - self.protocol_error(event) + self.send_protocol_error(event) else: raise AssertionError(f"Unexpected event: {event!r}") - except FrameUnexpected: + except H3FrameUnexpected: # Http2Connection also ignores HttpEvents that violate the current stream state - return + pass - # transmit buffered data and re-arm timer - yield QuicTransmit(self.conn, self.quic) + else: + # transmit buffered data and re-arm timer + yield QuicTransmit(self.conn, self.quic) # handle events from the underlying QUIC connection elif isinstance(event, QuicConnectionEvent): # report abrupt stream resets - if isinstance(event, quic_events.StreamReset): - if event.stream_id in self.h3_conn._stream: - stream = self.h3_conn._stream[event.stream_id] - if not stream.ended: - # mark the receiving part of the stream as ended - # (H3Connection alas doesn't handle StreamReset) - stream.ended = True - - # report the protocol error (doing the same error code mingling as H2) - code = ( - status_codes.CLIENT_CLOSED_REQUEST - if event.error_code == H3ErrorCode.H3_REQUEST_CANCELLED - else self.ReceiveProtocolError.code - ) - yield ReceiveHttp( - self.ReceiveProtocolError( - stream_id=event.stream_id, - message=f"stream reset by client ({error_code_to_str(event.error_code)})", - code=code, - ) + if ( + isinstance(event, quic_events.StreamReset) + and stream_is_client_initiated(event.stream_id) + and event.stream_id in self.h3_conn._stream + and not self.h3_conn._stream[event.stream_id].ended + ): + # mark the receiving part of the stream as ended + # (H3Connection alas doesn't handle StreamReset) + self.h3_conn._stream[event.stream_id] = True + + # report the protocol error (doing the same error code mingling as H2) + code = ( + status_codes.CLIENT_CLOSED_REQUEST + if event.error_code == H3ErrorCode.H3_REQUEST_CANCELLED + else self.ReceiveProtocolError.code + ) + yield ReceiveHttp( + self.postprocess_outgoing_event( + self.ReceiveProtocolError( + stream_id=event.stream_id, + message=f"stream reset by client ({error_code_to_str(event.error_code)})", + code=code, ) + ) + ) # forward QUIC events to the H3 connection for h3_event in self.h3_conn.handle_event(event.event): # report received data - if isinstance(h3_event, h3_events.DataReceived): + if isinstance( + h3_event, h3_events.DataReceived + ) and stream_is_client_initiated(h3_event.stream_id): yield ReceiveHttp( - self.ReceiveData( - stream_id=h3_event.stream_id, data=h3_event.data + self.postprocess_outgoing_event( + self.ReceiveData( + stream_id=h3_event.stream_id, data=h3_event.data + ) ) ) if h3_event.stream_ended: yield ReceiveHttp( - self.ReceiveEndOfMessage(stream_id=h3_event.stream_id) + self.postprocess_outgoing_event( + self.ReceiveEndOfMessage(stream_id=h3_event.stream_id) + ) ) # report headers and trailers - elif isinstance(h3_event, h3_events.HeadersReceived): + elif isinstance( + h3_event, h3_events.HeadersReceived + ) and stream_is_client_initiated(h3_event.stream_id): if ( self.h3_conn._stream[h3_event.stream_id].headers_recv_state is H3HeadersState.AFTER_TRAILERS ): yield ReceiveHttp( - self.ReceiveTrailers( - stream_id=h3_event.stream_id, - trailers=http.Headers(h3_event.headers), + self.postprocess_outgoing_event( + self.ReceiveTrailers( + stream_id=h3_event.stream_id, + trailers=http.Headers(h3_event.headers), + ) ) ) else: try: - receive_event = self.headers_received(h3_event) + receive_event = self.parse_headers(h3_event) except ValueError as e: - # this will result in a ConnectionTerminated event + # this will result in a ConnectionClosed event self.quic.close( error_code=H3ErrorCode.H3_GENERAL_PROTOCOL_ERROR, reason_phrase=f"Invalid HTTP/3 request headers: {e}", ) yield QuicTransmit(self.conn, self.quic) else: - yield ReceiveHttp(receive_event) + yield ReceiveHttp( + self.postprocess_outgoing_event(receive_event) + ) # always report an EndOfMessage if the stream has ended if h3_event.stream_ended: yield ReceiveHttp( - self.ReceiveEndOfMessage(stream_id=h3_event.stream_id) + self.postprocess_outgoing_event( + self.ReceiveEndOfMessage(stream_id=h3_event.stream_id) + ) ) # we don't support push, web transport, etc. @@ -202,14 +231,16 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: # report a protocol error for all remaining open streams when a connection is closed elif isinstance(event, events.ConnectionClosed): for stream in self.h3_conn._stream.values(): - if self.quic._stream_can_receive(stream.stream_id) and not stream.ended: + if stream_is_client_initiated(stream.stream_id) and not stream.ended: close_event = self.quic._close_event assert close_event is not None yield ReceiveHttp( - self.ReceiveProtocolError( - stream_id=stream.stream_id, - message=close_event.reason_phrase, - code=close_event.error_code, + self.postprocess_outgoing_event( + self.ReceiveProtocolError( + stream_id=stream.stream_id, + message=close_event.reason_phrase, + code=close_event.error_code, + ) ) ) self._handle_event = self.state_done @@ -217,21 +248,19 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: else: raise AssertionError(f"Unexpected event: {event!r}") - @expect(HttpEvent, QuicConnectionEvent, events.ConnectionClosed) - def state_done(self, _) -> layer.CommandGenerator[None]: + @expect(events.Start) + def state_start(self, event: events.Event) -> layer.CommandGenerator[None]: + assert isinstance(event, events.Start) + self._handle_event = self.state_wait_for_quic yield from () - @abstractmethod - def protocol_error( - self, event: Union[RequestProtocolError, ResponseProtocolError] - ) -> None: - pass # pragma: no cover - - @abstractmethod - def headers_received( - self, event: h3_events.HeadersReceived - ) -> Union[RequestHeaders, ResponseHeaders]: - pass # pragma: no cover + @expect(QuicStart) + def state_wait_for_quic(self, event: events.Event) -> layer.CommandGenerator[None]: + assert isinstance(event, QuicStart) + self.quic = event.quic + self.h3_conn = H3Connection(self.quic, enable_webtransport=False) + self._handle_event = self.state_ready + yield from () _handle_event = state_start @@ -245,31 +274,7 @@ class Http3Server(Http3Connection): def __init__(self, context: context.Context): super().__init__(context, context.client) - def protocol_error( - self, event: Union[RequestProtocolError, ResponseProtocolError] - ) -> None: - assert self.h3_conn is not None - assert isinstance(event, ResponseProtocolError) - - # same as HTTP/2 - code = event.code - if code != status_codes.CLIENT_CLOSED_REQUEST: - code = status_codes.INTERNAL_SERVER_ERROR - self.h3_conn.send_headers( - stream_id=event.stream_id, - headers=[ - (b":status", b"%d" % code), - (b"server", version.MITMPROXY.encode()), - (b"content-type", b"text/html"), - ], - ) - self.h3_conn.send_data( - stream_id=event.stream_id, - data=format_error(code, event.message), - end_stream=True, - ) - - def headers_received( + def parse_headers( self, event: h3_events.HeadersReceived ) -> Union[RequestHeaders, ResponseHeaders]: # same as HTTP/2 @@ -300,6 +305,30 @@ def headers_received( stream_id=event.stream_id, request=request, end_stream=event.stream_ended ) + def send_protocol_error( + self, event: Union[RequestProtocolError, ResponseProtocolError] + ) -> None: + assert self.h3_conn is not None + assert isinstance(event, ResponseProtocolError) + + # same as HTTP/2 + code = event.code + if code != status_codes.CLIENT_CLOSED_REQUEST: + code = status_codes.INTERNAL_SERVER_ERROR + self.h3_conn.send_headers( + stream_id=event.stream_id, + headers=[ + (b":status", b"%d" % code), + (b"server", version.MITMPROXY.encode()), + (b"content-type", b"text/html"), + ], + ) + self.h3_conn.send_data( + stream_id=event.stream_id, + data=format_error(code, event.message), + end_stream=True, + ) + class Http3Client(Http3Connection): ReceiveData = ResponseData @@ -309,20 +338,10 @@ class Http3Client(Http3Connection): def __init__(self, context: context.Context): super().__init__(context, context.server) + self._event_to_quic: Dict[int, int] = {} + self._quic_to_event: Dict[int, int] = {} - def protocol_error( - self, event: Union[RequestProtocolError, ResponseProtocolError] - ) -> None: - assert isinstance(event, RequestProtocolError) - assert self.quic is not None - - # same as HTTP/2 - code = event.code - if code != H3ErrorCode.H3_REQUEST_CANCELLED: - code = H3ErrorCode.H3_INTERNAL_ERROR - self.quic.reset_stream(stream_id=event.stream_id, error_code=code) - - def headers_received( + def parse_headers( self, event: h3_events.HeadersReceived ) -> Union[RequestHeaders, ResponseHeaders]: # same as HTTP/2 @@ -341,6 +360,36 @@ def headers_received( stream_id=event.stream_id, response=response, end_stream=event.stream_ended ) + def postprocess_outgoing_event(self, event: HttpEvent) -> HttpEvent: + event.stream_id = self._quic_to_event[event.stream_id] + return event + + def preprocess_incoming_event(self, event: HttpEvent) -> HttpEvent: + if event.stream_id in self._event_to_quic: + event.stream_id = self._event_to_quic[event.stream_id] + else: + # QUIC and HTTP/3 would actually allow for direct stream ID mapping, but since we want + # to support H2<->H3, we need to translate IDs. + # NOTE: We always create bidirectional streams, as we can't safely infer unidirectionality. + assert self.quic is not None + stream_id = self.quic.get_next_available_stream_id() + self._event_to_quic[event.stream_id] = stream_id + self._quic_to_event[stream_id] = event.stream_id + event.stream_id = stream_id + return event + + def send_protocol_error( + self, event: Union[RequestProtocolError, ResponseProtocolError] + ) -> None: + assert isinstance(event, RequestProtocolError) + assert self.quic is not None + + # same as HTTP/2 + code = event.code + if code != H3ErrorCode.H3_REQUEST_CANCELLED: + code = H3ErrorCode.H3_INTERNAL_ERROR + self.quic.reset_stream(stream_id=event.stream_id, error_code=code) + __all__ = [ "Http3Client", From d94345b2f32aa04ed57dadac99c1ce9724ec15b2 Mon Sep 17 00:00:00 2001 From: Manuel Meitinger Date: Wed, 29 Jun 2022 11:48:24 +0200 Subject: [PATCH 039/529] [quic] properly forward disconnect reason --- mitmproxy/proxy/layers/quic.py | 62 ++++++++++++++++------------------ 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 83d32ef819..15dab790e0 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -1,7 +1,7 @@ import asyncio from dataclasses import dataclass, field from ssl import VerifyMode -from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union from aioquic.buffer import Buffer as QuicBuffer from aioquic.h3.connection import ErrorCode as H3ErrorCode @@ -394,6 +394,17 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: close_event = local_quic._close_event assert close_event is not None + # close the peer as well (needs to be before hooks) + if peer_quic is not None: + peer_quic.close( + close_event.error_code, + close_event.frame_type, + close_event.reason_phrase, + ) + yield QuicTransmit(peer_conn, peer_quic) + else: + yield commands.CloseConnection(peer_conn) + # report the termination as error to all non-ended streams for flow in self.streams_by_flow: if flow.live: @@ -413,16 +424,6 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: yield tcp_layer.TcpErrorHook(flow) self.flow.live = False - # close the peer as well - if peer_quic is not None: - peer_quic.close( - close_event.error_code, - close_event.frame_type, - close_event.reason_phrase, - ) - yield QuicTransmit(peer_conn, peer_quic) - else: - yield commands.CloseConnection(peer_conn) self._handle_event = self.state_done elif isinstance(event, QuicStart): @@ -445,6 +446,7 @@ def state_ready(self, event: events.Event) -> layer.CommandGenerator[None]: ) # flush the buffer + assert buffer is not None for quic_event in buffer: yield from self.handle_quic_event(quic_event, from_client) @@ -584,9 +586,7 @@ def create_quic(self) -> layer.CommandGenerator[bool]: return True def destroy_quic( - self, - reason: str, - level: Literal["error", "warn", "info", "alert", "debug"], + self, event: quic_events.ConnectionTerminated ) -> layer.CommandGenerator[None]: # ensure QUIC has been properly shut down assert self.quic is not None @@ -594,6 +594,7 @@ def destroy_quic( assert self.quic._state is QuicConnectionState.TERMINATED # report as TLS failure if the termination happened before the handshake + reason = event.reason_phrase or error_code_to_str(event.error_code) if not self.conn.tls_established: self.conn.error = reason tls_data = QuicTlsData(self.conn, self.context, settings=self.tls) @@ -615,15 +616,17 @@ def destroy_quic( # record an entry in the log yield commands.Log( - f"{self.conn}: QUIC connection destroyed: {reason}", level=level + f"{self.conn}: QUIC connection destroyed: {reason}", + level="info" if is_success_error_code(event.error_code) else "warn", ) def establish_quic( self, event: quic_events.HandshakeCompleted ) -> layer.CommandGenerator[None]: - # must only be called if QUIC is initialized + # must only be called if QUIC is initialized and not established assert self.quic is not None assert self.tls is not None + assert not self.conn.tls_established # concatenate all peer certificates all_certs: List[x509.Certificate] = [] @@ -737,12 +740,7 @@ def process_events(self) -> layer.CommandGenerator[None]: elif isinstance(event, quic_events.ConnectionTerminated): # shutdown and close the connection - yield from self.destroy_quic( - event.reason_phrase or error_code_to_str(event.error_code), - level=( - "info" if is_success_error_code(event.error_code) else "warn" - ), - ) + yield from self.destroy_quic(event) yield commands.CloseConnection(self.conn) # we don't handle any further events, nor do/can we transmit data, so exit @@ -803,22 +801,20 @@ def state_has_quic(self, event: events.Event) -> layer.CommandGenerator[None]: # there is no point in calling quic.close, as it cannot send packets anymore # just set the new connection state and ensure there exists a close event self.quic._set_state(QuicConnectionState.TERMINATED) - if self.quic._close_event is None: - self.quic._close_event = quic_events.ConnectionTerminated( + close_event = self.quic._close_event + if close_event is None: + close_event = quic_events.ConnectionTerminated( error_code=QuicErrorCode.APPLICATION_ERROR, frame_type=None, - reason_phrase="Peer UDP connection closed or timed out.", + reason_phrase="UDP connection closed or timed out.", ) + self.quic._close_event = close_event # shutdown QUIC and handle the ConnectionClosed event - reason = self.quic._close_event.reason_phrase or error_code_to_str( - self.quic._close_event.error_code - ) - yield from self.destroy_quic( - reason, - level="info", - ) - if not (yield from self.open_connection_end(reason)): + yield from self.destroy_quic(close_event) + if not ( + yield from self.open_connection_end("QUIC could not be established") + ): # connection was opened before QUIC layer, report to the child layer yield from self.event_to_child(event) From 821859cd4fccf754f6916a67b6dcb59bb4da2811 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 15 Jul 2022 23:41:02 +1200 Subject: [PATCH 040/529] [requires.io] dependency update on main branch (#5448) * [requires.io] dependency update * [requires.io] dependency update * [requires.io] dependency update * [requires.io] dependency update * [requires.io] dependency update * Update setup.py Co-authored-by: requires.io Co-authored-by: Maximilian Hils --- setup.py | 6 +++--- tox.ini | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 4e149c4d87..add21f15d4 100644 --- a/setup.py +++ b/setup.py @@ -81,7 +81,7 @@ "h11>=0.11,<0.14", "h2>=4.1,<5", "hyperframe>=6.0,<7", - "kaitaistruct>=0.7,<0.10", + "kaitaistruct>=0.10,<0.11", "ldap3>=2.8,<2.10", "msgpack>=1.0.0, <1.1.0", "passlib>=1.6.5, <1.8", @@ -106,8 +106,8 @@ "hypothesis>=5.8,<7", "parver>=0.1,<2.0", "pdoc>=4.0.0", - "pyinstaller==5.1", - "pytest-asyncio>=0.17.0,<0.19", + "pyinstaller==5.2", + "pytest-asyncio>=0.17,<0.20", "pytest-cov>=2.7.1,<3.1", "pytest-timeout>=1.3.3,<2.2", "pytest-xdist>=2.1.0,<3", diff --git a/tox.ini b/tox.ini index db3d7c6ec3..044532230f 100644 --- a/tox.ini +++ b/tox.ini @@ -33,7 +33,7 @@ deps = types-certifi==2021.10.8.3 types-Flask==1.1.6 types-Werkzeug==1.0.9 - types-requests==2.28.0 + types-requests==2.28.1 types-cryptography==3.3.21 types-pyOpenSSL==22.0.4 From 57e65d3d1f28031ca087fca47cd13341ee76f210 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 20 Jul 2022 01:30:35 +1200 Subject: [PATCH 041/529] [requires.io] dependency update on main branch (#5461) * [requires.io] dependency update * [requires.io] dependency update * [requires.io] dependency update * Update setup.py Co-authored-by: requires.io Co-authored-by: Maximilian Hils --- setup.py | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index add21f15d4..5708fbe295 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ # It is not considered best practice to use install_requires to pin dependencies to specific versions. install_requires=[ "asgiref>=3.2.10,<3.6", - "blinker>=1.4, <1.5", + "blinker>=1.4,<1.6", "Brotli>=1.0,<1.1", "certifi>=2019.9.11", # no semver here - this should always be on the last release! "cryptography>=36,<38", diff --git a/tox.ini b/tox.ini index 044532230f..36fde15d2c 100644 --- a/tox.ini +++ b/tox.ini @@ -33,7 +33,7 @@ deps = types-certifi==2021.10.8.3 types-Flask==1.1.6 types-Werkzeug==1.0.9 - types-requests==2.28.1 + types-requests==2.28.2 types-cryptography==3.3.21 types-pyOpenSSL==22.0.4 From 83e543c3e66654b952f1979c0adaa62df91b2832 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 3 Jun 2022 15:18:36 +0200 Subject: [PATCH 042/529] add multi proxy mode This commit makes it possible for mitmproxy to spawn multiple TCP/UDP proxy servers at the same time, see https://github.com/mitmproxy/mitmproxy/discussions/5288 --- mitmproxy/addons/clientplayback.py | 12 +- mitmproxy/addons/core.py | 16 - mitmproxy/addons/dns_resolver.py | 9 +- mitmproxy/addons/next_layer.py | 34 +- mitmproxy/addons/proxyauth.py | 111 +++-- mitmproxy/addons/proxyserver.py | 379 ++++++------------ mitmproxy/addons/upstream_auth.py | 5 +- mitmproxy/connection.py | 8 + mitmproxy/io/compat.py | 7 + mitmproxy/master.py | 19 +- mitmproxy/net/server_spec.py | 43 +- mitmproxy/net/udp.py | 9 +- mitmproxy/options.py | 22 +- mitmproxy/proxy/layers/http/__init__.py | 13 +- .../proxy/layers/http/_upstream_proxy.py | 12 +- mitmproxy/proxy/layers/modes.py | 13 +- mitmproxy/proxy/mode_servers.py | 260 ++++++++++++ mitmproxy/proxy/mode_specs.py | 202 ++++++++++ mitmproxy/proxy/server.py | 10 +- mitmproxy/test/taddons.py | 14 +- mitmproxy/test/tflow.py | 3 + mitmproxy/tools/cmdline.py | 2 - mitmproxy/tools/console/statusbar.py | 16 +- mitmproxy/tools/main.py | 11 +- mitmproxy/version.py | 2 +- test/mitmproxy/addons/test_asgiapp.py | 4 +- test/mitmproxy/addons/test_clientplayback.py | 4 +- test/mitmproxy/addons/test_core.py | 27 +- test/mitmproxy/addons/test_dns_resolver.py | 5 +- test/mitmproxy/addons/test_next_layer.py | 42 +- test/mitmproxy/addons/test_proxyauth.py | 170 ++++---- test/mitmproxy/addons/test_proxyserver.py | 118 +++--- test/mitmproxy/addons/test_script.py | 6 +- test/mitmproxy/addons/test_upstream_auth.py | 5 +- test/mitmproxy/net/test_server_spec.py | 40 +- test/mitmproxy/proxy/layers/http/test_http.py | 14 +- test/mitmproxy/proxy/layers/test_modes.py | 13 +- test/mitmproxy/proxy/test_mode_servers.py | 77 ++++ test/mitmproxy/proxy/test_mode_specs.py | 73 ++++ test/mitmproxy/test_addonmanager.py | 7 +- test/mitmproxy/test_flow.py | 4 +- .../mitmproxy/tools/console/test_statusbar.py | 2 +- test/mitmproxy/tools/web/test_app.py | 2 + .../Modal/__snapshots__/ModalSpec.tsx.snap | 9 + web/src/js/components/Footer.tsx | 9 +- web/src/js/ducks/_options_gen.ts | 16 +- 46 files changed, 1167 insertions(+), 712 deletions(-) create mode 100644 mitmproxy/proxy/mode_servers.py create mode 100644 mitmproxy/proxy/mode_specs.py create mode 100644 test/mitmproxy/proxy/test_mode_servers.py create mode 100644 test/mitmproxy/proxy/test_mode_specs.py diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py index fb6dfad1d1..6aa8acb2a4 100644 --- a/mitmproxy/addons/clientplayback.py +++ b/mitmproxy/addons/clientplayback.py @@ -12,13 +12,13 @@ from mitmproxy import http from mitmproxy import io from mitmproxy.hooks import UpdateHook -from mitmproxy.net import server_spec from mitmproxy.options import Options from mitmproxy.proxy.context import Context from mitmproxy.proxy.layers.http import HTTPMode from mitmproxy.proxy import commands, events, layers, server from mitmproxy.connection import ConnectionState, Server from mitmproxy.proxy.layer import CommandGenerator +from mitmproxy.proxy.mode_specs import UpstreamMode from mitmproxy.utils import asyncio_utils @@ -81,14 +81,14 @@ def __init__(self, flow: http.HTTPFlow, options: Options) -> None: context = Context(client, options) context.server = Server((flow.request.host, flow.request.port)) context.server.tls = flow.request.scheme == "https" - if options.mode.startswith("upstream:"): - context.server.via = flow.server_conn.via = server_spec.parse_with_mode( - options.mode - )[1] + if options.mode and options.mode[0].startswith("upstream:"): + mode = UpstreamMode.parse(options.mode[0]) + assert isinstance(mode, UpstreamMode) # remove once mypy supports Self. + context.server.via = flow.server_conn.via = (mode.scheme, mode.address) super().__init__(context) - if options.mode.startswith("upstream:"): + if options.mode and options.mode[0].startswith("upstream:"): self.layer = layers.HttpLayer(context, HTTPMode.upstream) else: self.layer = layers.HttpLayer(context, HTTPMode.transparent) diff --git a/mitmproxy/addons/core.py b/mitmproxy/addons/core.py index 230afac4a0..7dadf56888 100644 --- a/mitmproxy/addons/core.py +++ b/mitmproxy/addons/core.py @@ -8,8 +8,6 @@ from mitmproxy import command from mitmproxy import flow from mitmproxy import optmanager -from mitmproxy import platform -from mitmproxy.net import server_spec from mitmproxy.net.http import status_codes import mitmproxy.types @@ -25,20 +23,6 @@ def configure(self, updated): raise exceptions.OptionsError( "add_upstream_certs_to_client_chain requires the upstream_cert option to be enabled." ) - if "mode" in updated: - mode = opts.mode - if mode.startswith("reverse:") or mode.startswith("upstream:"): - try: - server_spec.parse_with_mode(mode) - except ValueError as e: - raise exceptions.OptionsError(str(e)) from e - elif mode == "transparent": - if not platform.original_addr: - raise exceptions.OptionsError( - "Transparent mode not supported on this platform." - ) - elif mode not in ["regular", "socks5"]: - raise exceptions.OptionsError("Invalid mode specification: %s" % mode) if "client_certs" in updated: if opts.client_certs: client_certs = os.path.expanduser(opts.client_certs) diff --git a/mitmproxy/addons/dns_resolver.py b/mitmproxy/addons/dns_resolver.py index e1af088b3e..8d460d428a 100644 --- a/mitmproxy/addons/dns_resolver.py +++ b/mitmproxy/addons/dns_resolver.py @@ -2,7 +2,8 @@ import ipaddress import socket from typing import Callable, Iterable, Union -from mitmproxy import ctx, dns +from mitmproxy import dns +from mitmproxy.proxy import mode_specs IP4_PTR_SUFFIX = ".in-addr.arpa" IP6_PTR_SUFFIX = ".ip6.arpa" @@ -30,7 +31,7 @@ async def resolve_question_by_name( else: # NOTE might fail on Windows for IPv6 queries: # https://stackoverflow.com/questions/66755681/getaddrinfo-c-on-windows-not-handling-ipv6-correctly-returning-error-code-1 - raise ResolveError(dns.response_codes.SERVFAIL) + raise ResolveError(dns.response_codes.SERVFAIL) # pragma: no cover return map( lambda addrinfo: dns.ResourceRecord( name=question.name, @@ -137,11 +138,13 @@ async def resolve_message( class DnsResolver: async def dns_request(self, flow: dns.DNSFlow) -> None: + proxy_mode = flow.client_conn.proxy_mode + assert isinstance(proxy_mode, mode_specs.DnsMode) should_resolve = ( flow.live and not flow.response and not flow.error - and ctx.options.dns_mode == "regular" + and proxy_mode.resolve_local ) if should_resolve: flow.response = await resolve_message( diff --git a/mitmproxy/addons/next_layer.py b/mitmproxy/addons/next_layer.py index e2e57e96b9..4496895b7d 100644 --- a/mitmproxy/addons/next_layer.py +++ b/mitmproxy/addons/next_layer.py @@ -117,9 +117,7 @@ def next_layer(self, nextlayer: layer.NextLayer): def _next_layer( self, context: context.Context, data_client: bytes, data_server: bytes ) -> Optional[layer.Layer]: - if len(context.layers) == 0: - return self.make_top_layer(context) - + assert context.layers if len(data_client) < 3 and not data_server: return None # not enough data yet to make a decision @@ -153,17 +151,21 @@ def s(*layers): ret.child_layer = layers.ClientTLSLayer(context) return ret - # 3. Setup the HTTP layer for a regular HTTP proxy or an upstream proxy. + # 3. Setup the HTTP layer for a regular HTTP proxy if ( s(modes.HttpProxy) or # or a "Secure Web Proxy", see https://www.chromium.org/developers/design-documents/secure-web-proxy s(modes.HttpProxy, layers.ClientTLSLayer) ): - if ctx.options.mode == "regular": - return layers.HttpLayer(context, HTTPMode.regular) - else: - return layers.HttpLayer(context, HTTPMode.upstream) + return layers.HttpLayer(context, HTTPMode.regular) + # 3b. ... or an upstream proxy. + if ( + s(modes.HttpUpstreamProxy) + or + s(modes.HttpUpstreamProxy, layers.ClientTLSLayer) + ): + return layers.HttpLayer(context, HTTPMode.upstream) # 4. Check for --tcp if any( @@ -186,19 +188,3 @@ def s(*layers): # 6. Assume HTTP by default. return layers.HttpLayer(context, HTTPMode.transparent) - - def make_top_layer(self, context: context.Context) -> layer.Layer: - if ctx.options.mode == "regular" or ctx.options.mode.startswith("upstream:"): - return layers.modes.HttpProxy(context) - - elif ctx.options.mode == "transparent": - return layers.modes.TransparentProxy(context) - - elif ctx.options.mode.startswith("reverse:"): - return layers.modes.ReverseProxy(context) - - elif ctx.options.mode == "socks5": - return layers.modes.Socks5Proxy(context) - - else: # pragma: no cover - raise AssertionError("Unknown mode.") diff --git a/mitmproxy/addons/proxyauth.py b/mitmproxy/addons/proxyauth.py index 494ecbe99b..96013d8d49 100644 --- a/mitmproxy/addons/proxyauth.py +++ b/mitmproxy/addons/proxyauth.py @@ -13,6 +13,7 @@ from mitmproxy import exceptions from mitmproxy import http from mitmproxy.net.http import status_codes +from mitmproxy.proxy import mode_specs from mitmproxy.proxy.layers import modes REALM = "mitmproxy" @@ -42,27 +43,21 @@ def load(self, loader): ) def configure(self, updated): - if "proxyauth" not in updated: - return - auth = ctx.options.proxyauth - if auth: - if ctx.options.mode == "transparent": - raise exceptions.OptionsError( - "Proxy Authentication not supported in transparent mode." - ) - - if auth == "any": - self.validator = AcceptAll() - elif auth.startswith("@"): - self.validator = Htpasswd(auth) - elif ctx.options.proxyauth.startswith("ldap"): - self.validator = Ldap(auth) - elif ":" in ctx.options.proxyauth: - self.validator = SingleUser(auth) + if "proxyauth" in updated: + auth = ctx.options.proxyauth + if auth: + if auth == "any": + self.validator = AcceptAll() + elif auth.startswith("@"): + self.validator = Htpasswd(auth) + elif ctx.options.proxyauth.startswith("ldap"): + self.validator = Ldap(auth) + elif ":" in ctx.options.proxyauth: + self.validator = SingleUser(auth) + else: + raise exceptions.OptionsError("Invalid proxyauth specification.") else: - raise exceptions.OptionsError("Invalid proxyauth specification.") - else: - self.validator = None + self.validator = None def socks5_auth(self, data: modes.Socks5AuthData) -> None: if self.validator and self.validator(data.username, data.password): @@ -93,8 +88,11 @@ def authenticate_http(self, f: http.HTTPFlow) -> bool: username = None password = None is_valid = False + + is_proxy = is_http_proxy(f) + auth_header = http_auth_header(is_proxy) try: - auth_value = f.request.headers.get(self.http_auth_header, "") + auth_value = f.request.headers.get(auth_header, "") scheme, username, password = parse_http_basic_auth(auth_value) is_valid = self.validator(username, password) except Exception: @@ -102,47 +100,48 @@ def authenticate_http(self, f: http.HTTPFlow) -> bool: if is_valid: f.metadata["proxyauth"] = (username, password) - del f.request.headers[self.http_auth_header] + del f.request.headers[auth_header] return True else: - f.response = self.make_auth_required_response() + f.response = make_auth_required_response(is_proxy) return False - def make_auth_required_response(self) -> http.Response: - if self.is_http_proxy: - status_code = status_codes.PROXY_AUTH_REQUIRED - headers = {"Proxy-Authenticate": f'Basic realm="{REALM}"'} - else: - status_code = status_codes.UNAUTHORIZED - headers = {"WWW-Authenticate": f'Basic realm="{REALM}"'} - - reason = http.status_codes.RESPONSES[status_code] - return http.Response.make( - status_code, - ( - f"" - f"{status_code} {reason}" - f"

{status_code} {reason}

" - f"" - ), - headers, - ) - @property - def http_auth_header(self) -> str: - if self.is_http_proxy: - return "Proxy-Authorization" - else: - return "Authorization" +def make_auth_required_response(is_proxy: bool) -> http.Response: + if is_proxy: + status_code = status_codes.PROXY_AUTH_REQUIRED + headers = {"Proxy-Authenticate": f'Basic realm="{REALM}"'} + else: + status_code = status_codes.UNAUTHORIZED + headers = {"WWW-Authenticate": f'Basic realm="{REALM}"'} - @property - def is_http_proxy(self) -> bool: - """ - Returns: - - True, if authentication is done as if mitmproxy is a proxy - - False, if authentication is done as if mitmproxy is an HTTP server - """ - return ctx.options.mode == "regular" or ctx.options.mode.startswith("upstream:") + reason = http.status_codes.RESPONSES[status_code] + return http.Response.make( + status_code, + ( + f"" + f"{status_code} {reason}" + f"

{status_code} {reason}

" + f"" + ), + headers, + ) + + +def http_auth_header(is_proxy: bool) -> str: + if is_proxy: + return "Proxy-Authorization" + else: + return "Authorization" + + +def is_http_proxy(f: http.HTTPFlow) -> bool: + """ + Returns: + - True, if authentication is done as if mitmproxy is a proxy + - False, if authentication is done as if mitmproxy is an HTTP server + """ + return isinstance(f.client_conn.proxy_mode, (mode_specs.RegularMode, mode_specs.UpstreamMode)) def mkauth(username: str, password: str, scheme: str = "basic") -> str: diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index 70ae28855d..60f2760c54 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -1,112 +1,60 @@ +""" +This addon is responsible for starting/stopping the proxy server sockets/instances specified by the mode option. +""" +from __future__ import annotations + import asyncio -from asyncio import base_events +import collections import ipaddress -import re -import struct +from contextlib import contextmanager from typing import Optional +from wsproto.frame_protocol import Opcode + from mitmproxy import ( command, ctx, exceptions, - flow, http, - log, - master, - options, platform, tcp, websocket, ) from mitmproxy.connection import Address from mitmproxy.flow import Flow -from mitmproxy.net import udp -from mitmproxy.proxy import commands, events, layers, server_hooks -from mitmproxy.proxy import server +from mitmproxy.proxy import events, mode_specs, server_hooks from mitmproxy.proxy.layers.tcp import TcpMessageInjected from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected -from mitmproxy.utils import asyncio_utils, human -from wsproto.frame_protocol import Opcode +from mitmproxy.proxy.mode_servers import ProxyConnectionHandler, ServerInstance, ServerManager +from mitmproxy.utils import human -class ProxyConnectionHandler(server.LiveConnectionHandler): - master: master.Master - - def __init__(self, master, r, w, options, timeout=None): - self.master = master - super().__init__(r, w, options) - self.log_prefix = f"{human.format_address(self.client.peername)}: " - if timeout is not None: - self.timeout_watchdog.CONNECTION_TIMEOUT = timeout - - async def handle_hook(self, hook: commands.StartHook) -> None: - with self.timeout_watchdog.disarm(): - # We currently only support single-argument hooks. - (data,) = hook.args() - await self.master.addons.handle_lifecycle(hook) - if isinstance(data, flow.Flow): - await data.wait_for_resume() - - def log(self, message: str, level: str = "info") -> None: - x = log.LogEntry(self.log_prefix + message, level) - asyncio_utils.create_task( - self.master.addons.handle_lifecycle(log.AddLogHook(x)), - name="ProxyConnectionHandler.log", - ) - - -class Proxyserver: +class Proxyserver(ServerManager): """ This addon runs the actual proxy server. """ - - tcp_server: Optional[base_events.Server] - dns_server: Optional[udp.UdpServer] - connect_addr: Optional[Address] - listen_port: int - dns_reverse_addr: Optional[tuple[str, int]] - master: master.Master - options: options.Options + connections: dict[tuple, ProxyConnectionHandler] + servers: dict[str, ServerInstance] is_running: bool - _connections: dict[tuple, ProxyConnectionHandler] + _lock: asyncio.Lock + _connect_addr: Optional[Address] = None def __init__(self): - self._lock = asyncio.Lock() - self.tcp_server = None - self.dns_server = None - self.connect_addr = None - self.dns_reverse_addr = None + self.connections = {} + self.servers = {} self.is_running = False - self._connections = {} + self._lock = asyncio.Lock() def __repr__(self): - return f"ProxyServer({'running' if self.running_servers else 'stopped'}, {len(self._connections)} active conns)" + return f"Proxyserver({len(self.connections)} active conns)" - @property - def _server_desc(self): - yield "Proxy", self.tcp_server, lambda x: setattr( - self, "tcp_server", x - ), ctx.options.server, lambda: asyncio.start_server( - self.handle_tcp_connection, - self.options.listen_host, - self.options.listen_port, - ) - yield "DNS", self.dns_server, lambda x: setattr( - self, "dns_server", x - ), ctx.options.dns_server, lambda: udp.start_server( - self.handle_dns_datagram, - self.options.dns_listen_host or "127.0.0.1", - self.options.dns_listen_port, - transparent=self.options.dns_mode == "transparent", - ) - - @property - def running_servers(self): - return tuple( - instance - for _, instance, _, _, _ in self._server_desc - if instance is not None - ) + @contextmanager + def register_connection(self, connection_id: tuple, handler: ProxyConnectionHandler): + self.connections[connection_id] = handler + try: + yield + finally: + del self.connections[connection_id] def load(self, loader): loader.add_option( @@ -178,30 +126,11 @@ def load(self, loader): None, """Set the local IP address that mitmproxy should use when connecting to upstream servers.""", ) - loader.add_option( - "dns_server", bool, False, """Start a DNS server. Disabled by default.""" - ) - loader.add_option( - "dns_listen_host", str, "", """Address to bind DNS server to.""" - ) - loader.add_option("dns_listen_port", int, 53, """DNS server service port.""") - loader.add_option( - "dns_mode", - str, - "regular", - """ - One of "regular", "reverse:[:]" or "transparent". - regular....: requests will be resolved using the local resolver - reverse....: forward queries to another DNS server - transparent: transparent mode - """, - ) async def running(self): - self.master = ctx.master - self.options = ctx.options self.is_running = True - await self.refresh_server() + # TODO: Do this before running() + await self.setup_servers() def configure(self, updated): if "stream_large_bodies" in updated: @@ -222,146 +151,90 @@ def configure(self, updated): ) if "connect_addr" in updated: try: - self.connect_addr = (str(ipaddress.ip_address(ctx.options.connect_addr)), 0) if ctx.options.connect_addr else None + if ctx.options.connect_addr: + self._connect_addr = str(ipaddress.ip_address(ctx.options.connect_addr)), 0 + else: + self._connect_addr = None except ValueError: raise exceptions.OptionsError( - f"Invalid connection address {ctx.options.connect_addr!r}, specify a valid IP address." + f"Invalid value for connect_addr: {ctx.options.connect_addr!r}. Specify a valid IP address." ) - - if "dns_mode" in updated: - m = re.match( - r"^(regular|reverse:(?P[^:]+)(:(?P\d+))?|transparent)$", - ctx.options.dns_mode, - ) - if not m: - raise exceptions.OptionsError( - f"Invalid DNS mode {ctx.options.dns_mode!r}." - ) - if m["host"]: + if "mode" in updated or "server" in updated: + # Make sure that all modes are syntactically valid... + modes: list[mode_specs.ProxyMode] = [] + for mode in ctx.options.mode: try: - self.dns_reverse_addr = ( - str(ipaddress.ip_address(m["host"])), - int(m["port"]) if m["port"] is not None else 53, - ) - except ValueError: - raise exceptions.OptionsError( - f"Invalid DNS reverse mode, expected 'reverse:ip[:port]' got {ctx.options.dns_mode!r}." + modes.append( + mode_specs.ProxyMode.parse(mode) ) - else: - self.dns_reverse_addr = None - if "mode" in updated and ctx.options.mode == "transparent": # pragma: no cover - platform.init_transparent_mode() - if self.is_running and any( - x in updated - for x in [ - "server", - "listen_host", - "listen_port", - "dns_server", - "dns_mode", - "dns_listen_host", - "dns_listen_port", + except ValueError as e: + raise exceptions.OptionsError(f"Invalid proxy mode specification: {mode} ({e})") + + # ...and don't listen on the same address. + listen_addrs = [ + ( + m.listen_host(ctx.options.listen_host), + m.listen_port(ctx.options.listen_port), + m.transport_protocol + ) + for m in modes ] - ): - asyncio.create_task(self.refresh_server()) + if len(set(listen_addrs)) != len(listen_addrs): + (host, port, _) = collections.Counter(listen_addrs).most_common(1)[0][0] + dup_addr = human.format_address((host or "0.0.0.0", port)) + raise exceptions.OptionsError(f"Cannot spawn multiple servers on the same address: {dup_addr}") - async def refresh_server(self): - async with self._lock: - await self.shutdown_server() - if ctx.options.server and not ctx.master.addons.get("nextlayer"): + if ctx.options.mode and not ctx.master.addons.get("nextlayer"): ctx.log.warn("Warning: Running proxyserver without nextlayer addon!") - for name, instance, set_instance, enabled, start in self._server_desc: - if instance is None and enabled: - try: - instance = await start() - except OSError as e: - ctx.log.error(str(e)) - else: - set_instance(instance) - # TODO: This is a bit confusing currently for `-p 0`. - addrs = { - f"{human.format_address(s.getsockname())}" - for s in instance.sockets - } - ctx.log.info( - f"{name} server listening at {' and '.join(addrs)}" - ) + if any(isinstance(m, mode_specs.TransparentMode) for m in modes): + if platform.original_addr: + platform.init_transparent_mode() + else: + raise exceptions.OptionsError("Transparent mode not supported on this platform.") - async def shutdown_server(self): - for name, instance, set_instance, _, _ in self._server_desc: - if instance is not None: - ctx.log.info(f"Stopping {name} server...") - try: - instance.close() - await instance.wait_closed() - except OSError as e: - ctx.log.error(str(e)) + if self.is_running: + asyncio.create_task(self.setup_servers()) + + async def setup_servers(self) -> bool: + all_ok = True + async with self._lock: + new_servers: dict[str, ServerInstance] = dict.fromkeys(ctx.options.mode) # type: ignore + if not ctx.options.server: + new_servers.clear() + + # Shutdown modes that have been removed from the list. + shutdown_tasks = [ + s.stop() for spec, s in self.servers.items() + if spec not in new_servers + ] + for ret in await asyncio.gather(*shutdown_tasks, return_exceptions=True): + if ret: + all_ok = False + ctx.log.error(str(ret)) + + new_instances: list[ServerInstance] = [] + for spec in new_servers: + if existing := self.servers.get(spec, None): + new_servers[spec] = existing else: - set_instance(None) + instance: ServerInstance = ServerInstance.make(spec, self) + new_instances.append(instance) + new_servers[spec] = instance - async def handle_connection(self, connection_id: tuple): - handler = self._connections[connection_id] - task = asyncio.current_task() - assert task - asyncio_utils.set_task_debug_info( - task, - name=f"Proxyserver.handle_connection", - client=handler.client.peername, - ) - try: - await handler.handle_client() - finally: - del self._connections[connection_id] + for ret in await asyncio.gather(*[m.start() for m in new_instances], return_exceptions=True): + if ret: + all_ok = False + ctx.log.error(str(ret)) - async def handle_tcp_connection( - self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - connection_id = ( - "tcp", - writer.get_extra_info("peername"), - writer.get_extra_info("sockname"), - ) - self._connections[connection_id] = ProxyConnectionHandler( - self.master, reader, writer, self.options - ) - await self.handle_connection(connection_id) + self.servers = new_servers + return all_ok - def handle_dns_datagram( - self, - transport: asyncio.DatagramTransport, - data: bytes, - remote_addr: Address, - local_addr: Address, - ) -> None: - try: - dns_id = struct.unpack_from("!H", data, 0) - except struct.error: - ctx.log.info( - f"Invalid DNS datagram received from {human.format_address(remote_addr)}." - ) - return - connection_id = ("udp", dns_id, remote_addr, local_addr) - if connection_id not in self._connections: - reader = udp.DatagramReader() - writer = udp.DatagramWriter(transport, remote_addr, reader) - handler = ProxyConnectionHandler( - self.master, reader, writer, self.options, 20 - ) - handler.layer = layers.DNSLayer(handler.layer.context) - handler.layer.context.server.address = ( - local_addr - if self.options.dns_mode == "transparent" - else self.dns_reverse_addr - ) - handler.layer.context.server.transport_protocol = "udp" - self._connections[connection_id] = handler - asyncio.create_task(self.handle_connection(connection_id)) - else: - handler = self._connections[connection_id] - client_reader = handler.transports[handler.client].reader - assert isinstance(client_reader, udp.DatagramReader) - reader = client_reader - reader.feed_data(data, remote_addr) + def listen_addrs(self) -> list[Address]: + return [ + addr + for server in self.servers.values() + for addr in server.listen_addrs + ] def inject_event(self, event: events.MessageInjected): connection_id = ( @@ -369,9 +242,9 @@ def inject_event(self, event: events.MessageInjected): event.flow.client_conn.peername, event.flow.client_conn.sockname, ) - if connection_id not in self._connections: + if connection_id not in self.connections: raise ValueError("Flow is not from a live connection.") - self._connections[connection_id].server_event(event) + self.connections[connection_id].server_event(event) @command.command("inject.websocket") def inject_websocket( @@ -400,23 +273,29 @@ def inject_tcp(self, flow: Flow, to_client: bool, message: bytes): except ValueError as e: ctx.log.warn(str(e)) - def server_connect(self, ctx: server_hooks.ServerConnectionHookData): - assert ctx.server.address - # FIXME: Move this to individual proxy modes. - self_connect = ctx.server.address[1] in ( - self.options.dns_listen_port, - self.options.listen_port, - ) and ctx.server.address[0] in ( - "localhost", - "127.0.0.1", - "::1", - self.options.listen_host, - self.options.dns_listen_host, - ) - if self_connect: - ctx.server.error = ( - "Request destination unknown. " - "Unable to figure out where this request should be forwarded to." - ) - if ctx.server.sockname is None: - ctx.server.sockname = self.connect_addr + def server_connect(self, data: server_hooks.ServerConnectionHookData): + if data.server.sockname is None: + data.server.sockname = self._connect_addr + + # Prevent mitmproxy from recursively connecting to itself. + assert data.server.address + connect_host, connect_port, *_ = data.server.address + + for server in self.servers.values(): + for listen_host, listen_port, *_ in server.listen_addrs: + self_connect = ( + connect_port == listen_port + and connect_host in ( + "localhost", + "127.0.0.1", + "::1", + listen_host + ) + and server.mode.transport_protocol == data.server.transport_protocol + ) + if self_connect: + data.server.error = ( + "Request destination unknown. " + "Unable to figure out where this request should be forwarded to." + ) + return diff --git a/mitmproxy/addons/upstream_auth.py b/mitmproxy/addons/upstream_auth.py index 0c1cb1d625..da9b395348 100644 --- a/mitmproxy/addons/upstream_auth.py +++ b/mitmproxy/addons/upstream_auth.py @@ -5,6 +5,7 @@ from mitmproxy import exceptions from mitmproxy import ctx from mitmproxy import http +from mitmproxy.proxy import mode_specs from mitmproxy.utils import strutils @@ -52,7 +53,7 @@ def http_connect_upstream(self, f: http.HTTPFlow): def requestheaders(self, f: http.HTTPFlow): if self.auth: - if ctx.options.mode.startswith("upstream") and f.request.scheme == "http": + if isinstance(f.client_conn.proxy_mode, mode_specs.UpstreamMode) and f.request.scheme == "http": f.request.headers["Proxy-Authorization"] = self.auth - elif ctx.options.mode.startswith("reverse"): + elif isinstance(f.client_conn.proxy_mode, mode_specs.ReverseMode): f.request.headers["Authorization"] = self.auth diff --git a/mitmproxy/connection.py b/mitmproxy/connection.py index 6ed08602f6..286677197a 100644 --- a/mitmproxy/connection.py +++ b/mitmproxy/connection.py @@ -6,6 +6,7 @@ from typing import Literal, Optional from mitmproxy import certs +from mitmproxy.proxy import mode_specs from mitmproxy.coretypes import serializable from mitmproxy.net import server_spec from mitmproxy.utils import human @@ -158,6 +159,9 @@ class Client(Connection): The certificate used by mitmproxy to establish TLS with the client. """ + proxy_mode: mode_specs.ProxyMode + """The proxy server type this client has been connecting to.""" + timestamp_start: float """*Timestamp:* TCP SYN received""" @@ -168,6 +172,7 @@ def __init__( timestamp_start: float, *, transport_protocol: TransportProtocol = "tcp", + proxy_mode: mode_specs.ProxyMode = mode_specs.ProxyMode.parse("regular"), ): self.id = str(uuid.uuid4()) self.peername = peername @@ -175,6 +180,7 @@ def __init__( self.timestamp_start = timestamp_start self.state = ConnectionState.OPEN self.transport_protocol = transport_protocol + self.proxy_mode = proxy_mode def __str__(self): if self.alpn: @@ -211,6 +217,7 @@ def get_state(self): "certificate_list": [x.get_state() for x in self.certificate_list], "alpn_offers": self.alpn_offers, "cipher_list": self.cipher_list, + "proxy_mode": self.proxy_mode.get_state(), } @classmethod @@ -244,6 +251,7 @@ def set_state(self, state): ) self.alpn_offers = state["alpn_offers"] self.cipher_list = state["cipher_list"] + self.proxy_mode = mode_specs.ProxyMode.from_state(state["proxy_mode"]) @property def address(self): # pragma: no cover diff --git a/mitmproxy/io/compat.py b/mitmproxy/io/compat.py index 3b77818f6c..19229a4f14 100644 --- a/mitmproxy/io/compat.py +++ b/mitmproxy/io/compat.py @@ -380,6 +380,12 @@ def convert_16_17(data): return data +def convert_17_18(data): + data["version"] = 18 + data["client_conn"]["proxy_mode"] = "regular" + return data + + def _convert_dict_keys(o: Any) -> Any: if isinstance(o, dict): return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} @@ -441,6 +447,7 @@ def convert_unicode(data: dict) -> dict: 14: convert_14_15, 15: convert_15_16, 16: convert_16_17, + 17: convert_17_18, } diff --git a/mitmproxy/master.py b/mitmproxy/master.py index ec10032c0f..be9f960bb7 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -8,8 +8,8 @@ from mitmproxy import http from mitmproxy import log from mitmproxy import options -from mitmproxy.net import server_spec from . import ctx as mitmproxy_ctx +from .proxy.mode_specs import ReverseMode class Master: @@ -45,6 +45,7 @@ async def run(self) -> None: # Handle scheduled tasks (configure()) first. await asyncio.sleep(0) + # TODO: Bind proxy server here before invoking running(). await self.running() try: await self.should_exit.wait() @@ -90,14 +91,14 @@ async def load_flow(self, f): Loads a flow """ - if isinstance(f, http.HTTPFlow): - if self.options.mode.startswith("reverse:"): - # When we load flows in reverse proxy mode, we adjust the target host to - # the reverse proxy destination for all flows we load. This makes it very - # easy to replay saved flows against a different host. - _, upstream_spec = server_spec.parse_with_mode(self.options.mode) - f.request.host, f.request.port = upstream_spec.address - f.request.scheme = upstream_spec.scheme + if isinstance(f, http.HTTPFlow) and len(self.options.mode) == 1 and self.options.mode[0].startswith("reverse:"): + # When we load flows in reverse proxy mode, we adjust the target host to + # the reverse proxy destination for all flows we load. This makes it very + # easy to replay saved flows against a different host. + # We may change this in the future so that clientplayback always replays to the first mode. + mode = ReverseMode.parse(self.options.mode[0]) + f.request.host, f.request.port, *_ = mode.address + f.request.scheme = mode.scheme for e in eventsequence.iterate(f): await self.addons.handle_lifecycle(e) diff --git a/mitmproxy/net/server_spec.py b/mitmproxy/net/server_spec.py index ee95ad6471..376b13aaa2 100644 --- a/mitmproxy/net/server_spec.py +++ b/mitmproxy/net/server_spec.py @@ -1,17 +1,16 @@ """ Server specs are used to describe an upstream proxy or server. """ -import functools import re -from typing import Literal, NamedTuple +from functools import cache +from typing import Literal from mitmproxy.net import check - -class ServerSpec(NamedTuple): - scheme: Literal["http", "https"] - address: tuple[str, int] - +ServerSpec = tuple[ + Literal["http", "https", "tcp", "tls", "dns"], + tuple[str, int] +] server_spec_re = re.compile( r""" @@ -26,8 +25,8 @@ class ServerSpec(NamedTuple): ) -@functools.lru_cache -def parse(server_spec: str) -> ServerSpec: +@cache +def parse(server_spec: str, default_scheme: str) -> ServerSpec: """ Parses a server mode specification, e.g.: @@ -45,8 +44,8 @@ def parse(server_spec: str) -> ServerSpec: if m.group("scheme"): scheme = m.group("scheme") else: - scheme = "https" if m.group("port") in ("443", None) else "http" - if scheme not in ("http", "https"): + scheme = default_scheme + if scheme not in ("tcp", "tls", "dns", "http", "https"): raise ValueError(f"Invalid server scheme: {scheme}") host = m.group("host") @@ -59,19 +58,15 @@ def parse(server_spec: str) -> ServerSpec: if m.group("port"): port = int(m.group("port")) else: - port = {"http": 80, "https": 443}[scheme] + try: + port = { + "http": 80, + "https": 443, + "dns": 53, + }[scheme] + except KeyError: + raise ValueError(f"Port specification missing.") if not check.is_valid_port(port): raise ValueError(f"Invalid port: {port}") - return ServerSpec(scheme, (host, port)) # type: ignore - - -def parse_with_mode(mode: str) -> tuple[str, ServerSpec]: - """ - Parse a proxy mode specification, which is usually just `(reverse|upstream):server-spec`. - - *Raises:* - - ValueError, if the specification is invalid. - """ - mode, server_spec = mode.split(":", maxsplit=1) - return mode, parse(server_spec) + return scheme, (host, port) # type: ignore diff --git a/mitmproxy/net/udp.py b/mitmproxy/net/udp.py index c70647800e..90fd02aeaf 100644 --- a/mitmproxy/net/udp.py +++ b/mitmproxy/net/udp.py @@ -76,8 +76,13 @@ def recvfrom( ) -> tuple[bytes, tuple[SockAddress, SockAddress]]: """Same as recvfrom, but always returns source and destination addresses.""" + # (unavailable on Windows, hence the type checking exclusion) + space = socket.CMSG_SPACE(1024) # type: ignore + data, ancdata, _, client_addr = self._recvmsg( - bufsize, socket.CMSG_SPACE(1024), flags + bufsize, + space, + flags ) for cmsg_level, cmsg_type, cmsg_data in ancdata: if ( @@ -87,7 +92,7 @@ def recvfrom( server_addr = TransparentSocket._unpack_addr(cmsg_data) break else: - raise OSError("recvmsg did not return th original destination address") + raise OSError("recvmsg did not return the original destination address") return data, (client_addr, server_addr) diff --git a/mitmproxy/options.py b/mitmproxy/options.py index c9aa9c3335..c8892098df 100644 --- a/mitmproxy/options.py +++ b/mitmproxy/options.py @@ -5,7 +5,6 @@ CONF_DIR = "~/.mitmproxy" CONF_BASENAME = "mitmproxy" -LISTEN_PORT = 8080 CONTENT_VIEW_LINES_CUTOFF = 512 KEY_SIZE = 2048 @@ -91,16 +90,25 @@ def __init__(self, **kwargs) -> None: """, ) self.add_option("allow_hosts", Sequence[str], [], "Opposite of --ignore-hosts.") - self.add_option("listen_host", str, "", "Address to bind proxy to.") - self.add_option("listen_port", int, LISTEN_PORT, "Proxy service port.") + self.add_option("listen_host", str, "", + "Address to bind proxy server(s) to (may be overridden for individual modes, see `mode`).") + self.add_option("listen_port", Optional[int], None, + "Port to bind proxy server(s) to (may be overridden for individual modes, see `mode`). " + "By default, the port is mode-specific. The default regular HTTP proxy spawns on port 8080.") self.add_option( "mode", - str, - "regular", + Sequence[str], + ["regular"], """ - Mode can be "regular", "transparent", "socks5", "reverse:SPEC", - or "upstream:SPEC". For reverse and upstream proxy modes, SPEC + The proxy server type(s) to spawn. Can be passed multiple times. + + Mitmproxy supports "regular" (HTTP), "transparent", "socks5", "reverse:SPEC", + and "upstream:SPEC" proxy servers. For reverse and upstream proxy modes, SPEC is host specification in the form of "http[s]://host[:port]". + + You may append `@listen_port` or `@listen_host:listen_port` to override `listen_host` or `listen_port` for + a specific proxy mode. Features such as client playback will use the first mode to determine + which upstream server to use. """, ) self.add_option( diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index 08d6d7fb3b..4031006a01 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -2,6 +2,7 @@ import enum import time from dataclasses import dataclass +from functools import cached_property from typing import Optional, Union import wsproto.handshake @@ -42,6 +43,7 @@ from ._http1 import Http1Client, Http1Connection, Http1Server from ._http2 import Http2Client, Http2Server from ...context import Context +from ...mode_specs import ReverseMode, UpstreamMode class HTTPMode(enum.Enum): @@ -124,7 +126,7 @@ class HttpStream(layer.Layer): stream_id: StreamId child_layer: Optional[layer.Layer] = None - @property + @cached_property def mode(self): i = self.context.layers.index(self) parent: HttpLayer = self.context.layers[i - 1] @@ -222,7 +224,7 @@ def state_wait_for_request_headers( # update host header in reverse proxy mode if ( - self.context.options.mode.startswith("reverse:") + isinstance(self.context.client.proxy_mode, ReverseMode) and not self.context.options.keep_host_header ): assert self.context.server.address @@ -835,9 +837,9 @@ def _handle_event(self, event: events.Event): if isinstance(event, events.Start): yield from self.event_to_child(self.connections[self.context.client], event) if self.mode is HTTPMode.upstream: - self.context.server.via = server_spec.parse_with_mode( - self.context.options.mode - )[1] + proxy_mode = self.context.client.proxy_mode + assert isinstance(proxy_mode, UpstreamMode) + self.context.server.via = (proxy_mode.scheme, proxy_mode.address) elif isinstance(event, events.Wakeup): stream = self.command_sources.pop(event.command) yield from self.event_to_child(stream, event) @@ -996,7 +998,6 @@ def get_connection( if event.via: context.server.via = event.via - assert event.via.scheme in ("http", "https") # We always send a CONNECT request, *except* for plaintext absolute-form HTTP requests in upstream mode. send_connect = event.tls or self.mode != HTTPMode.upstream stack /= _upstream_proxy.HttpUpstreamProxy.make(context, send_connect) diff --git a/mitmproxy/proxy/layers/http/_upstream_proxy.py b/mitmproxy/proxy/layers/http/_upstream_proxy.py index cdbde2d67d..02b26c4bf8 100644 --- a/mitmproxy/proxy/layers/http/_upstream_proxy.py +++ b/mitmproxy/proxy/layers/http/_upstream_proxy.py @@ -26,16 +26,16 @@ def __init__( @classmethod def make(cls, ctx: context.Context, send_connect: bool) -> tunnel.LayerStack: - spec = ctx.server.via - assert spec - assert spec.scheme in ("http", "https") + assert ctx.server.via + scheme, address = ctx.server.via + assert scheme in ("http", "https") - http_proxy = connection.Server(spec.address) + http_proxy = connection.Server(address) stack = tunnel.LayerStack() - if spec.scheme == "https": + if scheme == "https": http_proxy.alpn_offers = tls.HTTP1_ALPNS - http_proxy.sni = spec.address[0] + http_proxy.sni = address[0] stack /= tls.ServerTLSLayer(ctx, http_proxy) stack /= cls(ctx, http_proxy, send_connect) diff --git a/mitmproxy/proxy/layers/modes.py b/mitmproxy/proxy/layers/modes.py index cf9bf960a5..3e4024dca6 100644 --- a/mitmproxy/proxy/layers/modes.py +++ b/mitmproxy/proxy/layers/modes.py @@ -5,10 +5,10 @@ from typing import Optional from mitmproxy import connection, platform -from mitmproxy.net import server_spec from mitmproxy.proxy import commands, events, layer from mitmproxy.proxy.commands import StartHook from mitmproxy.proxy.layers import tls +from mitmproxy.proxy.mode_specs import ReverseMode from mitmproxy.proxy.utils import expect @@ -20,6 +20,14 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: yield from child_layer.handle_event(event) +class HttpUpstreamProxy(layer.Layer): + @expect(events.Start) + def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: + child_layer = layer.NextLayer(self.context) + self._handle_event = child_layer.handle_event + yield from child_layer.handle_event(event) + + class DestinationKnown(layer.Layer, metaclass=ABCMeta): """Base layer for layers that gather connection destination info and then delegate.""" @@ -47,7 +55,8 @@ def done(self, _) -> layer.CommandGenerator[None]: class ReverseProxy(DestinationKnown): @expect(events.Start) def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: - spec = server_spec.parse_with_mode(self.context.options.mode)[1] + spec = self.context.client.proxy_mode + assert isinstance(spec, ReverseMode) self.context.server.address = spec.address if spec.scheme not in ("http", "tcp"): diff --git a/mitmproxy/proxy/mode_servers.py b/mitmproxy/proxy/mode_servers.py new file mode 100644 index 0000000000..3fc8de6683 --- /dev/null +++ b/mitmproxy/proxy/mode_servers.py @@ -0,0 +1,260 @@ +""" +This module defines "server instances", which manage +the TCP/UDP servers spawned my mitmproxy as specified by the proxy mode. + +Example: + + mode = ProxyMode.parse("reverse:https://example.com") + inst = ServerInstance.make(mode, manager_that_handles_callbacks) + await inst.start() + # TCP server is running now. +""" +from __future__ import annotations + +import asyncio +import struct +import typing +from abc import ABCMeta, abstractmethod +from contextlib import contextmanager +from functools import cached_property +from typing import ClassVar, Generic, TypeVar, cast, get_args + +from mitmproxy import ctx, flow, log +from mitmproxy.connection import Address +from mitmproxy.master import Master +from mitmproxy.net import udp +from mitmproxy.proxy import commands, layers, mode_specs, server +from mitmproxy.proxy.context import Context +from mitmproxy.proxy.layer import Layer +from mitmproxy.utils import asyncio_utils, human + + +class ProxyConnectionHandler(server.LiveConnectionHandler): + master: Master + + def __init__(self, master, r, w, options, mode): + self.master = master + super().__init__(r, w, options, mode) + self.log_prefix = f"{human.format_address(self.client.peername)}: " + + async def handle_hook(self, hook: commands.StartHook) -> None: + with self.timeout_watchdog.disarm(): + # We currently only support single-argument hooks. + (data,) = hook.args() + await self.master.addons.handle_lifecycle(hook) + if isinstance(data, flow.Flow): + await data.wait_for_resume() # pragma: no cover + + def log(self, message: str, level: str = "info") -> None: + x = log.LogEntry(self.log_prefix + message, level) + asyncio_utils.create_task( + self.master.addons.handle_lifecycle(log.AddLogHook(x)), + name="ProxyConnectionHandler.log", + ) + + +M = TypeVar('M', bound=mode_specs.ProxyMode) + + +class ServerManager(typing.Protocol): + connections: dict[tuple, ProxyConnectionHandler] + + @contextmanager + def register_connection(self, connection_id: tuple, handler: ProxyConnectionHandler): + ... # pragma: no cover + + +class ServerInstance(Generic[M], metaclass=ABCMeta): + + __modes: ClassVar[dict[str, type[ServerInstance]]] = {} + + def __init__(self, mode: M, manager: ServerManager): + self.mode: M = mode + self.manager: ServerManager = manager + + def __init_subclass__(cls, **kwargs): + """Register all subclasses so that make() finds them.""" + # extract mode from Generic[Mode]. + mode = get_args(cls.__orig_bases__[0])[0] + if mode != M: + assert mode.type not in ServerInstance.__modes + ServerInstance.__modes[mode.type] = cls + + @staticmethod + def make( + mode: mode_specs.ProxyMode | str, + manager: ServerManager, + ) -> ServerInstance: + if isinstance(mode, str): + mode = mode_specs.ProxyMode.parse(mode) + return ServerInstance.__modes[mode.type](mode, manager) + + @abstractmethod + async def start(self) -> None: + pass + + @abstractmethod + async def stop(self) -> None: + pass + + @property + @abstractmethod + def listen_addrs(self) -> tuple[Address, ...]: + pass + + +class TcpServerInstance(ServerInstance[M], metaclass=ABCMeta): + server: asyncio.Server | None = None + + @abstractmethod + def make_top_layer(self, context: Context) -> Layer: + pass + + async def handle_tcp_connection( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + connection_id = ( + "tcp", + writer.get_extra_info("peername"), + writer.get_extra_info("sockname"), + ) + handler = ProxyConnectionHandler( + ctx.master, reader, writer, ctx.options, self.mode + ) + handler.layer = self.make_top_layer(handler.layer.context) + with self.manager.register_connection(connection_id, handler): + await handler.handle_client() + + async def start(self): + assert not self.server + self.server = await asyncio.start_server( + self.handle_tcp_connection, + self.mode.listen_host(ctx.options.listen_host), + self.mode.listen_port(ctx.options.listen_port), + ) + + addrs = {f"{human.format_address(s)}" for s in self.listen_addrs} + ctx.log.info( + f"{self.log_desc} listening at {' and '.join(addrs)}." + ) + + @property + @abstractmethod + def log_desc(self) -> str: + pass + + async def stop(self): + assert self.server + self.server.close() + await self.server.wait_closed() + ctx.log.info(f"Stopped {self.mode.type} proxy server.") + + @cached_property + def listen_addrs(self) -> tuple[Address, ...]: + assert self.server + return tuple(s.getsockname() for s in self.server.sockets) + + +class RegularInstance(TcpServerInstance[mode_specs.RegularMode]): + log_desc = "HTTP(S) proxy" + + def make_top_layer(self, context: Context) -> Layer: + return layers.modes.HttpProxy(context) + + +class UpstreamInstance(TcpServerInstance[mode_specs.UpstreamMode]): + log_desc = "HTTP(S) proxy (upstream mode)" + + def make_top_layer(self, context: Context) -> Layer: + return layers.modes.HttpUpstreamProxy(context) + + +class TransparentInstance(TcpServerInstance[mode_specs.TransparentMode]): + log_desc = "Transparent proxy" + + def make_top_layer(self, context: Context) -> Layer: + return layers.modes.TransparentProxy(context) + + +class ReverseInstance(TcpServerInstance[mode_specs.ReverseMode]): + @property + def log_desc(self) -> str: + return f"Reverse proxy to {self.mode.data}" + + def make_top_layer(self, context: Context) -> Layer: + return layers.modes.ReverseProxy(context) + + +class Socks5Instance(TcpServerInstance[mode_specs.Socks5Mode]): + log_desc = "SOCKS v5 proxy" + + def make_top_layer(self, context: Context) -> Layer: + return layers.modes.Socks5Proxy(context) + + +class DnsInstance(ServerInstance[mode_specs.DnsMode]): + server: udp.UdpServer | None = None + + async def start(self): + assert not self.server + self.server = await udp.start_server( + self.handle_dns_datagram, + self.mode.listen_host(ctx.options.listen_host), + self.mode.listen_port(ctx.options.listen_port), + transparent=False + ) + addrs = {f"{human.format_address(s)}" for s in self.listen_addrs} + ctx.log.info( + f"DNS server listening at {' and '.join(addrs)}." + ) + + async def stop(self): + assert self.server + self.server.close() + await self.server.wait_closed() + ctx.log.info(f"Stopped {self.mode.type} proxy server.") + + def handle_dns_datagram( + self, + transport: asyncio.DatagramTransport, + data: bytes, + remote_addr: Address, + local_addr: Address, + ) -> None: + try: + dns_id = struct.unpack_from("!H", data, 0) + except struct.error: + ctx.log.info( + f"Invalid DNS datagram received from {human.format_address(remote_addr)}." + ) + return + connection_id = ("udp", dns_id, remote_addr, local_addr) + if connection_id not in self.manager.connections: + reader = udp.DatagramReader() + writer = udp.DatagramWriter(transport, remote_addr, reader) + handler = ProxyConnectionHandler( + ctx.master, reader, writer, ctx.options, self.mode + ) + handler.timeout_watchdog.CONNECTION_TIMEOUT = 20 + handler.layer = layers.DNSLayer(handler.layer.context) + handler.layer.context.server.address = (self.mode.data or "resolve-local", 53) + handler.layer.context.server.transport_protocol = "udp" + + # pre-register here - we may get datagrams before the task is executed. + self.manager.connections[connection_id] = handler + asyncio.create_task(self.handle_dns_connection(connection_id, handler)) + else: + handler = self.manager.connections[connection_id] + reader = cast(udp.DatagramReader, handler.transports[handler.client].reader) + reader.feed_data(data, remote_addr) + + async def handle_dns_connection(self, connection_id, handler): + with self.manager.register_connection(connection_id, handler): + await handler.handle_client() + + @cached_property + def listen_addrs(self) -> tuple[Address, ...]: + assert self.server + return tuple(s.getsockname() for s in self.server.sockets) diff --git a/mitmproxy/proxy/mode_specs.py b/mitmproxy/proxy/mode_specs.py new file mode 100644 index 0000000000..1173d321c6 --- /dev/null +++ b/mitmproxy/proxy/mode_specs.py @@ -0,0 +1,202 @@ +""" +This module is responsible for parsing proxy mode specifications such as +`"regular"`, `"reverse:https://example.com"`, or `"socks5@1234"`. The general syntax is + + mode [: mode_configuration] [@ [listen_addr:]listen_port] + +For a full example, consider `reverse:https://example.com@127.0.0.1:443`. +This would spawn a reverse proxy on port 443 bound to localhost. +The mode is `reverse`, and the mode data is `https://example.com`. +Examples: + + mode = ProxyMode.parse("regular@1234") + assert mode.listen_port == 1234 + assert isinstance(mode, RegularMode) + + ProxyMode.parse("reverse:example.com@invalid-port") # ValueError + +""" + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from functools import cache +from typing import ClassVar, Literal, Type, TypeVar + +from mitmproxy.coretypes.serializable import Serializable +from mitmproxy.net import server_spec + + +# Python 3.11: Use typing.Self +Self = TypeVar("Self", bound="ProxyMode") + + +@dataclass(frozen=True) # type: ignore +class ProxyMode(Serializable, metaclass=ABCMeta): + """ + Parsed representation of a proxy mode spec. Subclassed for each specific mode, + which then does its own data validation. + """ + full_spec: str + data: str + custom_listen_host: str | None + custom_listen_port: int | None + + transport_protocol: ClassVar[Literal["tcp", "udp"]] = "tcp" + """ + The transport protocol used by this mode. Used to detect multiple servers targeting the same proto+port. + """ + default_port: ClassVar[int] = 8080 + __modes: ClassVar[dict[str, type[ProxyMode]]] = {} + + @abstractmethod + def __post_init__(self) -> None: + """Validation of data happens here.""" + + def listen_host(self, default: str | None = None) -> str: + if self.custom_listen_host is not None: + return self.custom_listen_host + elif default is not None: + return default + else: + return "" + + def listen_port(self, default: int | None = None) -> int: + if self.custom_listen_port is not None: + return self.custom_listen_port + elif default is not None: + return default + else: + return self.default_port + + @classmethod + @property + def type(cls) -> str: + return cls.__name__.removesuffix("Mode").lower() + + @classmethod + @cache + def parse(cls: Type[Self], spec: str) -> Self: + head, _, listen_at = spec.rpartition("@") + if not head: + head = listen_at + listen_at = "" + + mode, _, data = head.partition(":") + + if listen_at: + if ":" in listen_at: + host, _, port_str = listen_at.rpartition(":") + else: + host = None + port_str = listen_at + try: + port = int(port_str) + if port < 0 or 65535 < port: + raise ValueError + except ValueError: + raise ValueError(f"invalid port: {port_str}") + else: + host = None + port = None + + try: + mode_cls = ProxyMode.__modes[mode.lower()] + except KeyError: + raise ValueError(f"unknown mode") + + if not issubclass(mode_cls, cls): + raise ValueError(f"{mode!r} is not a spec for a {cls.type} mode") + + return mode_cls( + full_spec=spec, + data=data, + custom_listen_host=host, + custom_listen_port=port + ) + + def __init_subclass__(cls, **kwargs): + t = cls.type.lower() + assert t not in ProxyMode.__modes + ProxyMode.__modes[t] = cls + + @classmethod + def from_state(cls, state): + return ProxyMode.parse(state) + + def get_state(self): + return self.full_spec + + def set_state(self, state): + if state != self.full_spec: + raise RuntimeError("Proxy modes are frozen.") + + +def _check_empty(data): + if data: + raise ValueError("mode takes no arguments") + + +class RegularMode(ProxyMode): + def __post_init__(self) -> None: + _check_empty(self.data) + + +class TransparentMode(ProxyMode): + def __post_init__(self) -> None: + _check_empty(self.data) + + +class UpstreamMode(ProxyMode): + scheme: Literal["http", "https"] + address: tuple[str, int] + + # noinspection PyDataclass + def __post_init__(self) -> None: + scheme, self.address = server_spec.parse(self.data, default_scheme="http") + if scheme != "http" and scheme != "https": + raise ValueError("invalid upstream proxy scheme") + self.scheme = scheme + + +class ReverseMode(ProxyMode): + scheme: Literal["http", "https", "tcp", "tls"] + address: tuple[str, int] + + # noinspection PyDataclass + def __post_init__(self) -> None: + scheme, self.address = server_spec.parse(self.data, default_scheme="https") + if scheme != "http" and scheme != "https" and scheme != "tcp" and scheme != "tls": + raise ValueError("invalid reverse proxy scheme") + self.scheme = scheme + + +class Socks5Mode(ProxyMode): + default_port = 1080 + + def __post_init__(self) -> None: + _check_empty(self.data) + + +class DnsMode(ProxyMode): + default_port = 53 + transport_protocol: ClassVar[Literal["tcp", "udp"]] = "udp" + scheme: Literal["dns"] # DoH, DoQ, ... + address: tuple[str, int] | None = None + + # noinspection PyDataclass + def __post_init__(self) -> None: + if self.data in ["", "resolve-local", "transparent"]: + return + m, _, server = self.data.partition(":") + if m != "reverse": + raise ValueError("invalid dns mode") + scheme, self.address = server_spec.parse(server, "dns") + if scheme != "dns": + raise ValueError("invalid dns scheme") + self.scheme = scheme + + @property + def resolve_local(self) -> bool: + return self.data in ["", "resolve-local"] diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 20fd4233bf..7c1beed9e7 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -20,7 +20,7 @@ from mitmproxy import http, options as moptions, tls from mitmproxy.proxy.context import Context from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy import commands, events, layer, layers, server_hooks +from mitmproxy.proxy import commands, events, layer, layers, mode_specs, server_hooks from mitmproxy.connection import Address, Client, Connection, ConnectionState from mitmproxy.net import udp from mitmproxy.utils import asyncio_utils @@ -406,11 +406,13 @@ def __init__( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options, + mode: mode_specs.ProxyMode, ) -> None: client = Client( writer.get_extra_info("peername"), writer.get_extra_info("sockname"), time.time(), + proxy_mode=mode, ) context = Context(client, options) super().__init__(context) @@ -424,8 +426,8 @@ class SimpleConnectionHandler(LiveConnectionHandler): # pragma: no cover hook_handlers: dict[str, Callable] - def __init__(self, reader, writer, options, hooks): - super().__init__(reader, writer, options) + def __init__(self, reader, writer, options, mode, hooks): + super().__init__(reader, writer, options, mode) self.hook_handlers = hooks async def handle_hook(self, hook: commands.StartHook) -> None: @@ -459,7 +461,6 @@ def log(self, message: str, level: str = "info"): to the reverse proxy target. """, ) - opts.mode = "reverse:http://127.0.0.1:3000/" async def handle(reader, writer): layer_stack = [ @@ -517,6 +518,7 @@ def tls_start_server(tls_start: tls.TlsData): reader, writer, opts, + mode_specs.ProxyMode.parse("reverse:http://127.0.0.1:3000/"), { "next_layer": next_layer, "request": request, diff --git a/mitmproxy/test/taddons.py b/mitmproxy/test/taddons.py index a52bb5eef5..5ea919593d 100644 --- a/mitmproxy/test/taddons.py +++ b/mitmproxy/test/taddons.py @@ -3,20 +3,18 @@ import mitmproxy.master import mitmproxy.options -from mitmproxy import addonmanager, hooks, log +from mitmproxy import hooks, log from mitmproxy import command from mitmproxy import eventsequence from mitmproxy.addons import script, core -class TestAddons(addonmanager.AddonManager): +class LogRecorder: def __init__(self, master): - super().__init__(master) + self.master: RecordingMaster = master - def trigger(self, event: hooks.Hook): - if isinstance(event, log.AddLogHook): - self.master.logs.append(event.entry) - super().trigger(event) + def add_log(self, entry: log.LogEntry): + self.master.logs.append(entry) class RecordingMaster(mitmproxy.master.Master): @@ -26,7 +24,7 @@ def __init__(self, *args, **kwargs): except RuntimeError: loop = asyncio.new_event_loop() super().__init__(*args, **kwargs, event_loop=loop) - self.addons = TestAddons(self) + self.addons.add(LogRecorder(self)) self.logs = [] def dump_log(self, outf=sys.stdout): diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index 571a22375d..def876b1f1 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -7,6 +7,7 @@ from mitmproxy import http from mitmproxy import tcp from mitmproxy import websocket +from mitmproxy.proxy.mode_specs import ProxyMode from mitmproxy.test.tutils import tdnsreq, tdnsresp from mitmproxy.test.tutils import treq, tresp from wsproto.frame_protocol import Opcode @@ -103,6 +104,7 @@ def tdnsflow( """Create a DNS flow for testing.""" if client_conn is None: client_conn = tclient_conn() + client_conn.proxy_mode = ProxyMode.parse("dns") client_conn.transport_protocol = "udp" if server_conn is None: server_conn = tserver_conn() @@ -206,6 +208,7 @@ def tclient_conn() -> connection.Client: certificate_list=[], alpn_offers=[], cipher_list=[], + proxy_mode="regular", ) ) return c diff --git a/mitmproxy/tools/cmdline.py b/mitmproxy/tools/cmdline.py index 7f902a88a8..15353332b5 100644 --- a/mitmproxy/tools/cmdline.py +++ b/mitmproxy/tools/cmdline.py @@ -74,7 +74,6 @@ def common_options(parser, opts): opts.make_parser(group, "certs", metavar="SPEC") opts.make_parser(group, "cert_passphrase", metavar="PASS") opts.make_parser(group, "ssl_insecure", short="k") - opts.make_parser(group, "key_size", metavar="KEY_SIZE") # Client replay group = parser.add_argument_group("Client Replay") @@ -141,7 +140,6 @@ def mitmweb(opts): opts.make_parser(group, "web_open_browser") opts.make_parser(group, "web_port", metavar="PORT") opts.make_parser(group, "web_host", metavar="HOST") - opts.make_parser(group, "web_columns") common_options(parser, opts) group = parser.add_argument_group( diff --git a/mitmproxy/tools/console/statusbar.py b/mitmproxy/tools/console/statusbar.py index 8ab3a13392..65a2a4e1ac 100644 --- a/mitmproxy/tools/console/statusbar.py +++ b/mitmproxy/tools/console/statusbar.py @@ -8,6 +8,7 @@ from mitmproxy.tools.console import common from mitmproxy.tools.console import signals from mitmproxy.tools.console.commander import commander +from mitmproxy.utils import human class PromptPath: @@ -281,8 +282,8 @@ def get_status(self): if opts: r.append("[%s]" % (":".join(opts))) - if self.master.options.mode != "regular": - r.append("[%s]" % self.master.options.mode) + if self.master.options.mode != ["regular"]: + r.append(f"[{','.join(self.master.options.mode)}]") if self.master.options.scripts: r.append("[scripts:%s]" % len(self.master.options.scripts)) @@ -311,11 +312,12 @@ def redraw(self): ("heading", (f"{arrow} {marked} [{offset}/{fc}]").ljust(11)), ] - if self.master.options.server: - host = self.master.options.listen_host - if host == "0.0.0.0" or host == "": - host = "*" - boundaddr = f"[{host}:{self.master.options.listen_port}]" + listen_addrs: list[str] = list(dict.fromkeys( + human.format_address(a) + for a in self.master.addons.get("proxyserver").listen_addrs() + )) + if listen_addrs: + boundaddr = f"[{', '.join(listen_addrs)}]" else: boundaddr = "" t.extend(self.get_status()) diff --git a/mitmproxy/tools/main.py b/mitmproxy/tools/main.py index 76bd4389be..b58a501654 100644 --- a/mitmproxy/tools/main.py +++ b/mitmproxy/tools/main.py @@ -26,11 +26,12 @@ def process_options(parser, opts, args): args.termlog_verbosity = "debug" args.flow_detail = 2 - adict = {} - for n in dir(args): - if n in opts: - adict[n] = getattr(args, n) - opts.merge(adict) + adict = { + key: val + for key, val in vars(args).items() + if key in opts and val is not None + } + opts.update(**adict) T = TypeVar("T", bound=master.Master) diff --git a/mitmproxy/version.py b/mitmproxy/version.py index b87cabc790..6bcb1dcf08 100644 --- a/mitmproxy/version.py +++ b/mitmproxy/version.py @@ -7,7 +7,7 @@ # Serialization format version. This is displayed nowhere, it just needs to be incremented by one # for each change in the file format. -FLOW_FORMAT_VERSION = 17 +FLOW_FORMAT_VERSION = 18 def get_dev_version() -> str: diff --git a/test/mitmproxy/addons/test_asgiapp.py b/test/mitmproxy/addons/test_asgiapp.py index d282b33368..031dea2591 100644 --- a/test/mitmproxy/addons/test_asgiapp.py +++ b/test/mitmproxy/addons/test_asgiapp.py @@ -55,8 +55,8 @@ async def test_asgi_full(): tctx.master.addons.add(next_layer.NextLayer()) tctx.configure(ps, listen_host="127.0.0.1", listen_port=0) await ps.running() - await tctx.master.await_log("Proxy server listening", level="info") - proxy_addr = ps.tcp_server.sockets[0].getsockname()[:2] + await tctx.master.await_log("HTTP(S) proxy listening", level="info") + proxy_addr = ("127.0.0.1", ps.listen_addrs()[0][1]) reader, writer = await asyncio.open_connection(*proxy_addr) req = f"GET http://testapp:80/ HTTP/1.1\r\n\r\n" diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index 2b93d7aff7..b527a2e7a9 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -52,7 +52,7 @@ async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): flow = tflow.tflow(live=False) flow.request.content = b"data" if mode == "upstream": - tctx.options.mode = f"upstream:http://{addr[0]}:{addr[1]}" + tctx.options.mode = [f"upstream:http://{addr[0]}:{addr[1]}"] flow.request.authority = f"{addr[0]}:{addr[1]}" flow.request.host, flow.request.port = "address", 22 else: @@ -86,7 +86,7 @@ async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): flow = tflow.tflow(live=False) flow.request.scheme = b"https" flow.request.content = b"data" - tctx.options.mode = f"upstream:http://{addr[0]}:{addr[1]}" + tctx.options.mode = [f"upstream:http://{addr[0]}:{addr[1]}"] cp.start_replay([flow]) assert cp.count() == 1 await asyncio.wait_for(cp.queue.join(), 5) diff --git a/test/mitmproxy/addons/test_core.py b/test/mitmproxy/addons/test_core.py index 707d775f29..fc219769d3 100644 --- a/test/mitmproxy/addons/test_core.py +++ b/test/mitmproxy/addons/test_core.py @@ -1,5 +1,3 @@ -from unittest import mock - from mitmproxy.addons import core from mitmproxy.test import taddons from mitmproxy.test import tflow @@ -10,9 +8,9 @@ def test_set(): sa = core.Core() with taddons.context(loadcore=False) as tctx: - assert tctx.master.options.server - tctx.command(sa.set, "server", "false") - assert not tctx.master.options.server + assert tctx.master.options.upstream_cert + tctx.command(sa.set, "upstream_cert", "false") + assert not tctx.master.options.upstream_cert with pytest.raises(exceptions.CommandError): tctx.command(sa.set, "nonexistent") @@ -167,25 +165,6 @@ def test_validation_simple(): tctx.configure( sa, add_upstream_certs_to_client_chain=True, upstream_cert=False ) - with pytest.raises(exceptions.OptionsError, match="Invalid mode"): - tctx.configure(sa, mode="Flibble") - - -@mock.patch("mitmproxy.platform.original_addr", None) -def test_validation_no_transparent(): - sa = core.Core() - with taddons.context() as tctx: - with pytest.raises(Exception, match="Transparent mode not supported"): - tctx.configure(sa, mode="transparent") - - -@mock.patch("mitmproxy.platform.original_addr") -def test_validation_modes(m): - sa = core.Core() - with taddons.context() as tctx: - tctx.configure(sa, mode="reverse:http://localhost") - with pytest.raises(Exception, match="Invalid server specification"): - tctx.configure(sa, mode="reverse:") def test_client_certs(tdata): diff --git a/test/mitmproxy/addons/test_dns_resolver.py b/test/mitmproxy/addons/test_dns_resolver.py index a4b959eb9c..41ef67f82d 100644 --- a/test/mitmproxy/addons/test_dns_resolver.py +++ b/test/mitmproxy/addons/test_dns_resolver.py @@ -8,6 +8,7 @@ from mitmproxy import dns from mitmproxy.addons import dns_resolver, proxyserver from mitmproxy.connection import Address +from mitmproxy.proxy.mode_specs import ProxyMode from mitmproxy.test import taddons, tflow, tutils @@ -17,13 +18,13 @@ async def test_simple(monkeypatch): ) dr = dns_resolver.DnsResolver() - with taddons.context(dr, proxyserver.Proxyserver()) as tctx: + with taddons.context(dr, proxyserver.Proxyserver()): f = tflow.tdnsflow() await dr.dns_request(f) assert f.response - tctx.options.dns_mode = "reverse:8.8.8.8" f = tflow.tdnsflow() + f.client_conn.proxy_mode = ProxyMode.parse("dns:reverse:8.8.8.8") await dr.dns_request(f) assert not f.response diff --git a/test/mitmproxy/addons/test_next_layer.py b/test/mitmproxy/addons/test_next_layer.py index 534fbd6fd6..b1dbd57b4b 100644 --- a/test/mitmproxy/addons/test_next_layer.py +++ b/test/mitmproxy/addons/test_next_layer.py @@ -5,7 +5,7 @@ from mitmproxy import connection from mitmproxy.addons.next_layer import NextLayer from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy import context, layers +from mitmproxy.proxy import context, layer, layers from mitmproxy.test import taddons @@ -76,30 +76,23 @@ def test_ignore_connection(self): is False ) - def test_make_top_layer(self): - nl = NextLayer() + def test_next_layer(self, monkeypatch): ctx = MagicMock() - with taddons.context(nl) as tctx: - tctx.configure(nl, mode="regular") - assert isinstance(nl.make_top_layer(ctx), layers.modes.HttpProxy) - - tctx.configure(nl, mode="transparent") - assert isinstance(nl.make_top_layer(ctx), layers.modes.TransparentProxy) - - tctx.configure(nl, mode="reverse:http://example.com") - assert isinstance(nl.make_top_layer(ctx), layers.modes.ReverseProxy) + nl_layer = layer.NextLayer(ctx) + monkeypatch.setattr(nl_layer, "data_client", lambda: b"\x16\x03\x03") + nl = NextLayer() - tctx.configure(nl, mode="socks5") - assert isinstance(nl.make_top_layer(ctx), layers.modes.Socks5Proxy) + with taddons.context(nl): + nl.next_layer(nl_layer) + assert nl_layer.layer - def test_next_layer(self): + def test_next_layer2(self): nl = NextLayer() ctx = MagicMock() ctx.client.alpn = None ctx.server.address = ("example.com", 443) with taddons.context(nl) as tctx: - ctx.layers = [] - assert isinstance(nl._next_layer(ctx, b"", b""), layers.modes.HttpProxy) + ctx.layers = [layers.modes.HttpProxy(ctx)] assert nl._next_layer(ctx, b"", b"") is None @@ -114,24 +107,20 @@ def test_next_layer(self): ) assert isinstance(ctx.layers[-1], layers.ClientTLSLayer) - ctx.layers = [] - assert isinstance(nl._next_layer(ctx, b"", b""), layers.modes.HttpProxy) + ctx.layers = [layers.modes.HttpProxy(ctx)] assert isinstance( nl._next_layer(ctx, client_hello_no_extensions, b""), layers.ClientTLSLayer, ) - ctx.layers = [] - assert isinstance(nl._next_layer(ctx, b"", b""), layers.modes.HttpProxy) + ctx.layers = [layers.modes.HttpProxy(ctx)] assert isinstance( nl._next_layer(ctx, b"GET http://example.com/ HTTP/1.1\r\n", b""), layers.HttpLayer, ) assert ctx.layers[-1].mode == HTTPMode.regular - ctx.layers = [] - tctx.configure(nl, mode="upstream:http://localhost:8081") - assert isinstance(nl._next_layer(ctx, b"", b""), layers.modes.HttpProxy) + ctx.layers = [layers.modes.HttpUpstreamProxy(ctx)] assert isinstance( nl._next_layer(ctx, b"GET http://example.com/ HTTP/1.1\r\n", b""), layers.HttpLayer, @@ -144,8 +133,3 @@ def test_next_layer(self): tctx.configure(nl, tcp_hosts=[]) assert isinstance(nl._next_layer(ctx, b"GET /foo", b""), layers.HttpLayer) assert isinstance(nl._next_layer(ctx, b"", b"hello"), layers.TCPLayer) - - l = MagicMock() - l.layer = None - nl.next_layer(l) - assert isinstance(l.layer, layers.modes.HttpProxy) diff --git a/test/mitmproxy/addons/test_proxyauth.py b/test/mitmproxy/addons/test_proxyauth.py index a27ccd4671..9bdb5b7ce3 100644 --- a/test/mitmproxy/addons/test_proxyauth.py +++ b/test/mitmproxy/addons/test_proxyauth.py @@ -7,101 +7,87 @@ from mitmproxy import exceptions from mitmproxy.addons import proxyauth from mitmproxy.proxy.layers import modes +from mitmproxy.proxy.mode_specs import ProxyMode from mitmproxy.test import taddons from mitmproxy.test import tflow -class TestMkauth: - def test_mkauth_scheme(self): - assert ( - proxyauth.mkauth("username", "password") - == "basic dXNlcm5hbWU6cGFzc3dvcmQ=\n" - ) - - @pytest.mark.parametrize( - "scheme, expected", - [ - ("", " dXNlcm5hbWU6cGFzc3dvcmQ=\n"), - ("basic", "basic dXNlcm5hbWU6cGFzc3dvcmQ=\n"), - ("foobar", "foobar dXNlcm5hbWU6cGFzc3dvcmQ=\n"), - ], - ) - def test_mkauth(self, scheme, expected): - assert proxyauth.mkauth("username", "password", scheme) == expected - - -class TestParseHttpBasicAuth: - @pytest.mark.parametrize( - "input", - [ - "", - "foo bar", - "basic abc", - "basic " + binascii.b2a_base64(b"foo").decode("ascii"), - ], - ) - def test_parse_http_basic_auth_error(self, input): - with pytest.raises(ValueError): - proxyauth.parse_http_basic_auth(input) - - def test_parse_http_basic_auth(self): - input = proxyauth.mkauth("test", "test") - assert proxyauth.parse_http_basic_auth(input) == ("basic", "test", "test") +@pytest.mark.parametrize( + "scheme, expected", + [ + ("", " dXNlcm5hbWU6cGFzc3dvcmQ=\n"), + ("basic", "basic dXNlcm5hbWU6cGFzc3dvcmQ=\n"), + ("foobar", "foobar dXNlcm5hbWU6cGFzc3dvcmQ=\n"), + ], +) +def test_mkauth(scheme, expected): + assert proxyauth.mkauth("username", "password", scheme) == expected -class TestProxyAuth: - @pytest.mark.parametrize( - "mode, expected", - [ - ("", False), - ("foobar", False), - ("regular", True), - ("upstream:", True), - ("upstream:foobar", True), - ], - ) - def test_is_http_proxy(self, mode, expected): - up = proxyauth.ProxyAuth() - with taddons.context(up, loadcore=False) as ctx: - ctx.options.mode = mode - assert up.is_http_proxy is expected - - @pytest.mark.parametrize( - "is_http_proxy, expected", - [ - (True, "Proxy-Authorization"), - (False, "Authorization"), - ], - ) - def test_which_auth_header(self, is_http_proxy, expected): - up = proxyauth.ProxyAuth() - with mock.patch( - "mitmproxy.addons.proxyauth.ProxyAuth.is_http_proxy", new=is_http_proxy - ): - assert up.http_auth_header == expected - - @pytest.mark.parametrize( - "is_http_proxy, expected_status_code, expected_header", - [ - (True, 407, "Proxy-Authenticate"), - (False, 401, "WWW-Authenticate"), - ], - ) - def test_auth_required_response( - self, is_http_proxy, expected_status_code, expected_header - ): - up = proxyauth.ProxyAuth() - with mock.patch( - "mitmproxy.addons.proxyauth.ProxyAuth.is_http_proxy", new=is_http_proxy - ): - resp = up.make_auth_required_response() - assert resp.status_code == expected_status_code - assert expected_header in resp.headers.keys() +def test_parse_http_basic_auth(): + input = proxyauth.mkauth("test", "test") + assert proxyauth.parse_http_basic_auth(input) == ("basic", "test", "test") + +@pytest.mark.parametrize( + "input", + [ + "", + "foo bar", + "basic abc", + "basic " + binascii.b2a_base64(b"foo").decode("ascii"), + ], +) +def test_parse_http_basic_auth_error(input): + with pytest.raises(ValueError): + proxyauth.parse_http_basic_auth(input) + + +@pytest.mark.parametrize( + "mode, expected", + [ + ("regular", True), + ("upstream:proxy", True), + ("reverse:example.com", False), + ], +) +def test_is_http_proxy(mode, expected): + f = tflow.tflow() + f.client_conn.proxy_mode = ProxyMode.parse(mode) + assert proxyauth.is_http_proxy(f) == expected + + +@pytest.mark.parametrize( + "is_http_proxy, expected", + [ + (True, "Proxy-Authorization"), + (False, "Authorization"), + ], +) +def test_http_auth_header(is_http_proxy, expected): + assert proxyauth.http_auth_header(is_http_proxy) == expected + + +@pytest.mark.parametrize( + "is_http_proxy, expected_status_code, expected_header", + [ + (True, 407, "Proxy-Authenticate"), + (False, 401, "WWW-Authenticate"), + ], +) +def test_make_auth_required_response( + is_http_proxy, expected_status_code, expected_header +): + resp = proxyauth.make_auth_required_response(is_http_proxy) + assert resp.status_code == expected_status_code + assert expected_header in resp.headers.keys() + + +class TestProxyAuth: def test_socks5(self): pa = proxyauth.ProxyAuth() with taddons.context(pa, loadcore=False) as ctx: - ctx.configure(pa, proxyauth="foo:bar", mode="regular") + ctx.configure(pa, proxyauth="foo:bar") data = modes.Socks5AuthData(tflow.tclient_conn(), "foo", "baz") pa.socks5_auth(data) assert not data.valid @@ -112,9 +98,10 @@ def test_socks5(self): def test_authenticate(self): up = proxyauth.ProxyAuth() with taddons.context(up, loadcore=False) as ctx: - ctx.configure(up, proxyauth="any", mode="regular") + ctx.configure(up, proxyauth="any") f = tflow.tflow() + f.client_conn.proxy_mode = ProxyMode.parse("regular") assert not f.response up.authenticate_http(f) assert f.response.status_code == 407 @@ -126,12 +113,13 @@ def test_authenticate(self): assert not f.request.headers.get("Proxy-Authorization") f = tflow.tflow() - ctx.configure(up, mode="reverse") + f.client_conn.proxy_mode = ProxyMode.parse("reverse:https://example.com") assert not f.response up.authenticate_http(f) assert f.response.status_code == 401 f = tflow.tflow() + f.client_conn.proxy_mode = ProxyMode.parse("reverse:https://example.com") f.request.headers["Authorization"] = proxyauth.mkauth("test", "test") up.authenticate_http(f) assert not f.response @@ -210,16 +198,10 @@ def test_configure(self, monkeypatch, tdata): assert pa.validator("test", "test") assert not pa.validator("test", "foo") - with pytest.raises( - exceptions.OptionsError, - match="Proxy Authentication not supported in transparent mode.", - ): - ctx.configure(pa, proxyauth="any", mode="transparent") - def test_handlers(self): up = proxyauth.ProxyAuth() with taddons.context(up) as ctx: - ctx.configure(up, proxyauth="any", mode="regular") + ctx.configure(up, proxyauth="any") f = tflow.tflow() assert not f.response diff --git a/test/mitmproxy/addons/test_proxyserver.py b/test/mitmproxy/addons/test_proxyserver.py index 9abeb983ec..5b12cdd8ff 100644 --- a/test/mitmproxy/addons/test_proxyserver.py +++ b/test/mitmproxy/addons/test_proxyserver.py @@ -1,9 +1,11 @@ import asyncio from contextlib import asynccontextmanager import socket +from unittest.mock import Mock import pytest +import mitmproxy.platform from mitmproxy import dns, exceptions from mitmproxy.addons import dns_resolver from mitmproxy.addons.proxyserver import Proxyserver @@ -20,7 +22,6 @@ class HelperAddon: def __init__(self): self.flows = [] self.layers = [ - lambda ctx: layers.modes.HttpProxy(ctx), lambda ctx: layers.HttpLayer(ctx, HTTPMode.regular), lambda ctx: layers.TCPLayer(ctx), ] @@ -60,12 +61,12 @@ async def server_handler( tctx.master.addons.add(state) async with tcp_server(server_handler) as addr: tctx.configure(ps, listen_host="127.0.0.1", listen_port=0) - assert not ps.tcp_server + assert not ps.servers await ps.running() - await tctx.master.await_log("Proxy server listening", level="info") - assert ps.tcp_server + await tctx.master.await_log("HTTP(S) proxy listening", level="info") + assert ps.servers - proxy_addr = ps.tcp_server.sockets[0].getsockname()[:2] + proxy_addr = ps.listen_addrs()[0] reader, writer = await asyncio.open_connection(*proxy_addr) req = f"GET http://{addr[0]}:{addr[1]}/hello HTTP/1.1\r\n\r\n" writer.write(req.encode()) @@ -73,17 +74,20 @@ async def server_handler( await reader.readuntil(b"\r\n\r\n") == b"HTTP/1.1 204 No Content\r\n\r\n" ) - assert repr(ps) == "ProxyServer(running, 1 active conns)" + assert repr(ps) == "Proxyserver(1 active conns)" + await ps.setup_servers() # assert this can always be called without side effects tctx.configure(ps, server=False) - await tctx.master.await_log("Stopping Proxy server", level="info") - assert not ps.tcp_server + await tctx.master.await_log("Stopped regular proxy server.", level="info") + async with ps._lock: + pass # wait until start/stop is finished. + assert not ps.servers assert state.flows assert state.flows[0].request.path == "/hello" assert state.flows[0].response.status_code == 204 # Waiting here until everything is really torn down... takes some effort. - conn_handler = list(ps._connections.values())[0] + conn_handler = list(ps.connections.values())[0] client_handler = conn_handler.transports[conn_handler.client].handler writer.close() await writer.wait_closed() @@ -94,7 +98,7 @@ async def server_handler( for _ in range(5): # Get all other scheduled coroutines to run. await asyncio.sleep(0) - assert repr(ps) == "ProxyServer(stopped, 0 active conns)" + assert repr(ps) == "Proxyserver(0 active conns)" async def test_inject() -> None: @@ -111,8 +115,8 @@ async def server_handler( async with tcp_server(server_handler) as addr: tctx.configure(ps, listen_host="127.0.0.1", listen_port=0) await ps.running() - await tctx.master.await_log("Proxy server listening", level="info") - proxy_addr = ps.tcp_server.sockets[0].getsockname()[:2] + await tctx.master.await_log("HTTP(S) proxy listening", level="info") + proxy_addr = ps.servers["regular"].listen_addrs[0] reader, writer = await asyncio.open_connection(*proxy_addr) req = f"CONNECT {addr[0]}:{addr[1]} HTTP/1.1\r\n\r\n" @@ -155,13 +159,9 @@ async def test_warn_no_nextlayer(): """ ps = Proxyserver() with taddons.context(ps) as tctx: - tctx.configure(ps, listen_host="127.0.0.1", listen_port=0) + tctx.configure(ps, listen_host="127.0.0.1", listen_port=0, server=False) await ps.running() - await tctx.master.await_log("Proxy server listening at", level="info") - assert tctx.master.has_log( - "Warning: Running proxyserver without nextlayer addon!", level="warn" - ) - await ps.shutdown_server() + await tctx.master.await_log("Warning: Running proxyserver without nextlayer addon!", level="warn") async def test_self_connect(): @@ -170,40 +170,37 @@ async def test_self_connect(): server.address = ("localhost", 8080) ps = Proxyserver() with taddons.context(ps) as tctx: - # not calling .running() here to avoid unnecessary socket - ps.options = tctx.options + tctx.configure(ps, listen_host="127.0.0.1", listen_port=0) + await ps.running() + await tctx.master.await_log("HTTP(S) proxy listening", level="info") + assert ps.servers + server.address = ("localhost", ps.servers["regular"].listen_addrs[0][1]) ps.server_connect(server_hooks.ServerConnectionHookData(server, client)) assert "Request destination unknown" in server.error + tctx.configure(ps, server=False) def test_options(): ps = Proxyserver() with taddons.context(ps) as tctx: - with pytest.raises(exceptions.OptionsError): - tctx.configure(ps, body_size_limit="invalid") - tctx.configure(ps, body_size_limit="1m") - with pytest.raises(exceptions.OptionsError): tctx.configure(ps, stream_large_bodies="invalid") tctx.configure(ps, stream_large_bodies="1m") - with pytest.raises(exceptions.OptionsError): - tctx.configure(ps, dns_mode="invalid") - tctx.configure(ps, dns_mode="regular") - - with pytest.raises(exceptions.OptionsError): - tctx.configure(ps, dns_mode="reverse") - tctx.configure(ps, dns_mode="reverse:8.8.8.8") - assert ps.dns_reverse_addr == ("8.8.8.8", 53) with pytest.raises(exceptions.OptionsError): - tctx.configure(ps, dns_mode="reverse:invalid:53") - tctx.configure(ps, dns_mode="reverse:8.8.8.8:53") - assert ps.dns_reverse_addr == ("8.8.8.8", 53) + tctx.configure(ps, body_size_limit="invalid") + tctx.configure(ps, body_size_limit="1m") with pytest.raises(exceptions.OptionsError): tctx.configure(ps, connect_addr="invalid") tctx.configure(ps, connect_addr="1.2.3.4") - assert ps.connect_addr == ("1.2.3.4", 0) + assert ps._connect_addr == ("1.2.3.4", 0) + + with pytest.raises(exceptions.OptionsError): + tctx.configure(ps, mode=["invalid!"]) + with pytest.raises(exceptions.OptionsError): + tctx.configure(ps, mode=["regular", "reverse:example.com"]) + tctx.configure(ps, mode=["regular"], server=False) async def test_startup_err(monkeypatch) -> None: @@ -219,19 +216,18 @@ async def _raise(*_): async def test_shutdown_err() -> None: - def _raise(*_): + async def _raise(*_): raise OSError("cannot close") ps = Proxyserver() with taddons.context(ps) as tctx: tctx.configure(ps, listen_host="127.0.0.1", listen_port=0) await ps.running() - assert ps.running_servers - for server in ps.running_servers: - setattr(server, "close", _raise) - await ps.shutdown_server() + assert ps.servers + for server in ps.servers.values(): + setattr(server, "stop", _raise) + tctx.configure(ps, server=False) await tctx.master.await_log("cannot close", level="error") - assert ps.running_servers class DummyResolver: @@ -251,16 +247,12 @@ async def test_dns() -> None: with taddons.context(ps, DummyResolver()) as tctx: tctx.configure( ps, - server=False, - dns_server=True, - dns_listen_host="127.0.0.1", - dns_listen_port=0, - dns_mode="regular", + mode=["dns@127.0.0.1:0"], ) await ps.running() await tctx.master.await_log("DNS server listening at", level="info") - assert ps.dns_server - dns_addr = ps.dns_server.sockets[0].getsockname()[:2] + assert ps.servers + dns_addr = ps.servers["dns@127.0.0.1:0"].listen_addrs[0] r, w = await udp.open_connection(*dns_addr) w.write(b"\x00") await tctx.master.await_log("Invalid DNS datagram received", level="info") @@ -268,15 +260,33 @@ async def test_dns() -> None: w.write(req.packed) resp = dns.Message.unpack(await r.read(udp.MAX_DATAGRAM_SIZE)) assert req.id == resp.id and "8.8.8.8" in str(resp) - assert len(ps._connections) == 1 + assert len(ps.connections) == 1 w.write(req.packed) resp = dns.Message.unpack(await r.read(udp.MAX_DATAGRAM_SIZE)) assert req.id == resp.id and "8.8.8.8" in str(resp) - assert len(ps._connections) == 1 + assert len(ps.connections) == 1 req.id = req.id + 1 w.write(req.packed) resp = dns.Message.unpack(await r.read(udp.MAX_DATAGRAM_SIZE)) assert req.id == resp.id and "8.8.8.8" in str(resp) - assert len(ps._connections) == 2 - await ps.shutdown_server() - await tctx.master.await_log("Stopping DNS server", level="info") + assert len(ps.connections) == 2 + tctx.configure(ps, server=False) + await tctx.master.await_log("Stopped dns proxy server.", level="info") + + +def test_validation_no_transparent(monkeypatch): + monkeypatch.setattr(mitmproxy.platform, "original_addr", None) + ps = Proxyserver() + with taddons.context(ps) as tctx: + with pytest.raises(Exception, match="Transparent mode not supported"): + tctx.configure(ps, mode=["transparent"]) + + +def test_transparent_init(monkeypatch): + init = Mock() + monkeypatch.setattr(mitmproxy.platform, "original_addr", lambda: 1) + monkeypatch.setattr(mitmproxy.platform, "init_transparent_mode", init) + ps = Proxyserver() + with taddons.context(ps) as tctx: + tctx.configure(ps, mode=["transparent"], server=False) + assert init.called diff --git a/test/mitmproxy/addons/test_script.py b/test/mitmproxy/addons/test_script.py index 071911628b..8163532abd 100644 --- a/test/mitmproxy/addons/test_script.py +++ b/test/mitmproxy/addons/test_script.py @@ -185,14 +185,14 @@ async def test_simple(self, tdata): with taddons.context(loadcore=False) as tctx: tctx.master.addons.add(sc) sc.running() - assert len(tctx.master.addons) == 1 + assert len(tctx.master.addons) == 2 tctx.master.options.update( scripts=[tdata.path("mitmproxy/data/addonscripts/recorder/recorder.py")] ) - assert len(tctx.master.addons) == 1 + assert len(tctx.master.addons) == 2 assert len(sc.addons) == 1 tctx.master.options.update(scripts=[]) - assert len(tctx.master.addons) == 1 + assert len(tctx.master.addons) == 2 assert len(sc.addons) == 0 def test_dupes(self): diff --git a/test/mitmproxy/addons/test_upstream_auth.py b/test/mitmproxy/addons/test_upstream_auth.py index 67ec559277..1eb8eb0131 100644 --- a/test/mitmproxy/addons/test_upstream_auth.py +++ b/test/mitmproxy/addons/test_upstream_auth.py @@ -2,6 +2,7 @@ import pytest from mitmproxy import exceptions +from mitmproxy.proxy.mode_specs import ProxyMode from mitmproxy.test import taddons from mitmproxy.test import tflow from mitmproxy.addons import upstream_auth @@ -41,10 +42,10 @@ def test_simple(): assert "proxy-authorization" not in f.request.headers assert "authorization" not in f.request.headers - tctx.configure(up, mode="upstream:127.0.0.1") + f.client_conn.proxy_mode = ProxyMode.parse("upstream:127.0.0.1") up.requestheaders(f) assert "proxy-authorization" in f.request.headers - tctx.configure(up, mode="reverse:127.0.0.1") + f.client_conn.proxy_mode = ProxyMode.parse("reverse:127.0.0.1") up.requestheaders(f) assert "authorization" in f.request.headers diff --git a/test/mitmproxy/net/test_server_spec.py b/test/mitmproxy/net/test_server_spec.py index ba527a20ee..1fe5590182 100644 --- a/test/mitmproxy/net/test_server_spec.py +++ b/test/mitmproxy/net/test_server_spec.py @@ -4,40 +4,34 @@ @pytest.mark.parametrize( - "spec,out", + "spec,default_scheme,out", [ - ("example.com", ("https", ("example.com", 443))), - ("http://example.com", ("http", ("example.com", 80))), - ("smtp.example.com:25", ("http", ("smtp.example.com", 25))), - ("http://127.0.0.1", ("http", ("127.0.0.1", 80))), - ("http://[::1]", ("http", ("::1", 80))), - ("http://[::1]/", ("http", ("::1", 80))), - ("https://[::1]/", ("https", ("::1", 443))), - ("http://[::1]:8080", ("http", ("::1", 8080))), + ("example.com", "https", ("https", ("example.com", 443))), + ("http://example.com", "https", ("http", ("example.com", 80))), + ("smtp.example.com:25", "tcp", ("tcp", ("smtp.example.com", 25))), + ("http://127.0.0.1", "https", ("http", ("127.0.0.1", 80))), + ("http://[::1]", "https", ("http", ("::1", 80))), + ("http://[::1]/", "https", ("http", ("::1", 80))), + ("https://[::1]/", "https", ("https", ("::1", 443))), + ("http://[::1]:8080", "https", ("http", ("::1", 8080))), ], ) -def test_parse(spec, out): - assert server_spec.parse(spec) == out +def test_parse(spec, default_scheme, out): + assert server_spec.parse(spec, default_scheme) == out def test_parse_err(): with pytest.raises(ValueError, match="Invalid server specification"): - server_spec.parse(":") + server_spec.parse(":", "https") with pytest.raises(ValueError, match="Invalid server scheme"): - server_spec.parse("ftp://example.com") + server_spec.parse("ftp://example.com", "https") with pytest.raises(ValueError, match="Invalid hostname"): - server_spec.parse("$$$") + server_spec.parse("$$$", "https") with pytest.raises(ValueError, match="Invalid port"): - server_spec.parse("example.com:999999") + server_spec.parse("example.com:999999", "https") - -def test_parse_with_mode(): - assert server_spec.parse_with_mode("m:example.com") == ( - "m", - ("https", ("example.com", 443)), - ) - with pytest.raises(ValueError): - server_spec.parse_with_mode("moo") + with pytest.raises(ValueError, match="Port specification missing"): + server_spec.parse("example.com", "tcp") diff --git a/test/mitmproxy/proxy/layers/http/test_http.py b/test/mitmproxy/proxy/layers/http/test_http.py index 6c82a58143..21d3ab32e4 100644 --- a/test/mitmproxy/proxy/layers/http/test_http.py +++ b/test/mitmproxy/proxy/layers/http/test_http.py @@ -4,7 +4,6 @@ from mitmproxy.connection import ConnectionState, Server from mitmproxy.http import HTTPFlow, Response -from mitmproxy.net.server_spec import ServerSpec from mitmproxy.proxy import layer from mitmproxy.proxy.commands import CloseConnection, Log, OpenConnection, SendData from mitmproxy.proxy.events import ConnectionClosed, DataReceived @@ -12,6 +11,7 @@ from mitmproxy.proxy.layers.http import HTTPMode from mitmproxy.proxy.layers.tcp import TcpMessageInjected, TcpStartHook from mitmproxy.proxy.layers.websocket import WebsocketStartHook +from mitmproxy.proxy.mode_specs import ProxyMode from mitmproxy.tcp import TCPFlow, TCPMessage from test.mitmproxy.proxy.tutils import ( BytesMatching, @@ -729,7 +729,7 @@ def test_upstream_proxy(tctx, redirect, domain, scheme): server = Placeholder(Server) server2 = Placeholder(Server) flow = Placeholder(HTTPFlow) - tctx.options.mode = "upstream:http://proxy:8080" + tctx.client.proxy_mode = ProxyMode.parse("upstream:http://proxy:8080") playbook = Playbook(http.HttpLayer(tctx, HTTPMode.upstream), hooks=False) if scheme == "http": @@ -782,7 +782,7 @@ def test_upstream_proxy(tctx, redirect, domain, scheme): flow().request.host = domain + b".test" flow().request.host_header = domain elif redirect == "change-proxy": - flow().server_conn.via = ServerSpec("http", address=("other-proxy", 1234)) + flow().server_conn.via = ("http", ("other-proxy", 1234)) playbook >> reply() if redirect: @@ -829,10 +829,10 @@ def test_upstream_proxy(tctx, redirect, domain, scheme): if redirect == "change-proxy": assert ( - server2().address == flow().server_conn.via.address == ("other-proxy", 1234) + server2().address == flow().server_conn.via[1] == ("other-proxy", 1234) ) else: - assert server2().address == flow().server_conn.via.address == ("proxy", 8080) + assert server2().address == flow().server_conn.via[1] == ("proxy", 8080) playbook >> ConnectionClosed(tctx.client) playbook << CloseConnection(tctx.client) @@ -848,10 +848,10 @@ def test_http_proxy_tcp(tctx, mode, close_first): tctx.options.connection_strategy = "lazy" if mode == "upstream": - tctx.options.mode = "upstream:http://proxy:8080" + tctx.client.proxy_mode = ProxyMode.parse("upstream:http://proxy:8080") toplayer = http.HttpLayer(tctx, HTTPMode.upstream) else: - tctx.options.mode = "regular" + tctx.client.proxy_mode = ProxyMode.parse("regular") toplayer = http.HttpLayer(tctx, HTTPMode.regular) playbook = Playbook(toplayer, hooks=False) diff --git a/test/mitmproxy/proxy/layers/test_modes.py b/test/mitmproxy/proxy/layers/test_modes.py index 6ebe4943bd..eacded4093 100644 --- a/test/mitmproxy/proxy/layers/test_modes.py +++ b/test/mitmproxy/proxy/layers/test_modes.py @@ -23,6 +23,7 @@ TlsStartClientHook, TlsStartServerHook, ) +from mitmproxy.proxy.mode_specs import ProxyMode from mitmproxy.tcp import TCPFlow from test.mitmproxy.proxy.layers.test_tls import ( reply_tls_start_client, @@ -44,15 +45,15 @@ def test_upstream_https(tctx): Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), copy.deepcopy(tctx.options), ) - tctx1.options.mode = "upstream:https://example.mitmproxy.org:8081" + tctx1.client.proxy_mode = ProxyMode.parse("upstream:https://example.mitmproxy.org:8081") tctx2 = Context( Client(("client", 4321), ("127.0.0.1", 8080), 1605699329), copy.deepcopy(tctx.options), ) - assert tctx2.options.mode == "regular" + assert tctx2.client.proxy_mode == ProxyMode.parse("regular") del tctx - proxy1 = Playbook(modes.HttpProxy(tctx1), hooks=False) + proxy1 = Playbook(modes.HttpUpstreamProxy(tctx1), hooks=False) proxy2 = Playbook(modes.HttpProxy(tctx2), hooks=False) upstream = Placeholder(Server) @@ -124,7 +125,7 @@ def test_reverse_proxy(tctx, keep_host_header): - make sure that we include non-standard ports in the host header (#4280) """ server = Placeholder(Server) - tctx.options.mode = "reverse:http://localhost:8000" + tctx.client.proxy_mode = ProxyMode.parse("reverse:http://localhost:8000") tctx.options.connection_strategy = "lazy" tctx.options.keep_host_header = keep_host_header assert ( @@ -165,7 +166,7 @@ def test_reverse_proxy_tcp_over_tls( flow = Placeholder(TCPFlow) data = Placeholder(bytes) - tctx.options.mode = "reverse:https://localhost:8000" + tctx.client.proxy_mode = ProxyMode.parse("reverse:https://localhost:8000") tctx.options.connection_strategy = connection_strategy playbook = Playbook(modes.ReverseProxy(tctx)) if connection_strategy == "eager": @@ -273,7 +274,7 @@ def test_reverse_eager_connect_failure(tctx: Context): reverse proxying. """ - tctx.options.mode = "reverse:https://localhost:8000" + tctx.client.proxy_mode = ProxyMode.parse("reverse:https://localhost:8000") tctx.options.connection_strategy = "eager" playbook = Playbook(modes.ReverseProxy(tctx)) assert ( diff --git a/test/mitmproxy/proxy/test_mode_servers.py b/test/mitmproxy/proxy/test_mode_servers.py new file mode 100644 index 0000000000..7dacbfb951 --- /dev/null +++ b/test/mitmproxy/proxy/test_mode_servers.py @@ -0,0 +1,77 @@ +import asyncio +from typing import cast +from unittest.mock import AsyncMock, MagicMock, Mock + +from mitmproxy.net import udp +from mitmproxy.proxy.mode_servers import DnsInstance, ServerInstance +from mitmproxy.test import taddons + + +def test_make(): + manager = Mock() + context = MagicMock() + assert ServerInstance.make("regular", manager) + + for mode in ["regular", "upstream:example.com", "transparent", "reverse:example.com", "socks5"]: + inst = ServerInstance.make(mode, manager) + assert inst + assert inst.make_top_layer(context) + assert inst.log_desc + + +async def test_tcp_start_stop(): + manager = MagicMock() + + with taddons.context() as tctx: + inst = ServerInstance.make("regular@127.0.0.1:0", manager) + await inst.start() + assert await tctx.master.await_log("proxy listening") + + host, port, *_ = inst.listen_addrs[0] + reader, writer = await asyncio.open_connection(host, port) + assert await tctx.master.await_log("client connect") + + writer.close() + await writer.wait_closed() + assert await tctx.master.await_log("client disconnect") + + await inst.stop() + assert await tctx.master.await_log("Stopped regular proxy server.") + + +async def test_udp_start_stop(): + manager = MagicMock() + + with taddons.context() as tctx: + inst = ServerInstance.make("dns@127.0.0.1:0", manager) + await inst.start() + assert await tctx.master.await_log("server listening") + + host, port, *_ = inst.listen_addrs[0] + reader, writer = await udp.open_connection(host, port) + writer.write(b"\x00") + assert await tctx.master.await_log("Invalid DNS datagram received") + + writer.write(b"\x00\x00\x01") + assert await tctx.master.await_log("sent an invalid message") + + writer.close() + + await inst.stop() + assert await tctx.master.await_log("Stopped") + + +async def test_udp_connection_reuse(monkeypatch): + manager = MagicMock() + manager.connections = {} + + monkeypatch.setattr(udp, "DatagramWriter", MagicMock()) + monkeypatch.setattr(DnsInstance, "handle_dns_connection", AsyncMock()) + + with taddons.context(): + inst = cast(DnsInstance, ServerInstance.make("dns", manager)) + inst.handle_dns_datagram(MagicMock(), b"\x00\x00\x01", ("remoteaddr", 0), ("localaddr", 0)) + inst.handle_dns_datagram(MagicMock(), b"\x00\x00\x02", ("remoteaddr", 0), ("localaddr", 0)) + await asyncio.sleep(0) + + assert len(inst.manager.connections) == 1 diff --git a/test/mitmproxy/proxy/test_mode_specs.py b/test/mitmproxy/proxy/test_mode_specs.py new file mode 100644 index 0000000000..c6895010d0 --- /dev/null +++ b/test/mitmproxy/proxy/test_mode_specs.py @@ -0,0 +1,73 @@ +import pytest + +from mitmproxy.proxy.mode_specs import ProxyMode, Socks5Mode + + +def test_parse(): + m = ProxyMode.parse("reverse:https://example.com/@127.0.0.1:443") + m = ProxyMode.from_state(m.get_state()) + + assert m.type == "reverse" + assert m.full_spec == "reverse:https://example.com/@127.0.0.1:443" + assert m.data == "https://example.com/" + assert m.custom_listen_host == "127.0.0.1" + assert m.custom_listen_port == 443 + + with pytest.raises(ValueError, match="unknown mode"): + ProxyMode.parse("flibbel") + + with pytest.raises(ValueError, match="invalid port"): + ProxyMode.parse("regular@invalid-port") + + with pytest.raises(ValueError, match="invalid port"): + ProxyMode.parse("regular@99999") + + m.set_state(m.get_state()) + with pytest.raises(RuntimeError, match="Proxy modes are frozen"): + m.set_state("regular") + + +def test_parse_subclass(): + assert Socks5Mode.parse("socks5") + with pytest.raises(ValueError, match="'regular' is not a spec for a socks5 mode"): + Socks5Mode.parse("regular") + + +def test_listen_addr(): + assert ProxyMode.parse("regular").listen_port() == 8080 + assert ProxyMode.parse("regular@1234").listen_port() == 1234 + assert ProxyMode.parse("regular").listen_port(default=4424) == 4424 + assert ProxyMode.parse("regular@1234").listen_port(default=4424) == 1234 + + assert ProxyMode.parse("regular").listen_host() == "" + assert ProxyMode.parse("regular@127.0.0.2:8080").listen_host() == "127.0.0.2" + assert ProxyMode.parse("regular").listen_host(default="127.0.0.3") == "127.0.0.3" + assert ProxyMode.parse("regular@127.0.0.2:8080").listen_host(default="127.0.0.3") == "127.0.0.2" + + +def test_parse_specific_modes(): + assert ProxyMode.parse("regular") + assert ProxyMode.parse("transparent") + assert ProxyMode.parse("upstream:https://proxy") + assert ProxyMode.parse("reverse:https://host@443") + assert ProxyMode.parse("socks5") + assert ProxyMode.parse("dns").resolve_local + assert ProxyMode.parse("dns:reverse:8.8.8.8") + + with pytest.raises(ValueError, match="invalid port"): + ProxyMode.parse("regular@invalid-port") + + with pytest.raises(ValueError, match="takes no arguments"): + ProxyMode.parse("regular:configuration") + + with pytest.raises(ValueError, match="invalid upstream proxy scheme"): + ProxyMode.parse("upstream:dns://example.com") + + with pytest.raises(ValueError, match="invalid reverse proxy scheme"): + ProxyMode.parse("reverse:dns://example.com") + + with pytest.raises(ValueError, match="invalid dns mode"): + ProxyMode.parse("dns:invalid") + + with pytest.raises(ValueError, match="invalid dns scheme"): + ProxyMode.parse("dns:reverse:https://example.com") diff --git a/test/mitmproxy/test_addonmanager.py b/test/mitmproxy/test_addonmanager.py index 1686d4923a..da9151c5b8 100644 --- a/test/mitmproxy/test_addonmanager.py +++ b/test/mitmproxy/test_addonmanager.py @@ -131,7 +131,7 @@ async def test_mixed_async_sync(): with taddons.context(loadcore=False) as tctx: a = tctx.master.addons - assert len(a) == 0 + assert len(a) == 1 a1 = TAddon("sync") a2 = AsyncTAddon("async") a.add(a1) @@ -177,15 +177,16 @@ async def test_simple(): with taddons.context(loadcore=False) as tctx: a = tctx.master.addons - assert len(a) == 0 + assert len(a) == 1 a.add(TAddon("one")) assert a.get("one") assert not a.get("two") - assert len(a) == 1 + assert len(a) == 2 a.clear() assert len(a) == 0 assert not a.chain + with taddons.context(loadcore=False) as tctx: a.add(TAddon("one")) a.trigger("nonexistent") diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 55beb7d42d..a44408c259 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -114,7 +114,7 @@ def test_copy(self): class TestFlowMaster: async def test_load_http_flow_reverse(self): - opts = options.Options(mode="reverse:https://use-this-domain") + opts = options.Options(mode=["reverse:https://use-this-domain"]) s = State() with taddons.context(s, options=opts) as ctx: f = tflow.tflow(resp=True) @@ -122,7 +122,7 @@ async def test_load_http_flow_reverse(self): assert s.flows[0].request.host == "use-this-domain" async def test_all(self): - opts = options.Options(mode="reverse:https://use-this-domain") + opts = options.Options(mode=["reverse:https://use-this-domain"]) s = State() with taddons.context(s, options=opts) as ctx: f = tflow.tflow(req=None) diff --git a/test/mitmproxy/tools/console/test_statusbar.py b/test/mitmproxy/tools/console/test_statusbar.py index a54b983fe3..08b3439919 100644 --- a/test/mitmproxy/tools/console/test_statusbar.py +++ b/test/mitmproxy/tools/console/test_statusbar.py @@ -24,7 +24,7 @@ async def test_statusbar(monkeypatch): server_replay_kill_extra=True, upstream_cert=False, stream_large_bodies="3m", - mode="transparent", + mode=["transparent"], ) m.options.update(view_order="url", console_focus_follow=True) diff --git a/test/mitmproxy/tools/web/test_app.py b/test/mitmproxy/tools/web/test_app.py index 8cd88074fa..1c86750ac3 100644 --- a/test/mitmproxy/tools/web/test_app.py +++ b/test/mitmproxy/tools/web/test_app.py @@ -110,6 +110,8 @@ def ts_type(t): return "string[]" if t == Optional[str]: return "string | undefined" + if t == Optional[int]: + return "number | undefined" raise RuntimeError(t) with redirect_stdout(io.StringIO()) as s: diff --git a/web/src/js/__tests__/components/Modal/__snapshots__/ModalSpec.tsx.snap b/web/src/js/__tests__/components/Modal/__snapshots__/ModalSpec.tsx.snap index 32358e6103..3a0c54b2f5 100644 --- a/web/src/js/__tests__/components/Modal/__snapshots__/ModalSpec.tsx.snap +++ b/web/src/js/__tests__/components/Modal/__snapshots__/ModalSpec.tsx.snap @@ -193,6 +193,15 @@ exports[`Modal Component 2`] = ` value="8080" /> +
+ Default: + + 8080 + + +
diff --git a/web/src/js/components/Footer.tsx b/web/src/js/components/Footer.tsx index 4df2683b26..4770709df5 100644 --- a/web/src/js/components/Footer.tsx +++ b/web/src/js/components/Footer.tsx @@ -6,14 +6,14 @@ import {useAppSelector} from "../ducks"; export default function Footer() { const version = useAppSelector(state => state.conf.version); let { - mode, intercept, showhost, upstream_cert, rawtcp, dns_server, http2, websocket, anticache, anticomp, + mode, intercept, showhost, upstream_cert, rawtcp, http2, websocket, anticache, anticomp, stickyauth, stickycookie, stream_large_bodies, listen_host, listen_port, server, ssl_insecure } = useAppSelector(state => state.options); return (