From 1a6a727eeee75eb208b327fc66b228b805316faa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 23:08:26 +0200 Subject: [PATCH] Add asyncio message reassembler. --- src/websockets/asyncio/messages.py | 261 +++++++++++++++++++ tests/asyncio/__init__.py | 0 tests/asyncio/test_messages.py | 391 +++++++++++++++++++++++++++++ 3 files changed, 652 insertions(+) create mode 100644 src/websockets/asyncio/messages.py create mode 100644 tests/asyncio/__init__.py create mode 100644 tests/asyncio/test_messages.py diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py new file mode 100644 index 000000000..236139058 --- /dev/null +++ b/src/websockets/asyncio/messages.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +import asyncio +import codecs +from typing import Iterator + +from ..frames import Frame, Opcode +from ..typing import Data, List, Optional +from .compatibility import asyncio_timeout + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + + +class Assembler: + """ + Assemble messages from frames. + + """ + + def __init__(self) -> None: + self.loop = asyncio.get_event_loop() + + # We create a latch with two futures to ensure proper interleaving of + # writing and reading messages. + # put() sets this future to tell get() that a message can be fetched. + self.message_complete: asyncio.Future[None] = self.loop.create_future() + # get() sets this future to let put() that the message was fetched. + self.message_fetched: asyncio.Future[None] = self.loop.create_future() + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + # This flag prevents concurrent calls to put() by library code. + self.put_in_progress = False + + # Decoder for text frames, None for binary frames. + self.decoder: Optional[codecs.IncrementalDecoder] = None + + # Buffer of frames belonging to the same message. + self.chunks: List[Data] = [] + + # When switching from "buffering" to "streaming", we use a queue for + # transferring frames from the writing coroutine (library code) to the + # reading coroutine (user code). We're buffering when chunks_queue is + # None and streaming when it's a Queue. None is a sentinel value marking + # the end of the stream, superseding message_complete. + + # Stream data from frames belonging to the same message. + self.chunks_queue: Optional[asyncio.Queue[Optional[Data]]] = None + + # This flag marks the end of the stream. + self.closed = False + + async def get(self, timeout: Optional[float] = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Args: + timeout: If a timeout is provided and elapses before a complete + message is received, :meth:`get` raises :exc:`TimeoutError`. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + TimeoutError: If a timeout is provided and elapses before a + complete message is received. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get or get_iter is already running") + + # If the message_complete future isn't set yet, yield control to allow + # put() to run and eventually set it. + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + try: + async with asyncio_timeout(timeout): + await self.message_complete + finally: + self.get_in_progress = False + + # get() was unblocked by close() rather than put(). + if self.closed: + raise EOFError("stream of frames ended") + + assert self.message_complete.done() + self.message_complete = self.loop.create_future() + + joiner: Data = b"" if self.decoder is None else "" + # mypy cannot figure out that chunks have the proper type. + message: Data = joiner.join(self.chunks) # type: ignore + + self.message_fetched.set_result(None) + + self.chunks = [] + assert self.chunks_queue is None + + return message + + async def get_iter(self) -> Iterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`RuntimeError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get or get_iter is already running") + + chunks = self.chunks + self.chunks = [] + self.chunks_queue: asyncio.Queue[Optional[Data]] = asyncio.Queue() + + # Sending None in chunk_queue supersedes setting message_complete + # when switching to "streaming". If message is already complete + # when the switch happens, put() didn't send None, so we have to. + if self.message_complete.done(): + self.chunks_queue.put_nowait(None) + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + try: + for chunk in chunks: + yield chunk + while (chunk := await self.chunks_queue.get()) is not None: + yield chunk + finally: + self.get_in_progress = False + + assert self.message_complete.done() + self.message_complete = self.loop.create_future() + + # get_iter() was unblocked by close() rather than put(). + if self.closed: + raise EOFError("stream of frames ended") + + self.message_fetched.set_result(None) + + assert self.chunks == [] + self.chunks_queue = None + + async def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + When ``frame`` is the final frame in a message, :meth:`put` waits until + the message is fetched, either by calling :meth:`get` or by fully + consuming the return value of :meth:`get_iter`. + + :meth:`put` assumes that the stream of frames respects the protocol. If + it doesn't, the behavior is undefined. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`put` concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.put_in_progress: + raise RuntimeError("put is already running") + + if frame.opcode is Opcode.TEXT: + self.decoder = UTF8Decoder(errors="strict") + elif frame.opcode is Opcode.BINARY: + self.decoder = None + elif frame.opcode is Opcode.CONT: + pass + else: + # Ignore control frames. + return + + data: Data + if self.decoder is not None: + data = self.decoder.decode(frame.data, frame.fin) + else: + data = frame.data + + if self.chunks_queue is None: + self.chunks.append(data) + else: + self.chunks_queue.put_nowait(data) + + if not frame.fin: + return + + # Message is complete. Wait until it's fetched to return. + + self.message_complete.set_result(None) + + if self.chunks_queue is not None: + self.chunks_queue.put_nowait(None) + + # Yield control to allow get() to run and eventually set the future. + # Locking with put_in_progress ensures only one coroutine can get here. + self.put_in_progress = True + try: + assert not self.message_fetched.done() + await self.message_fetched + finally: + self.put_in_progress = False + + assert self.message_fetched.done() + self.message_fetched = self.loop.create_future() + + # put() was unblocked by close() rather than get() or get_iter(). + if self.closed: + raise EOFError("stream of frames ended") + + self.decoder = None + + def close(self) -> None: + """ + End the stream of frames. + + Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get or get_iter. + if self.get_in_progress: + self.message_complete.set_result(None) + if self.chunks_queue is not None: + self.chunks_queue.put_nowait(None) + + # Unblock put(). + if self.put_in_progress: + self.message_fetched.set_result(None) diff --git a/tests/asyncio/__init__.py b/tests/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py new file mode 100644 index 000000000..c2723c27a --- /dev/null +++ b/tests/asyncio/test_messages.py @@ -0,0 +1,391 @@ +import asyncio +import unittest + +from websockets.asyncio.messages import * +from websockets.frames import OP_BINARY, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Frame + + +class AssemblerTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.assembler = Assembler() + + def tearDown(self): + """ + Check that the assembler goes back to its default state after each test. + + This removes the need for testing various sequences. + + """ + self.assertFalse(self.assembler.get_in_progress) + self.assertFalse(self.assembler.put_in_progress) + if not self.assembler.closed: + self.assertFalse(self.assembler.message_complete.done()) + self.assertFalse(self.assembler.message_fetched.done()) + self.assertIsNone(self.assembler.decoder) + self.assertEqual(self.assembler.chunks, []) + self.assertIsNone(self.assembler.chunks_queue) + + # Test get + + async def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + + async def putter(): + await self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + asyncio.create_task(putter()) + message = await self.assembler.get() + + self.assertEqual(message, "café") + + async def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + + async def putter(): + await self.assembler.put(Frame(OP_BINARY, b"tea")) + + asyncio.create_task(putter()) + message = await self.assembler.get() + + self.assertEqual(message, b"tea") + + async def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = getter_task.result() + + self.assertEqual(message, "café") + + async def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await self.assembler.put(Frame(OP_BINARY, b"tea")) + message = getter_task.result() + + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + + async def putter(): + await self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"\xa9")) + + asyncio.create_task(putter()) + message = await self.assembler.get() + + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + + async def putter(): + await self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"a")) + + asyncio.create_task(putter()) + message = await self.assembler.get() + + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + await self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + await self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = getter_task.result() + + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + await self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + await self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"a")) + message = getter_task.result() + + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = getter_task.result() + + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"a")) + message = getter_task.result() + + self.assertEqual(message, b"tea") + + # Test get_iter + + async def run_get_iter(self): + self.fragments = [] + async for fragment in self.assembler.get_iter(): + self.fragments.append(fragment) + + async def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + + async def putter(): + await self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + asyncio.create_task(putter()) + await self.run_get_iter() + + self.assertEqual(self.fragments, ["café"]) + + async def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + + async def putter(): + await self.assembler.put(Frame(OP_BINARY, b"tea")) + + asyncio.create_task(putter()) + await self.run_get_iter() + + self.assertEqual(self.fragments, [b"tea"]) + + async def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + asyncio.create_task(self.run_get_iter()) + await self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + await asyncio.sleep(0) + + self.assertEqual(self.fragments, ["café"]) + + async def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + asyncio.create_task(self.run_get_iter()) + await self.assembler.put(Frame(OP_BINARY, b"tea")) + await asyncio.sleep(0) + + self.assertEqual(self.fragments, [b"tea"]) + + async def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + + async def putter(): + await self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"\xa9")) + + asyncio.create_task(putter()) + await self.run_get_iter() + + self.assertEqual(self.fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + + async def putter(): + await self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + await self.assembler.put(Frame(OP_CONT, b"a")) + + asyncio.create_task(putter()) + await self.run_get_iter() + + self.assertEqual(self.fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + await self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + asyncio.create_task(self.run_get_iter()) + await asyncio.sleep(0) + self.assertEqual(self.fragments, ["ca"]) + await self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + await asyncio.sleep(0) + self.assertEqual(self.fragments, ["ca", "f"]) + await self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(self.fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + await self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + asyncio.create_task(self.run_get_iter()) + await asyncio.sleep(0) + self.assertEqual(self.fragments, [b"t"]) + await self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + await asyncio.sleep(0) + self.assertEqual(self.fragments, [b"t", b"e"]) + await self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(self.fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + asyncio.create_task(self.run_get_iter()) + await asyncio.sleep(0) + await self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + await asyncio.sleep(0) + self.assertEqual(self.fragments, ["ca"]) + await self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + await asyncio.sleep(0) + self.assertEqual(self.fragments, ["ca", "f"]) + await self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(self.fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + asyncio.create_task(self.run_get_iter()) + await asyncio.sleep(0) + await self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + await asyncio.sleep(0) + self.assertEqual(self.fragments, [b"t"]) + await self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + await asyncio.sleep(0) + self.assertEqual(self.fragments, [b"t", b"e"]) + await self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(self.fragments, [b"t", b"e", b"a"]) + + # # Test timeouts + + # async def test_get_with_timeout_completes(self): + # """get returns a message when it is received before the timeout.""" + + # async def putter(): + # await self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + # asyncio.create_task(putter()) + # message = self.assembler.get(MS) + + # self.assertEqual(message, "café") + + # async def test_get_with_timeout_times_out(self): + # """get raises TimeoutError when no message is received before the timeout.""" + # with self.assertRaises(TimeoutError): + # self.assembler.get(MS) + + # Test control frames + + async def test_control_frame_before_message_is_ignored(self): + """get ignores control frames between messages.""" + + async def putter(): + await self.assembler.put(Frame(OP_PING, b"")) + await self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + asyncio.create_task(putter()) + message = await self.assembler.get() + + self.assertEqual(message, "café") + + async def test_control_frame_in_fragmented_message_is_ignored(self): + """get ignores control frames within fragmented messages.""" + + async def putter(): + await self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + await self.assembler.put(Frame(OP_PING, b"")) + await self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + await self.assembler.put(Frame(OP_PONG, b"")) + await self.assembler.put(Frame(OP_CONT, b"a")) + + asyncio.create_task(putter()) + message = await self.assembler.get() + + self.assertEqual(message, b"tea") + + # Test concurrency + + async def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + await self.assembler.put(Frame(OP_TEXT, b"")) # unlock other coroutine + + async def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + asyncio.create_task(self.run_get_iter()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + await self.assembler.put(Frame(OP_TEXT, b"")) # unlock other coroutine + + async def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.run_get_iter() + await self.assembler.put(Frame(OP_TEXT, b"")) # unlock other coroutine + + async def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + asyncio.create_task(self.run_get_iter()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.run_get_iter() + await self.assembler.put(Frame(OP_TEXT, b"")) # unlock other coroutine + + async def test_put_fails_when_put_is_running(self): + """put cannot be called concurrently with itself.""" + asyncio.create_task(self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.put(Frame(OP_BINARY, b"tea")) + await self.assembler.get() # unblock other coroutine + + # Test termination + + async def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + asyncio.get_event_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + asyncio.get_event_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_put_fails_when_interrupted_by_close(self): + """put raises EOFError when close is called.""" + asyncio.get_event_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + await self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close()