diff --git a/src/labone/core/session.py b/src/labone/core/session.py index 7eb1834..47b52e2 100644 --- a/src/labone/core/session.py +++ b/src/labone/core/session.py @@ -31,7 +31,7 @@ request_field_type_description, ) from labone.core.result import unwrap -from labone.core.subscription import DataQueue, streaming_handle_factory +from labone.core.subscription import DataQueue, QueueProtocol, streaming_handle_factory from labone.core.value import AnnotatedValue if t.TYPE_CHECKING: @@ -553,11 +553,33 @@ async def get_with_expression( for raw_result in response.result ] + @t.overload async def subscribe( self, path: LabOneNodePath, + *, parser_callback: t.Callable[[AnnotatedValue], AnnotatedValue] | None = None, + queue_type: None, ) -> DataQueue: + ... + + @t.overload + async def subscribe( + self, + path: LabOneNodePath, + *, + parser_callback: t.Callable[[AnnotatedValue], AnnotatedValue] | None = None, + queue_type: type[QueueProtocol], + ) -> QueueProtocol: + ... + + async def subscribe( + self, + path: LabOneNodePath, + *, + parser_callback: t.Callable[[AnnotatedValue], AnnotatedValue] | None = None, + queue_type: type[QueueProtocol] | None = None, + ) -> QueueProtocol | DataQueue: """Register a new subscription to a node. Registers a new subscription to a node on the kernel/server. All @@ -581,6 +603,10 @@ async def subscribe( parser_callback: Function to bring values obtained from data-queue into desired format. This may involve parsing them or putting them into an enum. + queue_type: The type of the queue to be returned. This can be + any class matching the DataQueue interface. Only needed if the + default DataQueue class is not sufficient. If None is passed + the default DataQueue class is used. (default=None) Returns: An instance of the DataQueue class. This async queue will receive @@ -611,7 +637,8 @@ async def subscribe( request.subscription = subscription response = await _send_and_wait_request(request) unwrap(response.result) # Result(Void, Error) - return DataQueue( + new_queue_type = queue_type or DataQueue + return new_queue_type( path=path, register_function=streaming_handle.register_data_queue, ) diff --git a/src/labone/core/subscription.py b/src/labone/core/subscription.py index 7e53a8c..c785474 100644 --- a/src/labone/core/subscription.py +++ b/src/labone/core/subscription.py @@ -103,7 +103,21 @@ def __repr__(self) -> str: f"connected={self.connected})", ) - def fork(self) -> DataQueue: + @t.overload + def fork(self, queue_type: None) -> DataQueue: + ... + + @t.overload + def fork( + self, + queue_type: type[QueueProtocol], + ) -> QueueProtocol: + ... + + def fork( + self, + queue_type: type[QueueProtocol] | None = None, + ) -> DataQueue | QueueProtocol: """Create a fork of the subscription. The forked subscription will receive all updates that the original @@ -114,6 +128,12 @@ def fork(self) -> DataQueue: Warning: The forked subscription will not contain any values before the fork. + Args: + queue_type: The type of the queue to be returned. This can be + any class matching the DataQueue interface. Only needed if the + default DataQueue class is not sufficient. If None is passed + the default DataQueue class is used. (default=None) + Returns: A new data queue to the same underlying subscription. """ @@ -123,7 +143,8 @@ def fork(self) -> DataQueue: "sense as it would never receive data.", ) raise errors.StreamingError(msg) - return DataQueue( + new_queue_type = queue_type or DataQueue + return new_queue_type( path=self._path, register_function=self._register_function, ) @@ -209,6 +230,89 @@ def maxsize(self, maxsize: int) -> None: self._maxsize = maxsize +QueueProtocol = t.TypeVar("QueueProtocol", bound=DataQueue) + + +class CircularDataQueue(DataQueue): + """Circular data queue. + + This data queue is identical to the DataQueue, with the exception that it + will remove the oldest item from the queue if the queue is full and a new + item is added. + """ + + async def put(self, item: AnnotatedValue) -> None: + """Put an item into the queue. + + If the queue is full the oldest item will be removed and the new item + will be added to the end of the queue. + + Args: + item: The item to the put in the queue. + + Raises: + StreamingError: If the data queue has been disconnected. + """ + if self.full(): + self.get_nowait() + await super().put(item) + + def put_nowait(self, item: AnnotatedValue) -> None: + """Put an item into the queue without blocking. + + If the queue is full the oldest item will be removed and the new item + will be added to the end of the queue. + + Args: + item: The item to the put in the queue. + + Raises: + StreamingError: If the data queue has been disconnected. + """ + if self.full(): + self.get_nowait() + super().put_nowait(item) + + @t.overload + def fork(self, queue_type: None) -> CircularDataQueue: + ... # pragma: no cover + + @t.overload + def fork( + self, + queue_type: type[QueueProtocol], + ) -> QueueProtocol: + ... # pragma: no cover + + def fork( + self, + queue_type: type[QueueProtocol] | None = None, + ) -> CircularDataQueue | QueueProtocol: + """Create a fork of the subscription. + + The forked subscription will receive all updates that the original + subscription receives. Its connection state is independent of the original + subscription, meaning even if the original subscription is disconnected, + the forked subscription will still receive updates. + + Warning: + The forked subscription will not contain any values before the fork. + + Args: + queue_type: The type of the queue to be returned. This can be + any class matching the DataQueue interface. Only needed if the + default DataQueue class is not sufficient. If None is passed + the default DataQueue class is used. (default=None) + + Returns: + A new data queue to the same underlying subscription. + """ + return DataQueue.fork( + self, + queue_type=queue_type if queue_type is not None else CircularDataQueue, + ) + + class StreamingHandle(ABC): """Streaming Handle server. @@ -238,7 +342,10 @@ def __init__( ... @abstractmethod - def register_data_queue(self, data_queue: weakref.ReferenceType[DataQueue]) -> None: + def register_data_queue( + self, + data_queue: weakref.ReferenceType[QueueProtocol], + ) -> None: """Register a new data queue. Args: @@ -303,7 +410,7 @@ def __init__( *, parser_callback: t.Callable[[AnnotatedValue], AnnotatedValue] | None = None, ) -> None: - self._data_queues: list[weakref.ReferenceType[DataQueue]] = [] + self._data_queues = [] # type: ignore[var-annotated] if parser_callback is None: @@ -314,7 +421,7 @@ def parser_callback(x: AnnotatedValue) -> AnnotatedValue: def register_data_queue( self, - data_queue: weakref.ReferenceType[DataQueue], + data_queue: weakref.ReferenceType[QueueProtocol], ) -> None: """Register a new data queue. @@ -326,7 +433,7 @@ def register_data_queue( def _add_to_data_queue( self, - data_queue: DataQueue | None, + data_queue: QueueProtocol | None, value: AnnotatedValue, ) -> bool: """Add a value to the data queue. diff --git a/src/labone/nodetree/helper.py b/src/labone/nodetree/helper.py index dd9f719..3f626bc 100644 --- a/src/labone/nodetree/helper.py +++ b/src/labone/nodetree/helper.py @@ -15,7 +15,7 @@ from typing_extensions import TypeAlias from labone.core.session import NodeInfo - from labone.core.subscription import DataQueue + from labone.core.subscription import QueueProtocol NormalizedPathSegment: TypeAlias = str @@ -111,7 +111,8 @@ async def subscribe( path: LabOneNodePath, *, parser_callback: t.Callable[[AnnotatedValue], AnnotatedValue] | None = None, - ) -> DataQueue: + queue_type: type[QueueProtocol], + ) -> QueueProtocol: """Register a new subscription to a node.""" ... diff --git a/src/labone/nodetree/node.py b/src/labone/nodetree/node.py index 7534271..e37d7b6 100644 --- a/src/labone/nodetree/node.py +++ b/src/labone/nodetree/node.py @@ -15,6 +15,7 @@ from deprecation import deprecated +from labone.core.subscription import DataQueue from labone.core.value import AnnotatedValue, Value from labone.nodetree.errors import ( LabOneInappropriateNodeTypeError, @@ -40,7 +41,7 @@ from labone.core.helper import LabOneNodePath from labone.core.session import NodeInfo as NodeInfoType from labone.core.session import NodeType - from labone.core.subscription import DataQueue + from labone.core.subscription import QueueProtocol from labone.nodetree.enum import NodeEnum T = t.TypeVar("T") @@ -1088,10 +1089,24 @@ async def wait_for_state_change( any value except the passed value. (default = False) Useful when waiting for value to change from existing one. """ - ... # pragma: no cover + ... - @abstractmethod + @t.overload async def subscribe(self) -> DataQueue: + ... + + @t.overload + async def subscribe( + self, + queue_type: type[QueueProtocol], + ) -> QueueProtocol: + ... + + @abstractmethod + async def subscribe( + self, + queue_type: type[QueueProtocol] | None = None, + ) -> QueueProtocol | DataQueue: """Subscribe to a node. Subscribing to a node will cause the server to send updates that happen @@ -1115,6 +1130,12 @@ async def subscribe(self) -> DataQueue: happen when the queue is garbage collected or when the queue is closed manually. + Args: + queue_type: The type of the queue to be returned. This can be + any class matching the DataQueue interface. Only needed if the + default DataQueue class is not sufficient. If None is passed + the default DataQueue class is used. (default=None) + Returns: A DataQueue, which can be used to receive any changes to the node in a flexible manner. @@ -1186,7 +1207,21 @@ def try_generate_subnode( ) raise LabOneInvalidPathError(msg) + @t.overload async def subscribe(self) -> DataQueue: + ... + + @t.overload + async def subscribe( + self, + queue_type: type[QueueProtocol], + ) -> QueueProtocol: + ... + + async def subscribe( + self, + queue_type: type[QueueProtocol] | None = None, + ) -> QueueProtocol | DataQueue: """Subscribe to a node. Subscribing to a node will cause the server to send updates that happen @@ -1207,6 +1242,12 @@ async def subscribe(self) -> DataQueue: happen when the queue is garbage collected or when the queue is closed manually. + Args: + queue_type: The type of the queue to be returned. This can be + any class matching the DataQueue interface. Only needed if the + default DataQueue class is not sufficient. If None is passed + the default DataQueue class is used. (default=None) + Returns: A DataQueue, which can be used to receive any changes to the node in a flexible manner. @@ -1214,6 +1255,7 @@ async def subscribe(self) -> DataQueue: return await self._tree_manager.session.subscribe( self.path, parser_callback=self._tree_manager.parser, + queue_type=queue_type or DataQueue, ) async def wait_for_state_change( @@ -1322,7 +1364,21 @@ def _package_get_response( """ ... + @t.overload async def subscribe(self) -> DataQueue: + ... + + @t.overload + async def subscribe( + self, + queue_type: type[QueueProtocol], + ) -> QueueProtocol: + ... + + async def subscribe( + self, + queue_type: type[QueueProtocol] | None = None, # noqa: ARG002 + ) -> QueueProtocol | DataQueue: """Subscribe to a node. Currently not supported for wildcard and partial nodes. diff --git a/tests/core/test_subscription.py b/tests/core/test_subscription.py index 1aad161..9ee5a8f 100644 --- a/tests/core/test_subscription.py +++ b/tests/core/test_subscription.py @@ -5,6 +5,7 @@ import pytest from labone.core import errors from labone.core.subscription import ( + CircularDataQueue, DataQueue, streaming_handle_factory, ) @@ -146,6 +147,78 @@ async def test_data_queue_get_disconnected_empty(): await queue.get() +@pytest.mark.asyncio() +async def test_circular_data_queue_put_enough_space(): + subscription = FakeSubscription() + queue = CircularDataQueue( + path="dummy", + register_function=subscription.register_data_queue, + ) + queue.maxsize = 2 + await asyncio.wait_for(queue.put("test"), timeout=0.01) + assert queue.qsize() == 1 + assert queue.get_nowait() == "test" + + +@pytest.mark.asyncio() +async def test_circular_data_queue_put_full(): + subscription = FakeSubscription() + queue = CircularDataQueue( + path="dummy", + register_function=subscription.register_data_queue, + ) + queue.maxsize = 2 + await asyncio.wait_for(queue.put("test1"), timeout=0.01) + await asyncio.wait_for(queue.put("test2"), timeout=0.01) + await asyncio.wait_for(queue.put("test3"), timeout=0.01) + assert queue.qsize() == 2 + assert queue.get_nowait() == "test2" + assert queue.get_nowait() == "test3" + + +@pytest.mark.asyncio() +async def test_circular_data_queue_put_no_wait_enough_space(): + subscription = FakeSubscription() + queue = CircularDataQueue( + path="dummy", + register_function=subscription.register_data_queue, + ) + queue.maxsize = 2 + queue.put_nowait("test") + assert queue.qsize() == 1 + assert queue.get_nowait() == "test" + + +@pytest.mark.asyncio() +async def test_circular_data_queue_put_no_wait_full(): + subscription = FakeSubscription() + queue = CircularDataQueue( + path="dummy", + register_function=subscription.register_data_queue, + ) + queue.maxsize = 2 + queue.put_nowait("test1") + queue.put_nowait("test2") + queue.put_nowait("test3") + assert queue.qsize() == 2 + assert queue.get_nowait() == "test2" + assert queue.get_nowait() == "test3" + + +def test_circular_data_queue_fork(): + subscription = FakeSubscription() + queue = CircularDataQueue( + path="dummy", + register_function=subscription.register_data_queue, + ) + assert len(subscription.data_queues) == 1 + forked_queue = queue.fork() + assert isinstance(forked_queue, CircularDataQueue) + assert len(subscription.data_queues) == 2 + assert forked_queue.path == queue.path + assert forked_queue.connected + + def test_streaming_handle_register(reflection_server): streaming_handle_class = streaming_handle_factory(reflection_server) streaming_handle = streaming_handle_class() diff --git a/tests/nodetree/test_node.py b/tests/nodetree/test_node.py index d2d1e7a..1dfc26b 100644 --- a/tests/nodetree/test_node.py +++ b/tests/nodetree/test_node.py @@ -798,6 +798,7 @@ async def test_subscribe(self, mock_path): node._tree_manager.session.subscribe.assert_called_once_with( "path", parser_callback=node._tree_manager.parser, + queue_type=DataQueue, ) @pytest.mark.parametrize(