Skip to content

Commit

Permalink
Support the websocket.http.response ASGI extension
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Dec 3, 2023
1 parent 80797c9 commit 83be09d
Showing 1 changed file with 85 additions and 103 deletions.
188 changes: 85 additions & 103 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,120 +256,103 @@ async def run_asgi(self) -> None:
self.logger.error(msg, result)
self.transport.close()

async def send(self, message: "ASGISendEvent") -> None:
async def send(self, message: ASGISendEvent) -> None:
await self.writable.wait()

message_type = message["type"]

if not self.handshake_complete:
if not (self.response_started or self.reject_event):
# a rejection event has not been sent yet
if message_type == "websocket.accept":
message = typing.cast("WebSocketAcceptEvent", message)
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
get_path_with_query_string(self.scope),
)
subprotocol = message.get("subprotocol")
extra_headers = self.default_headers + list(
message.get("headers", [])
)
extensions: typing.List[Extension] = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
if not self.transport.is_closing():
self.handshake_complete = True
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol,
extensions=extensions,
extra_headers=extra_headers,
)
if message["type"] == "websocket.accept":
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
get_path_with_query_string(self.scope),
)
subprotocol = message.get("subprotocol")
extra_headers = self.default_headers + list(message.get("headers", []))
extensions: typing.List[Extension] = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
if not self.transport.is_closing():
self.handshake_complete = True
output = self.conn.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol,
extensions=extensions,
extra_headers=extra_headers,
)
self.transport.write(output)
)
self.transport.write(output)

elif message["type"] == "websocket.close":
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
get_path_with_query_string(self.scope),
)
self.handshake_complete = True
self.close_sent = True
event = events.RejectConnection(status_code=403, headers=[])
output = self.conn.send(event)
self.transport.write(output)
self.transport.close()

elif message_type == "websocket.close":
elif message["type"] == "websocket.http.response.start":
self.logger.info(
'%s - "WebSocket %s" %d',
self.scope["client"],
get_path_with_query_string(self.scope),
message["status"],
)
event = events.RejectConnection(
status_code=message["status"],
headers=list(message["headers"]),
has_body=True,
)
self.handshake_complete = True
output = self.conn.send(event)
self.transport.write(output)

else:
msg = (
"Expected ASGI message 'websocket.accept', 'websocket.close' "
"or 'websocket.http.response.start' "
"but got '%s'."
)
raise RuntimeError(msg % message["type"])
else:
# we have started a rejection process with http.response.start
if message["type"] == "websocket.http.response.body":
body_finished = not message.get("more_body", False)
reject_data = events.RejectData(
data=message["body"], body_finished=body_finished
)
if self.reject_event is not None:
# Prepend with the reject event now that we have a body event.
output = self.conn.send(self.reject_event)
self.transport.write(output)
self.reject_event = None
self.response_started = True

output = self.conn.send(reject_data)
self.transport.write(output)

if body_finished:
self.queue.put_nowait(
{"type": "websocket.disconnect", "code": 1006}
)
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
get_path_with_query_string(self.scope),
)
self.handshake_complete = True
self.close_sent = True
event = events.RejectConnection(status_code=403, headers=[])
output = self.conn.send(event)
self.transport.write(output)
self.transport.close()

elif message_type == "websocket.http.response.start":
message = typing.cast("WebSocketResponseStartEvent", message)
# ensure status code is in the valid range
if not (100 <= message["status"] < 600):
msg = "Invalid HTTP status code '%d' in response."
raise RuntimeError(msg % message["status"])
self.logger.info(
'%s - "WebSocket %s" %d',
self.scope["client"],
get_path_with_query_string(self.scope),
message["status"],
)
event = events.RejectConnection(
status_code=message["status"],
headers=list(message["headers"]),
has_body=True,
)
# Create the event here but do not send it, the ASGI spec
# suggest that we wait for the body event before sending.
# https://asgi.readthedocs.io/en/latest/specs/www.html#response-start-send-event
self.reject_event = event

else:
msg = (
"Expected ASGI message 'websocket.accept', 'websocket.close' "
"or 'websocket.http.response.start' "
"but got '%s'."
)
raise RuntimeError(msg % message_type)
else:
# we have started a rejection process with http.response.start
if message_type == "websocket.http.response.body":
message = typing.cast("WebSocketResponseBodyEvent", message)
body_finished = not message.get("more_body", False)
reject_data = events.RejectData(
data=message["body"], body_finished=body_finished
)
if self.reject_event is not None:
# Prepend with the reject event now that we have a body event.
output = self.conn.send(self.reject_event)
self.transport.write(output)
self.reject_event = None
self.response_started = True

output = self.conn.send(reject_data)
self.transport.write(output)

if body_finished:
self.queue.put_nowait(
{"type": "websocket.disconnect", "code": 1006}
)
self.handshake_complete = True
self.close_sent = True
self.transport.close()

else:
msg = (
"Expected ASGI message 'websocket.http.response.body' "
"but got '%s'."
)
raise RuntimeError(msg % message_type)
msg = (
"Expected ASGI message 'websocket.http.response.body' "
"but got '%s'."
)
raise RuntimeError(msg % message["type"])

elif not self.close_sent:
if message_type == "websocket.send":
message = typing.cast("WebSocketSendEvent", message)
if message["type"] == "websocket.send":
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
Expand All @@ -379,8 +362,7 @@ async def send(self, message: "ASGISendEvent") -> None:
if not self.transport.is_closing():
self.transport.write(output)

elif message_type == "websocket.close":
message = typing.cast("WebSocketCloseEvent", message)
elif message["type"] == "websocket.close":
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
Expand All @@ -397,11 +379,11 @@ async def send(self, message: "ASGISendEvent") -> None:
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
raise RuntimeError(msg % message_type)
raise RuntimeError(msg % message["type"])

else:
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
raise RuntimeError(msg % message_type)
raise RuntimeError(msg % message["type"])

async def receive(self) -> "WebSocketEvent":
message = await self.queue.get()
Expand Down

0 comments on commit 83be09d

Please sign in to comment.