Skip to content

Commit

Permalink
Add WebSocketsSansIOProtocol
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Dec 14, 2024
1 parent d79d86e commit 7ee1e15
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 69 deletions.
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def unused_tcp_port() -> int:
),
pytest.param("uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", id="websockets"),
pytest.param(
"uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketSansIOProtocol", id="websockets-sansio"
"uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", id="websockets-sansio"
),
]
)
Expand Down
1 change: 0 additions & 1 deletion tests/middleware/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging
assert any(" - HTTP connection lost" in message for message in messages)


@pytest.mark.skip()
async def test_trace_logging_on_ws_protocol(
ws_protocol_cls: WSProtocol,
caplog: pytest.LogCaptureFixture,
Expand Down
4 changes: 2 additions & 2 deletions tests/middleware/test_proxy_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import httpx
import httpx._transports.asgi
import pytest
from websockets.asyncio.client import connect

import websockets.client
from tests.response import Response
from tests.utils import run_server
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
Expand Down Expand Up @@ -479,7 +479,7 @@ async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISe
async with run_server(config):
url = f"ws://127.0.0.1:{unused_tcp_port}"
headers = {X_FORWARDED_FOR: "1.2.3.4", X_FORWARDED_PROTO: forwarded_proto}
async with connect(url, additional_headers=headers) as websocket:
async with websockets.client.connect(url, extra_headers=headers) as websocket:
data = await websocket.recv()
assert data == expected

Expand Down
35 changes: 20 additions & 15 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,20 +601,20 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
await send_accept_task.wait()
disconnect_message = await receive() # type: ignore

response: httpx.Response | None = None

async def websocket_session(uri: str):
nonlocal response
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{unused_tcp_port}",
headers={
"upgrade": "websocket",
"connection": "upgrade",
"sec-websocket-version": "13",
"sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==",
},
)
try:
await client.get(
f"http://127.0.0.1:{unused_tcp_port}",
headers={
"upgrade": "websocket",
"connection": "upgrade",
"sec-websocket-version": "13",
"sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==",
},
)
except httpx.RemoteProtocolError:
pass # pragma: no cover

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
Expand All @@ -623,9 +623,6 @@ async def websocket_session(uri: str):
send_accept_task.set()
await asyncio.sleep(0.1)

assert response is not None
assert response.status_code == 500, response.text
assert response.text == "Internal Server Error"
assert disconnect_message == {"type": "websocket.disconnect", "code": 1006}
await task

