Skip to content

Commit

Permalink
Don't wait_closed() on servers until connections are closed
Browse files Browse the repository at this point in the history
The asyncio.Server.wait_closed() method was a no-op in versions of Python
earlier than 3.12. The intention of this method was to block until all
existing connections are closed, but due to a bug in its implementation,
it wouldn't actually wait. This bug was fixed in Python 3.12, which
exposed uvicorn's dependence on the buggy behavior: the implementation of
uvicorn.Server.shutdown() called wait_closed() before asking existing
connections to close. As a result, attempting to stop a Uvicorn server
with an open connection would result in a blocked process, with further
attempts to Ctrl+C having no effect.
  • Loading branch information
jcheng5 committed Nov 1, 2023
1 parent 07c2b36 commit a670b3b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
24 changes: 24 additions & 0 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,30 @@ async def open_connection(url):
assert is_open


@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_shutdown(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls,
unused_tcp_port: int,
):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config) as server:
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}"):
# Attempt shutdown while connection is still open
await server.shutdown()


@pytest.mark.anyio
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_supports_permessage_deflate_extension(
Expand Down
16 changes: 14 additions & 2 deletions uvicorn/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,6 @@ async def shutdown(self, sockets: Optional[List[socket.socket]] = None) -> None:
server.close()
for sock in sockets or []:
sock.close()
for server in self.servers:
await server.wait_closed()

# Request shutdown on all existing connections.
for connection in list(self.server_state.connections):
Expand Down Expand Up @@ -312,6 +310,20 @@ async def _wait_tasks_to_complete(self) -> None:
while self.server_state.tasks and not self.force_exit:
await asyncio.sleep(0.1)

# Wait for servers to close. They won't do so until all connections are
# closed, which we've already waited for above, so this should be quick.
servers_closed = asyncio.gather(
*[server.wait_closed() for server in self.servers]
)
# Give the servers_closed future a chance to complete so we don't
# spuriously log about this operation.
await asyncio.sleep(0.1)
if not servers_closed.done() and not self.force_exit:
msg = "Waiting for servers to close. (CTRL+C to force quit)"
logger.info(msg)
while not servers_closed.done() and not self.force_exit:
await asyncio.sleep(0.1)

def install_signal_handlers(self) -> None:
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
Expand Down

0 comments on commit a670b3b

Please sign in to comment.