From 83be09db0d9858732834f653b5489b9b81af34ac Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 3 Dec 2023 16:25:29 +0000 Subject: [PATCH] Support the `websocket.http.response` ASGI extension --- uvicorn/protocols/websockets/wsproto_impl.py | 188 +++++++++---------- 1 file changed, 85 insertions(+), 103 deletions(-) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 08033b198..dab89314f 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -256,120 +256,103 @@ async def run_asgi(self) -> None: self.logger.error(msg, result) self.transport.close() - async def send(self, message: "ASGISendEvent") -> None: + async def send(self, message: ASGISendEvent) -> None: await self.writable.wait() - message_type = message["type"] - if not self.handshake_complete: - if not (self.response_started or self.reject_event): - # a rejection event has not been sent yet - if message_type == "websocket.accept": - message = typing.cast("WebSocketAcceptEvent", message) - self.logger.info( - '%s - "WebSocket %s" [accepted]', - self.scope["client"], - get_path_with_query_string(self.scope), - ) - subprotocol = message.get("subprotocol") - extra_headers = self.default_headers + list( - message.get("headers", []) - ) - extensions: typing.List[Extension] = [] - if self.config.ws_per_message_deflate: - extensions.append(PerMessageDeflate()) - if not self.transport.is_closing(): - self.handshake_complete = True - output = self.conn.send( - wsproto.events.AcceptConnection( - subprotocol=subprotocol, - extensions=extensions, - extra_headers=extra_headers, - ) + if message["type"] == "websocket.accept": + self.logger.info( + '%s - "WebSocket %s" [accepted]', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + subprotocol = message.get("subprotocol") + extra_headers = self.default_headers + list(message.get("headers", [])) + extensions: typing.List[Extension] = [] + if self.config.ws_per_message_deflate: + extensions.append(PerMessageDeflate()) + if not self.transport.is_closing(): + self.handshake_complete = True + output = self.conn.send( + wsproto.events.AcceptConnection( + subprotocol=subprotocol, + extensions=extensions, + extra_headers=extra_headers, ) - self.transport.write(output) + ) + self.transport.write(output) + + elif message["type"] == "websocket.close": + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.logger.info( + '%s - "WebSocket %s" 403', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + self.handshake_complete = True + self.close_sent = True + event = events.RejectConnection(status_code=403, headers=[]) + output = self.conn.send(event) + self.transport.write(output) + self.transport.close() - elif message_type == "websocket.close": + elif message["type"] == "websocket.http.response.start": + self.logger.info( + '%s - "WebSocket %s" %d', + self.scope["client"], + get_path_with_query_string(self.scope), + message["status"], + ) + event = events.RejectConnection( + status_code=message["status"], + headers=list(message["headers"]), + has_body=True, + ) + self.handshake_complete = True + output = self.conn.send(event) + self.transport.write(output) + + else: + msg = ( + "Expected ASGI message 'websocket.accept', 'websocket.close' " + "or 'websocket.http.response.start' " + "but got '%s'." + ) + raise RuntimeError(msg % message["type"]) + else: + # we have started a rejection process with http.response.start + if message["type"] == "websocket.http.response.body": + body_finished = not message.get("more_body", False) + reject_data = events.RejectData( + data=message["body"], body_finished=body_finished + ) + if self.reject_event is not None: + # Prepend with the reject event now that we have a body event. + output = self.conn.send(self.reject_event) + self.transport.write(output) + self.reject_event = None + self.response_started = True + + output = self.conn.send(reject_data) + self.transport.write(output) + + if body_finished: self.queue.put_nowait( {"type": "websocket.disconnect", "code": 1006} ) - self.logger.info( - '%s - "WebSocket %s" 403', - self.scope["client"], - get_path_with_query_string(self.scope), - ) self.handshake_complete = True self.close_sent = True - event = events.RejectConnection(status_code=403, headers=[]) - output = self.conn.send(event) - self.transport.write(output) self.transport.close() - elif message_type == "websocket.http.response.start": - message = typing.cast("WebSocketResponseStartEvent", message) - # ensure status code is in the valid range - if not (100 <= message["status"] < 600): - msg = "Invalid HTTP status code '%d' in response." - raise RuntimeError(msg % message["status"]) - self.logger.info( - '%s - "WebSocket %s" %d', - self.scope["client"], - get_path_with_query_string(self.scope), - message["status"], - ) - event = events.RejectConnection( - status_code=message["status"], - headers=list(message["headers"]), - has_body=True, - ) - # Create the event here but do not send it, the ASGI spec - # suggest that we wait for the body event before sending. - # https://asgi.readthedocs.io/en/latest/specs/www.html#response-start-send-event - self.reject_event = event - - else: - msg = ( - "Expected ASGI message 'websocket.accept', 'websocket.close' " - "or 'websocket.http.response.start' " - "but got '%s'." - ) - raise RuntimeError(msg % message_type) else: - # we have started a rejection process with http.response.start - if message_type == "websocket.http.response.body": - message = typing.cast("WebSocketResponseBodyEvent", message) - body_finished = not message.get("more_body", False) - reject_data = events.RejectData( - data=message["body"], body_finished=body_finished - ) - if self.reject_event is not None: - # Prepend with the reject event now that we have a body event. - output = self.conn.send(self.reject_event) - self.transport.write(output) - self.reject_event = None - self.response_started = True - - output = self.conn.send(reject_data) - self.transport.write(output) - - if body_finished: - self.queue.put_nowait( - {"type": "websocket.disconnect", "code": 1006} - ) - self.handshake_complete = True - self.close_sent = True - self.transport.close() - - else: - msg = ( - "Expected ASGI message 'websocket.http.response.body' " - "but got '%s'." - ) - raise RuntimeError(msg % message_type) + msg = ( + "Expected ASGI message 'websocket.http.response.body' " + "but got '%s'." + ) + raise RuntimeError(msg % message["type"]) elif not self.close_sent: - if message_type == "websocket.send": - message = typing.cast("WebSocketSendEvent", message) + if message["type"] == "websocket.send": bytes_data = message.get("bytes") text_data = message.get("text") data = text_data if bytes_data is None else bytes_data @@ -379,8 +362,7 @@ async def send(self, message: "ASGISendEvent") -> None: if not self.transport.is_closing(): self.transport.write(output) - elif message_type == "websocket.close": - message = typing.cast("WebSocketCloseEvent", message) + elif message["type"] == "websocket.close": self.close_sent = True code = message.get("code", 1000) reason = message.get("reason", "") or "" @@ -397,11 +379,11 @@ async def send(self, message: "ASGISendEvent") -> None: "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'." ) - raise RuntimeError(msg % message_type) + raise RuntimeError(msg % message["type"]) else: msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." - raise RuntimeError(msg % message_type) + raise RuntimeError(msg % message["type"]) async def receive(self) -> "WebSocketEvent": message = await self.queue.get()