Expand Down Expand Up @@ -920,6 +917,9 @@ async def websocket_session(url: str):
async def test_server_reject_connection_with_invalid_msg(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "websocket"
assert "extensions" in scope and "websocket.http.response" in scope["extensions"]
Expand Down Expand Up @@ -951,6 +951,9 @@ async def websocket_session(url: str):
async def test_server_reject_connection_with_missing_body(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "websocket"
assert "extensions" in scope and "websocket.http.response" in scope["extensions"]
Expand Down Expand Up @@ -986,6 +989,8 @@ async def test_server_multiple_websocket_http_response_start_events(
The server should raise an exception if it sends multiple
websocket.http.response.start events.
"""
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")
exception_message: str | None = None

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
Expand Down
2 changes: 1 addition & 1 deletion uvicorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol",
"none": None,
"websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
"websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketSansIOProtocol",
"websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol",
"wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
}
LIFESPAN: dict[LifespanType, str] = {
Expand Down
117 changes: 70 additions & 47 deletions uvicorn/protocols/websockets/websockets_sansio_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Literal, cast
from urllib.parse import unquote

from websockets import InvalidState
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.frames import Frame, Opcode
from websockets.http11 import Request
Expand All @@ -26,11 +27,17 @@
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.protocols.utils import (
ClientDisconnected,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.server import ServerState


class WebSocketSansIOProtocol(asyncio.Protocol):
class WebSocketsSansIOProtocol(asyncio.Protocol):
def __init__(
self,
config: Config,
Expand Down Expand Up @@ -96,12 +103,20 @@ def connection_made(self, transport: BaseTransport) -> None:
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

def connection_lost(self, exc: Exception | None) -> None:
code = 1005 if self.handshake_complete else 1006
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
if self.handshake_initiated and not self.close_sent:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})

self.handshake_complete = True
if exc is None:
self.transport.close()

def eof_received(self) -> None:
pass

def shutdown(self) -> None:
if not self.transport.is_closing():
Expand All @@ -110,8 +125,8 @@ def shutdown(self) -> None:
self.close_sent = True
self.conn.send_close(1012)
output = self.conn.data_to_send()
self.transport.writelines(output)
elif self.handshake_initiated:
self.transport.write(b"".join(output))
elif not self.handshake_initiated:
self.send_500_response()
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.transport.close()
Expand Down Expand Up @@ -152,7 +167,7 @@ def handle_connect(self, event: Request) -> None:
self.close_sent = True
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.writelines(output)
self.transport.write(b"".join(output))
self.transport.close()
return

Expand Down Expand Up @@ -213,29 +228,29 @@ def send_receive_event_to_app(self) -> None:

def handle_ping(self, event: Frame) -> None:
output = self.conn.data_to_send()
self.transport.writelines(output)
self.transport.write(b"".join(output))

def handle_close(self, event: Frame) -> None:
if not self.close_sent and self.conn.close_rcvd and not self.transport.is_closing():
if not self.close_sent and not self.transport.is_closing():
disconnect_event: WebSocketDisconnectEvent = {
"type": "websocket.disconnect",
"code": self.conn.close_rcvd.code,
"reason": self.conn.close_rcvd.reason,
"code": self.conn.close_rcvd.code, # type: ignore[union-attr]
"reason": self.conn.close_rcvd.reason, # type: ignore[union-attr]
}
self.queue.put_nowait(disconnect_event)
output = self.conn.data_to_send()
self.transport.writelines(output)
self.close_sent = True
self.transport.write(b"".join(output))
self.transport.close()

def handle_parser_exception(self) -> None:
disconnect_event: WebSocketDisconnectEvent = {
"type": "websocket.disconnect",
"code": self.conn.close_sent.code if self.conn.close_sent else 1006,
"code": self.conn.close_sent.code, # type: ignore[union-attr]
"reason": self.conn.close_sent.reason, # type: ignore[union-attr]
}
self.queue.put_nowait(disconnect_event)
output = self.conn.data_to_send()
self.transport.writelines(output)
self.transport.write(b"".join(output))
self.close_sent = True
self.transport.close()

Expand All @@ -245,10 +260,11 @@ def on_task_complete(self, task: asyncio.Task[None]) -> None:
async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
except ClientDisconnected:
self.transport.close()
except BaseException:
self.logger.exception("Exception in ASGI application\n")
if not self.handshake_complete:
self.send_500_response()
self.send_500_response()
self.transport.close()
else:
if not self.handshake_complete:
Expand All @@ -262,10 +278,12 @@ async def run_asgi(self) -> None:
self.transport.close()

def send_500_response(self) -> None:
if self.initial_response or self.handshake_complete:
return
response = self.conn.reject(500, "Internal Server Error")
self.conn.send_response(response)
output = self.conn.data_to_send()
self.transport.writelines(output)
self.transport.write(b"".join(output))

async def send(self, message: ASGISendEvent) -> None:
await self.writable.wait()
Expand Down Expand Up @@ -293,7 +311,7 @@ async def send(self, message: ASGISendEvent) -> None:
self.handshake_complete = True
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.writelines(output)
self.transport.write(b"".join(output))

elif message_type == "websocket.close":
message = cast(WebSocketCloseEvent, message)
Expand All @@ -308,10 +326,12 @@ async def send(self, message: ASGISendEvent) -> None:
output = self.conn.data_to_send()
self.close_sent = True
self.handshake_complete = True
self.transport.writelines(output)
self.transport.write(b"".join(output))
self.transport.close()
elif message_type == "websocket.http.response.start":
elif message_type == "websocket.http.response.start" and self.initial_response is None:
message = cast(WebSocketResponseStartEvent, message)
if not (100 <= message["status"] < 600):
raise RuntimeError("Invalid HTTP status code '%d' in response." % message["status"])
self.logger.info(
'%s - "WebSocket %s" %d',
self.scope["client"],
Expand All @@ -329,34 +349,36 @@ async def send(self, message: ASGISendEvent) -> None:
"or 'websocket.http.response.start' "
"but got '%s'."
)
print(message)
raise RuntimeError(msg % message_type)

elif not self.close_sent and self.initial_response is None:
if message_type == "websocket.send" and not self.transport.is_closing():
message = cast(WebSocketSendEvent, message)
bytes_data = message.get("bytes")
text_data = message.get("text")
if text_data:
self.conn.send_text(text_data.encode())
elif bytes_data:
self.conn.send_binary(bytes_data)
output = self.conn.data_to_send()
self.transport.writelines(output)

elif message_type == "websocket.close" and not self.transport.is_closing():
message = cast(WebSocketCloseEvent, message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.conn.send_close(code, reason)
output = self.conn.data_to_send()
self.transport.writelines(output)
self.close_sent = True
self.transport.close()
else:
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
raise RuntimeError(msg % message_type)
try:
if message_type == "websocket.send":
message = cast(WebSocketSendEvent, message)
bytes_data = message.get("bytes")
text_data = message.get("text")
if text_data:
self.conn.send_text(text_data.encode())
elif bytes_data:
self.conn.send_binary(bytes_data)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))

elif message_type == "websocket.close" and not self.transport.is_closing():
message = cast(WebSocketCloseEvent, message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.conn.send_close(code, reason)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
self.close_sent = True
self.transport.close()
else:
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
raise RuntimeError(msg % message_type)
except InvalidState:
raise ClientDisconnected()
elif self.initial_response is not None:
if message_type == "websocket.http.response.body":
message = cast(WebSocketResponseBodyEvent, message)
Expand All @@ -365,10 +387,11 @@ async def send(self, message: ASGISendEvent) -> None:
if not message.get("more_body", False):
response = self.conn.reject(self.initial_response[0], body.decode())
response.headers.update(self.initial_response[1])
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.conn.send_response(response)
output = self.conn.data_to_send()
self.close_sent = True
self.transport.writelines(output)
self.transport.write(b"".join(output))
self.transport.close()
else:
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
Expand Down
4 changes: 2 additions & 2 deletions uvicorn/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketSansIOProtocol
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol

Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketSansIOProtocol]
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol]

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
Expand Down

0 comments on commit 7ee1e15

Please sign in to comment.