diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 51c66af..22deaad 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -20,6 +20,9 @@ jobs: run: pip install -U setuptools wheel - name: install run: pip install .[dev,ci] + - name: install async requirements + if: matrix.python != '2.7' + run: pip install trio curio - name: test run: python -m pytest --reruns 5 tests/ --cov oscpy/ --cov-branch - name: coveralls diff --git a/oscpy/server/__init__.py b/oscpy/server/__init__.py index dce997a..6186db3 100644 --- a/oscpy/server/__init__.py +++ b/oscpy/server/__init__.py @@ -9,6 +9,7 @@ from time import time from functools import partial import socket +from select import select from oscpy import __version__ from oscpy.parser import read_packet, UNICODE @@ -263,13 +264,22 @@ def close(self, sock=None): elif not sock: raise RuntimeError('no default socket yet and no socket provided') + if sock == self.default_socket: + self.default_socket = None + + if sock not in self.sockets: + return + + self.sockets.remove(sock) + read = select([sock], [], [], 0) if platform != 'win32' and sock.family == socket.AF_UNIX: + print(sock.getsockname()) os.unlink(sock.getsockname()) else: sock.close() - if sock == self.default_socket: - self.default_socket = None + if sock in read: + sock.recvfrom(UDP_MAX_SIZE) def getaddress(self, sock=None): """Wrap call to getsockname. diff --git a/oscpy/server/curio_server.py b/oscpy/server/curio_server.py index 3ed7646..1a6e4b3 100644 --- a/oscpy/server/curio_server.py +++ b/oscpy/server/curio_server.py @@ -1,5 +1,7 @@ import logging from typing import Awaitable +from sys import platform +import os from curio import TaskGroup, socket from oscpy.server import OSCBaseServer, UDP_MAX_SIZE @@ -61,6 +63,24 @@ async def process(self): for s in self.sockets: await g.spawn(self._listen, s) + async def close(self, sock=None): + """Close a socket opened by the server.""" + if not sock and self.default_socket: + sock = self.default_socket + elif not sock: + raise RuntimeError('no default socket yet and no socket provided') + + if sock not in self.sockets: + logger.warning("Ignoring requested to close an unknown socket %s" % sock) + + if sock == self.default_socket: + self.default_socket = None + + if platform != 'win32' and sock.family == socket.AF_UNIX: + os.unlink(sock.getsockname()) + else: + await sock.close() + async def stop_all(self): await self.tasks_group.cancel_remaining() diff --git a/oscpy/server/thread_server.py b/oscpy/server/thread_server.py index a19e152..7e806f7 100644 --- a/oscpy/server/thread_server.py +++ b/oscpy/server/thread_server.py @@ -37,13 +37,7 @@ def stop(self, s=None): s = self.default_socket if s in self.sockets: - read = select([s], [], [], 0) - s.close() - if s in read: - s.recvfrom(UDP_MAX_SIZE) - self.sockets.remove(s) - if s is self.default_socket: - self.default_socket = None + self.close(s) else: raise RuntimeError('{} is not one of my sockets!'.format(s)) diff --git a/oscpy/server/trio_server.py b/oscpy/server/trio_server.py index c2c3e1e..d21eeb3 100644 --- a/oscpy/server/trio_server.py +++ b/oscpy/server/trio_server.py @@ -1,7 +1,10 @@ +import os import logging from functools import partial +from sys import platform +from typing import Awaitable -from trio import socket, open_nursery +from trio import socket, open_nursery, move_on_after from oscpy.server import OSCBaseServer, UDP_MAX_SIZE logging.basicConfig() @@ -33,24 +36,34 @@ async def listen( "Unknown socket family, accepted values are 'unix' and 'inet'" ) - sock = await self.get_socket(family_, (address, port)) + if family == 'unix': + addr = address + else: + addr = (address, port) + sock = await self.get_socket(family_, addr) self.add_socket(sock, default) return sock async def _listen(self, sock): async with open_nursery() as nursery: self.nurseries[sock] = nursery - while True: - data, addr = await sock.recvfrom(UDP_MAX_SIZE) - nursery.start_soon( - partial( - self.handle_message, - data, - addr, - drop_late=False, - sender_socket=sock + try: + while True: + data, addr = await sock.recvfrom(UDP_MAX_SIZE) + nursery.start_soon( + partial( + self.handle_message, + data, + addr, + drop_late=False, + sender_socket=sock + ) ) - ) + finally: + with move_on_after(1) as cleanup_scope: + cleanup_scope.shield = True + logger.info("socket %s cancelled", sock) + await self.stop(sock) async def handle_message(self, data, sender, drop_late, sender_socket): for callbacks, values, address in self.callbacks(data, sender, sender_socket): @@ -60,13 +73,17 @@ async def _execute_callbacks(self, callbacks_list, address, values): for cb, get_address in callbacks_list: try: if get_address: - await cb(address, *values) + result = cb(address, *values) else: - await cb(*values) + result = cb(*values) + if isinstance(result, Awaitable): + await result + except Exception: if self.intercept_errors: - logger.error("Unhandled exception caught in oscpy server", exc_info=True) + logger.error("Ignoring unhandled exception caught in oscpy server", exc_info=True) else: + logger.exception("Unhandled exception caught in oscpy server") raise async def process(self): @@ -80,9 +97,41 @@ async def stop_all(self): """ self.nursery.cancel_scope.deadline = 0 - async def stop(self, sock): - nursery = self.nurseries.pop(sock) - nursery.cancel_scope.deadline = 0 + async def stop(self, sock=None): + if sock is None: + if self.default_socket: + sock = self.default_socket + else: + raise RuntimeError('no default socket yet and no socket provided') + if sock in self.sockets: + self.sockets.remove(sock) + else: + raise RuntimeError("Socket %s is not managed by this server" % sock) + sock.close() + if sock in self.nurseries: + nursery = self.nurseries.pop(sock) + nursery.cancel_scope.deadline = 0 + + if sock is self.default_socket: + self.default_socket = None + + async def close(self, sock=None): + """Close a socket opened by the server.""" + if not sock and self.default_socket: + sock = self.default_socket + elif not sock: + raise RuntimeError('no default socket yet and no socket provided') + + if sock not in self.sockets: + logger.warning("Ignoring requested to close an unknown socket %s" % sock) + + if sock == self.default_socket: + self.default_socket = None + + if platform != 'win32' and sock.family == socket.AF_UNIX: + os.unlink(sock.getsockname()) + else: + sock.close() def getaddress(self, sock=None): """Wrap call to getsockname. diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_server.py b/tests/test_server.py index dfac5d5..f7e8985 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -12,7 +12,7 @@ from oscpy.client import send_message, send_bundle, OSCClient from oscpy import __version__ -from utils import runner, _await, _callback +from tests.utils import runner, _await, _callback if version_info > (3, 5, 0): from oscpy.server.curio_server import OSCCurioServer @@ -20,12 +20,20 @@ from oscpy.server.asyncio_server import OSCAsyncioServer server_classes = { OSCThreadServer, - # OSCTrioServer, + OSCTrioServer, OSCAsyncioServer, OSCCurioServer, } else: - server_classes = [OSCThreadServer] + # so we can refer to them safely + OSCTrioServer = OSCAsyncioServer = OSCCurioServer = None + server_classes = {OSCThreadServer} + + +# # force a one second interval between each test to avoid messages hitting the +# # wrong server +# def teardown_function(function): +# sleep(1) @pytest.mark.parametrize("cls", server_classes) @@ -34,10 +42,11 @@ def test_instance(cls): @pytest.mark.parametrize("cls", server_classes) -def test_listen(cls): +def test_listen_simple(cls): osc = cls() - sock = osc.listen() + sock = _await(osc.listen, osc) runner(osc, timeout=1, socket=sock) + _await(osc.close, osc, (sock,)) @pytest.mark.parametrize("cls", server_classes) @@ -63,18 +72,15 @@ def test_listen_default(cls): # osc.listen(default=True) _await(osc.listen, osc, kwargs=dict(default=True)) - osc.close(sock) + _await(osc.close, osc, (sock,)) _await(osc.listen, osc, kwargs=dict(default=True)) @pytest.mark.parametrize("cls", server_classes) def test_close(cls): osc = cls() - _await(osc.listen, osc, kwargs=dict(default=True)) - - osc.close() - with pytest.raises(RuntimeError) as e_info: # noqa - osc.close() + sock = _await(osc.listen, osc, kwargs=dict(default=True)) + _await(osc.close, osc, (sock,)) @pytest.mark.skipif(platform == 'win32', reason="unix sockets not available on windows") @@ -82,10 +88,9 @@ def test_close(cls): def test_close_unix(cls): osc = cls() filename = mktemp() - # unix = osc.listen(address=filename, family='unix') unix = _await(osc.listen, osc, kwargs=dict(address=filename, family='unix')) assert exists(filename) - osc.close(unix) + _await(osc.close, osc, (unix,)) assert not exists(filename) @@ -101,7 +106,7 @@ def test_stop_default(cls): osc = cls() _await(osc.listen, osc, kwargs=dict(default=True)) assert len(osc.sockets) == 1 - osc.stop() + _await(osc.stop, osc) assert len(osc.sockets) == 0 @@ -112,12 +117,14 @@ def test_stop_all(cls): host, port = sock.getsockname() sock2 = _await(osc.listen, osc) assert len(osc.sockets) == 2 - osc.stop_all() + runner(osc, timeout=.2) + _await(osc.stop_all, osc) assert len(osc.sockets) == 0 sleep(.1) sock3 = _await(osc.listen, osc, kwargs=dict(default=True)) assert len(osc.sockets) == 1 - osc.stop_all() + runner(osc, timeout=.2) + _await(osc.stop_all, osc) @pytest.mark.parametrize("cls", {OSCThreadServer}) @@ -149,7 +156,7 @@ def broken_callback(*values): raise ValueError("some bad value") osc = cls() - sock = osc.listen() + sock = _await(osc.listen, osc) port = sock.getsockname()[1] osc.bind(b'/broken_callback', broken_callback, sock) osc.bind(b'/success', success, sock) @@ -165,11 +172,14 @@ def broken_callback(*values): assert record.exc_info osc = cls(intercept_errors=False) - sock = osc.listen() + sock = _await(osc.listen, osc) port = sock.getsockname()[1] osc.bind(b'/broken_callback', broken_callback, sock) - send_message(b'/broken_callback', [b'test'], 'localhost', port) - runner(osc, timeout=.2) + try: + send_message(b'/broken_callback', [b'test'], 'localhost', port) + runner(osc, timeout=.2) + except Exception: + pass assert len(caplog.records) == 2, caplog.records # Unchanged @@ -179,7 +189,7 @@ def test_send_bundle_without_socket(cls): with pytest.raises(RuntimeError): osc.send_bundle([], 'localhost', 0) - osc.listen(default=True) + sock = _await(osc.listen, osc, kwargs={'default': True}) osc.send_bundle( ( (b'/test', []), @@ -191,7 +201,7 @@ def test_send_bundle_without_socket(cls): @pytest.mark.parametrize("cls", server_classes) def test_bind1(cls): osc = cls() - sock = osc.listen(default=True) + sock = _await(osc.listen, osc, kwargs={'default': True}) port = sock.getsockname()[1] event = Event() @@ -208,7 +218,7 @@ def success(*values): @pytest.mark.parametrize("cls", server_classes) def test_bind_get_address(cls): osc = cls() - sock = osc.listen(default=True) + sock = _await(osc.listen, osc, kwargs={'default': True}) port = sock.getsockname()[1] event = Event() @@ -227,7 +237,7 @@ def success(address, *values): @pytest.mark.parametrize("cls", server_classes) def test_bind_get_address_smart(cls): osc = cls(advanced_matching=True) - sock = osc.listen(default=True) + sock = _await(osc.listen, osc, kwargs={'default': True}) port = sock.getsockname()[1] event = Event() @@ -244,7 +254,7 @@ def success(address, *values): @pytest.mark.parametrize("cls", server_classes) def test_reuse_callback(cls): osc = cls() - sock = osc.listen() + sock = _await(osc.listen, osc) port = sock.getsockname()[1] def success(*values): @@ -260,7 +270,7 @@ def success(*values): @pytest.mark.parametrize("cls", server_classes) def test_unbind(cls): osc = cls() - sock = osc.listen() + sock = _await(osc.listen, osc) port = sock.getsockname()[1] event = Event() @@ -280,7 +290,7 @@ def failure(*values): @pytest.mark.parametrize("cls", server_classes) def test_unbind_default(cls): osc = cls() - sock = osc.listen(default=True) + sock = _await(osc.listen, osc, kwargs={'default': True}) port = sock.getsockname()[1] event = Event() @@ -299,12 +309,12 @@ def failure(*values): def test_bind_multi(cls): osc = cls() - sock1 = osc.listen() + sock1 = _await(osc.listen, osc) port1 = sock1.getsockname()[1] event1 = Event() osc.bind(b'/success', _callback(osc, lambda *_: event1.set()), sock1) - sock2 = osc.listen() + sock2 = _await(osc.listen, osc) port2 = sock2.getsockname()[1] event2 = Event() osc.bind(b'/success', _callback(osc, lambda *_: event2.set()), sock2) @@ -323,7 +333,7 @@ def test_bind_multi(cls): @pytest.mark.parametrize("cls", server_classes) def test_bind_address(cls): osc = cls() - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) result = [] event = Event() @@ -342,7 +352,7 @@ def success(*args): @pytest.mark.parametrize("cls", server_classes) def test_bind_address_class(cls): osc = cls() - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) @ServerClass class Test(object): @@ -373,7 +383,7 @@ def success(*values): @pytest.mark.parametrize("cls", server_classes) def test_bind_default(cls): osc = cls() - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) port = osc.getaddress()[1] event = Event() @@ -468,7 +478,7 @@ def test_smart_address_cache(cls): @pytest.mark.parametrize("cls", server_classes) def test_advanced_matching(cls): osc = cls(advanced_matching=True) - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) port = osc.getaddress()[1] result = {} event = Event() @@ -653,7 +663,7 @@ def done(*values): @pytest.mark.parametrize("cls", server_classes) def test_decorator(cls): osc = cls() - sock = osc.listen(default=True) + sock = _await(osc.listen, osc, kwargs=dict(default=True)) port = sock.getsockname()[1] event1 = Event() event2 = Event() @@ -678,7 +688,7 @@ def test_answer(cls): event = Event() osc_1 = cls(intercept_errors=False) - osc_1.listen(default=True) + _await(osc_1.listen, osc_1, kwargs=dict(default=True)) @osc_1.address(b'/ping') def ping(*values): @@ -692,14 +702,14 @@ def ping(*values): ) osc_2 = OSCThreadServer(intercept_errors=False) - osc_2.listen(default=True) + _await(osc_2.listen, osc_2, kwargs=dict(default=True)) @osc_2.address(b'/pong') def pong(*values): osc_2.answer(b'/ping', [True]) osc_3 = OSCThreadServer(intercept_errors=False) - osc_3.listen(default=True) + _await(osc_3.listen, osc_3, kwargs=dict(default=True)) @osc_3.address(b'/zap') def zap(*values): @@ -720,26 +730,28 @@ def zap(*values): @pytest.mark.parametrize("cls", server_classes) def test_socket_family(cls): osc = cls() - assert osc.listen().family == socket.AF_INET + sock = _await(osc.listen, osc) + assert sock.family == socket.AF_INET filename = mktemp() if platform != 'win32': - assert osc.listen(address=filename, family='unix').family == socket.AF_UNIX # noqa + sock = _await(osc.listen, osc, kwargs=dict(address=filename, family='unix')) + assert sock.family == socket.AF_UNIX # noqa else: with pytest.raises(AttributeError) as e_info: - osc.listen(address=filename, family='unix') + _await(osc.listen, osc, kwargs=dict(family='unix')) if exists(filename): unlink(filename) with pytest.raises(ValueError) as e_info: # noqa - osc.listen(family='') + _await(osc.listen, osc, kwargs=dict(family='')) @pytest.mark.parametrize("cls", server_classes) def test_encoding_send(cls): osc = cls() - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) values = [] event = Event() @@ -763,7 +775,7 @@ def encoded(*val): @pytest.mark.parametrize("cls", server_classes) def test_encoding_receive(cls): osc = cls(encoding='utf8') - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) values = [] event = Event() @@ -790,7 +802,7 @@ def encoded(*val): @pytest.mark.parametrize("cls", server_classes) def test_encoding_send_receive(cls): osc = cls(encoding='utf8') - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) event = Event() values = [] @@ -821,7 +833,7 @@ def test(address, *values): event.set() osc = cls(default_handler=test) - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) @osc.address(b'/passthrough') def passthrough(*values): @@ -851,7 +863,7 @@ def passthrough(*values): @pytest.mark.parametrize("cls", {OSCThreadServer}) def test_get_version(cls): osc = cls(encoding='utf8') - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) values = [] @@ -876,7 +888,7 @@ def cb(val): @pytest.mark.parametrize("cls", {OSCThreadServer}) def test_get_routes(cls): osc = cls(encoding='utf8') - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) event = Event() values = [] @@ -908,7 +920,7 @@ def cb(*routes): @pytest.mark.parametrize("cls", server_classes) def test_get_sender(cls): osc = cls(encoding='utf8') - osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) event = Event() @osc.address(u'/test_route') @@ -941,7 +953,7 @@ def callback(index): # server, will be tested: osc = cls(encoding='utf8') - sock = osc.listen(address='0.0.0.0', default=True) + sock = _await(osc.listen, osc, kwargs=dict(address='0.0.0.0', default=True)) port = sock.getsockname()[1] osc.bind('/callback', callback) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..2bbc665 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,20 @@ +from sys import version_info +from time import sleep + +if version_info > (3, 5, 0): + from tests.utils_async import _await, runner, _callback +else: + def runner(osc, timeout=5, socket=None): + sleep(timeout) + if socket: + osc.stop(socket) + else: + osc.stop_all() + + def _await(something, osc, args=None, kwargs=None): + args = args or [] + kwargs = kwargs or {} + return something(*args, **kwargs) + + def _callback(osc, function): + return function diff --git a/tests/utils_async.py b/tests/utils_async.py new file mode 100644 index 0000000..cacdc43 --- /dev/null +++ b/tests/utils_async.py @@ -0,0 +1,65 @@ +from functools import partial +from typing import Awaitable +import inspect + +import curio +import trio +import asyncio +from oscpy.server.curio_server import OSCCurioServer +from oscpy.server.trio_server import OSCTrioServer +from oscpy.server.asyncio_server import OSCAsyncioServer +from oscpy.server.thread_server import OSCThreadServer + +def _await(something, osc, args=None, kwargs=None, timeout=1): + args = args or [] + kwargs = kwargs or {} + if isinstance(osc, OSCTrioServer): + return trio.run(partial(something, *args, **kwargs)) + if isinstance(osc, OSCCurioServer): + async def wrapper(): + result = something(*args, **kwargs) + if isinstance(result, Awaitable): + result = await result + return result + return curio.run(wrapper) + else: + return something(*args, **kwargs) + +def handle_exception(loop, context): + # context["message"] will always be there; but context["exception"] may not + msg = context.get("exception", context["message"]) + logging.error(f"Caught exception: {msg}") + logging.info("Shutting down...") + asyncio.create_task(shutdown(loop)) + +async def _trio_with_timout(process, timeout): + with trio.move_on_after(timeout): + await process() + +def runner(osc, timeout=1, socket=None): + if isinstance(osc, OSCThreadServer): + sleep(timeout) + if socket: + osc.stop(socket) + else: + osc.stop_all() + elif isinstance(osc, OSCCurioServer): + try: + curio.run(curio.timeout_after(timeout, osc.process)) + except curio.TaskTimeout: + ... + elif isinstance(osc, OSCTrioServer): + trio.run(lambda: _trio_with_timout(osc.process, timeout)) + elif isinstance(osc, OSCAsyncioServer): + loop = asyncio.get_event_loop() + # loop.set_debug(True) + loop.set_exception_handler(handle_exception) + loop.run_until_complete(osc.process()) + +def _callback(osc, function): + if isinstance(osc, OSCAsyncioServer): + async def _(*args, **kwargs): + return function(*args, **kwargs) + return _ + return function +