From add86fb0ee75d92f02e7581739a9f4c412f6b421 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 23:12:28 +0200 Subject: [PATCH] Add asyncio message assembler. --- src/websockets/asyncio/messages.py | 239 ++++++++++++++++++ tests/asyncio/__init__.py | 0 tests/asyncio/test_messages.py | 386 +++++++++++++++++++++++++++++ 3 files changed, 625 insertions(+) create mode 100644 src/websockets/asyncio/messages.py create mode 100644 tests/asyncio/__init__.py create mode 100644 tests/asyncio/test_messages.py diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py new file mode 100644 index 000000000..ae30c5307 --- /dev/null +++ b/src/websockets/asyncio/messages.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import asyncio +import codecs +import collections +from typing import ( + Any, + AsyncIterator, + Callable, + Generic, + Optional, + TypeVar, +) + +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + +T = TypeVar("T") + + +class SimpleQueue(Generic[T]): + """ + Simplified version of asyncio.Queue. + + Doesn't support maxsize nor concurrent calls to get(). + + """ + + def __init__(self) -> None: + self.loop = asyncio.get_running_loop() + self.get_waiter: Optional[asyncio.Future[None]] = None + self.queue: collections.deque[T] = collections.deque() + + def __len__(self) -> int: + return len(self.queue) + + def put(self, item: T) -> None: + """Put an item into the queue without waiting.""" + self.queue.append(item) + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_result(None) + + async def get(self) -> T: + """Remove and return an item from the queue, waiting if necessary.""" + if not self.queue: + if self.get_waiter is not None: + raise RuntimeError("get is already running") + self.get_waiter = self.loop.create_future() + try: + await self.get_waiter + finally: + self.get_waiter.cancel() + self.get_waiter = None + return self.queue.popleft() + + def abort(self) -> None: + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_exception(EOFError("stream of frames ended")) + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frame and that the stream of + frames respects the protocol. If it doesn't, the behavior is undefined. + + """ + + def __init__( + self, + high: int = 16, + low: int = 4, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Queue of incoming messages. Each item is a queue of frames. + self.frames: SimpleQueue[Frame] = SimpleQueue() + + # We cannot put a hard limit on the size of the queues because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + self.paused = False + self.high = high + self.low = low + self.pause = pause + self.resume = resume + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + async def get(self, decode: Optional[bool] = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get or get_iter is already running") + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + try: + # First frame + frame = await self.frames.get() + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + # Following frames, for fragmented messages + while not frame.fin: + frame = await self.frames.get() + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + finally: + self.get_in_progress = False + + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + async def get_iter(self, decode: Optional[bool] = None) -> AsyncIterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`RuntimeError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get or get_iter is already running") + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + try: + # First frame + frame = await self.frames.get() + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + # Following frames, for fragmented messages + while not frame.fin: + frame = await self.frames.get() + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + finally: + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + self.frames.put(frame) + self.maybe_pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + if len(self.frames) < self.low and self.paused: + self.paused = False + self.resume() + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + if len(self.frames) >= self.high and not self.paused: + self.paused = True + self.pause() + + def close(self) -> None: + """ + End the stream of frames. + + Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get or get_iter. + self.frames.abort() diff --git a/tests/asyncio/__init__.py b/tests/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py new file mode 100644 index 000000000..b039cdc85 --- /dev/null +++ b/tests/asyncio/test_messages.py @@ -0,0 +1,386 @@ +import asyncio +import unittest +import unittest.mock + +from websockets.asyncio.messages import * +from websockets.asyncio.messages import SimpleQueue +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame + + +class SimpleQueueTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.queue = SimpleQueue() + + async def test_len(self): + """__len__ returns queue length.""" + self.assertEqual(len(self.queue), 0) + self.queue.put(42) + self.assertEqual(len(self.queue), 1) + await self.queue.get() + self.assertEqual(len(self.queue), 0) + + async def test_put_then_get(self): + """get returns an item that is already put.""" + self.queue.put(42) + item = await self.queue.get() + self.assertEqual(item, 42) + + async def test_get_then_put(self): + """get returns an item when it is put.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + self.queue.put(42) + item = await getter_task + self.assertEqual(item, 42) + + async def test_get_concurrently(self): + """get cannot be called concurrently with itself.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + with self.assertRaises(RuntimeError): + await self.queue.get() + getter_task.cancel() + + +class AssemblerTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(low=2, high=3, pause=self.pause, resume=self.resume) + + # Test get + + async def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + async def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + async def test_get_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + await self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + # Test get_iter + + async def run_get_iter(self, **kwargs): + self.fragments = [] + async for fragment in self.assembler.get_iter(**kwargs): + self.fragments.append(fragment) + + async def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + await self.run_get_iter() + self.assertEqual(self.fragments, ["café"]) + + async def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + await self.run_get_iter() + self.assertEqual(self.fragments, [b"tea"]) + + async def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + asyncio.create_task(self.run_get_iter()) + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, ["café"]) + + async def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + asyncio.create_task(self.run_get_iter()) + self.assembler.put(Frame(OP_BINARY, b"tea")) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, [b"tea"]) + + async def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + await self.run_get_iter() + self.assertEqual(self.fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + await self.run_get_iter() + self.assertEqual(self.fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + asyncio.create_task(self.run_get_iter()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, ["ca"]) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, ["ca", "f"]) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + asyncio.create_task(self.run_get_iter()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, [b"t"]) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, [b"t", b"e"]) + self.assembler.put(Frame(OP_CONT, b"a")) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + asyncio.create_task(self.run_get_iter()) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, ["ca"]) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, ["ca", "f"]) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + asyncio.create_task(self.run_get_iter()) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, [b"t"]) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, [b"t", b"e"]) + self.assembler.put(Frame(OP_CONT, b"a")) + await asyncio.sleep(0) # let run_get_iter() run + self.assertEqual(self.fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + await self.run_get_iter(decode=False) + self.assertEqual(self.fragments, [b"ca", b"f\xc3", b"\xa9"]) + + async def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + await self.run_get_iter(decode=True) + self.assertEqual(self.fragments, ["t", "e", "a"]) + + async def test_get_iter_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + getter_iter = aiter(self.assembler.get_iter()) + + # queue is above the low-water mark + await anext(getter_iter) + self.resume.assert_not_called() + + # queue is at the low-water mark + await anext(getter_iter) + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await anext(getter_iter) + self.resume.assert_called_once_with() + + # Test put + + async def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() + + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() + + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() + + # Test termination + + async def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + asyncio.get_running_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + asyncio.get_running_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() + + # Test (non-)concurrency + + async def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + self.assembler.close() # let task terminate + + async def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + asyncio.create_task(self.run_get_iter()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + self.assembler.close() # let task terminate + + async def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.run_get_iter() + self.assembler.close() # let task terminate + + async def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + asyncio.create_task(self.run_get_iter()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.run_get_iter() + self.assembler.close() # let task terminate