Skip to content

Commit

Permalink
Ensure async https() requests are bounded in total time
Browse files Browse the repository at this point in the history
according to the timeout [#978].

Unfortunately we do not currently have a good way to
make this guarantee for sync https() calls.
  • Loading branch information
rthalley committed Aug 8, 2023
1 parent 9d0262a commit a22644d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
3 changes: 3 additions & 0 deletions dns/_asyncbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,6 @@ async def sleep(self, interval):

def get_transport_class(self):
raise NotImplementedError

async def wait_for(self, awaitable, timeout):
raise NotImplementedError
3 changes: 3 additions & 0 deletions dns/_asyncio_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,6 @@ def datagram_connection_required(self):

def get_transport_class(self):
return _HTTPTransport

async def wait_for(self, awaitable, timeout):
return await _maybe_wait_for(awaitable, timeout)
7 changes: 7 additions & 0 deletions dns/_trio_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,10 @@ async def sleep(self, interval):

def get_transport_class(self):
return _HTTPTransport

async def wait_for(self, awaitable, timeout):
with _maybe_timeout(timeout):
return await awaitable
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
8 changes: 4 additions & 4 deletions dns/asyncquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,14 +563,14 @@ async def https(
"content-length": str(len(wire)),
}
)
response = await the_client.post(
url, headers=headers, content=wire, timeout=timeout
response = await backend.wait_for(
the_client.post(url, headers=headers, content=wire), timeout
)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
response = await the_client.get(
url, headers=headers, timeout=timeout, params={"dns": twire}
response = await backend.wait_for(
the_client.get(url, headers=headers, params={"dns": twire}), timeout
)

# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
Expand Down

0 comments on commit a22644d

Please sign in to comment.