From a22644d7ced90ee796592622eebd9629377b39ba Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Mon, 7 Aug 2023 18:16:42 -0700 Subject: [PATCH] Ensure async https() requests are bounded in total time according to the timeout [#978]. Unfortunately we do not currently have a good way to make this guarantee for sync https() calls. --- dns/_asyncbackend.py | 3 +++ dns/_asyncio_backend.py | 3 +++ dns/_trio_backend.py | 7 +++++++ dns/asyncquery.py | 8 ++++---- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py index cebcbdfd4..49f14fed6 100644 --- a/dns/_asyncbackend.py +++ b/dns/_asyncbackend.py @@ -94,3 +94,6 @@ async def sleep(self, interval): def get_transport_class(self): raise NotImplementedError + + async def wait_for(self, awaitable, timeout): + raise NotImplementedError diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 0021f84fe..2631228ec 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -270,3 +270,6 @@ def datagram_connection_required(self): def get_transport_class(self): return _HTTPTransport + + async def wait_for(self, awaitable, timeout): + return await _maybe_wait_for(awaitable, timeout) diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index d414f0b37..4d9fb8204 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -237,3 +237,10 @@ async def sleep(self, interval): def get_transport_class(self): return _HTTPTransport + + async def wait_for(self, awaitable, timeout): + with _maybe_timeout(timeout): + return await awaitable + raise dns.exception.Timeout( + timeout=timeout + ) # pragma: no cover lgtm[py/unreachable-statement] diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 737e1c922..ecf9c1a5f 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -563,14 +563,14 @@ async def https( "content-length": str(len(wire)), } ) - response = await the_client.post( - url, headers=headers, content=wire, timeout=timeout + response = await backend.wait_for( + the_client.post(url, headers=headers, content=wire), timeout ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") twire = wire.decode() # httpx does a repr() if we give it bytes - response = await the_client.get( - url, headers=headers, timeout=timeout, params={"dns": twire} + response = await backend.wait_for( + the_client.get(url, headers=headers, params={"dns": twire}), timeout ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH