diff --git a/aiosip/application.py b/aiosip/application.py index c8e6f85..d226cc3 100644 --- a/aiosip/application.py +++ b/aiosip/application.py @@ -167,11 +167,12 @@ async def reply(*args, **kwargs): method=msg.method, from_details=Contact.from_header(msg.headers['To']), to_details=Contact.from_header(msg.headers['From']), - call_id=call_id + call_id=call_id, + inbound=True ) await dialog.reply(*args, **kwargs) - await dialog.close(fast=True) + await dialog.close() try: route = await self.dialplan.resolve( @@ -226,7 +227,12 @@ def finish(self): def register_on_finish(self, func, *args, **kwargs): self._finish_callbacks.insert(0, (func, args, kwargs)) - async def close(self): + async def close(self, timeout=5): + for dialog in set(self._dialogs.values()): + try: + await dialog.close(timeout=timeout) + except asyncio.TimeoutError: + pass for connector in self._connectors.values(): await connector.close() for task in self._tasks: diff --git a/aiosip/dialog.py b/aiosip/dialog.py index 51e8b63..84a43dd 100644 --- a/aiosip/dialog.py +++ b/aiosip/dialog.py @@ -4,6 +4,7 @@ from collections import defaultdict from multidict import CIMultiDict +from async_timeout import timeout as Timeout from . import utils from .message import Request, Response @@ -151,10 +152,19 @@ def _maybe_close(self, msg): def _close(self): LOG.debug('Closing: %s', self) + if self._closing: + self._closing.cancel() + for transactions in self.transactions.values(): for transaction in transactions.values(): transaction.close() + # Should not be necessary once dialog are correctly tracked + try: + del self.app._dialogs[self.dialog_id] + except KeyError as e: + pass + def _connection_lost(self): for transactions in self.transactions.values(): for transaction in transactions.values(): @@ -187,10 +197,11 @@ def end_transaction(self, transaction): for item in to_delete: del self.transactions[item[0]][item[1]] - async def request(self, method, contact_details=None, headers=None, payload=None): + async def request(self, method, contact_details=None, headers=None, payload=None, timeout=None): msg = self._prepare_request(method, contact_details, headers, payload) if msg.method != 'ACK': - return await self.start_unreliable_transaction(msg) + async with Timeout(timeout): + return await self.start_unreliable_transaction(msg) else: self.peer.send_message(msg) @@ -274,7 +285,7 @@ async def refresh(self, headers=None, expires=1800, *args, **kwargs): headers['Expires'] = int(expires) return await self.request(self.original_msg.method, headers=headers, *args, **kwargs) - async def close(self, fast=False, headers=None, *args, **kwargs): + async def close(self, headers=None, *args, **kwargs): if not self._closed: self._closed = True result = None @@ -282,7 +293,11 @@ async def close(self, fast=False, headers=None, *args, **kwargs): headers = CIMultiDict(headers or {}) if 'Expires' not in headers: headers['Expires'] = 0 - result = await self.request(self.original_msg.method, headers=headers, *args, **kwargs) + try: + result = await self.request(self.original_msg.method, headers=headers, *args, **kwargs) + finally: + self._close() + self._close() return result @@ -409,7 +424,7 @@ def end_transaction(self, transaction): for item in to_delete: del self.transactions[item[0]][item[1]] - async def close(self): + async def close(self, timeout=None): if not self._closed: self._closed = True @@ -422,9 +437,11 @@ async def close(self): if msg: transaction = UnreliableTransaction(self, original_msg=msg, loop=self.app.loop) self.transactions[msg.method][msg.cseq] = transaction - await transaction.start() - self._close() + try: + async with Timeout(timeout): + await transaction.start() + finally: + self._close() - def _close(self): - pass + self._close() diff --git a/tests/conftest.py b/tests/conftest.py index 6f0e3e1..da91bf6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,7 @@ -import asyncio - import aiosip import pytest - +import asyncio +import itertools pytest_plugins = ['aiosip.pytest_plugin'] @@ -119,3 +118,8 @@ def to_details(request): @pytest.fixture def loop(event_loop): return event_loop + + +@pytest.fixture(params=itertools.permutations(('client', 'server'))) +def close_order(request): + return request.param diff --git a/tests/test_sip_proxy.py b/tests/test_sip_proxy.py index 6398254..626806a 100644 --- a/tests/test_sip_proxy.py +++ b/tests/test_sip_proxy.py @@ -1,8 +1,11 @@ -import asyncio import aiosip +import pytest +import asyncio +import itertools -async def test_proxy_subscribe(test_server, test_proxy, protocol, loop, from_details, to_details): +@pytest.mark.parametrize('close_order', itertools.permutations(('client', 'server', 'proxy'))) +async def test_proxy_subscribe(test_server, test_proxy, protocol, loop, from_details, to_details, close_order): callback_complete = loop.create_future() callback_complete_proxy = loop.create_future() @@ -59,12 +62,19 @@ async def proxy_subscribe(self, request, message): assert received_request_server.payload == received_request_proxy.payload assert received_request_server.headers == received_request_proxy.headers - await server_app.close() - await proxy_app.close() - await app.close() + for item in close_order: + if item == 'client': + await app.close() + elif item == 'server': + await server_app.close() + elif item == 'proxy': + await proxy_app.close() + else: + raise ValueError('Invalid close_order') -async def test_proxy_notify(test_server, test_proxy, protocol, loop, from_details, to_details): # noQa: C901 +@pytest.mark.parametrize('close_order', itertools.permutations(('client', 'server', 'proxy'))) +async def test_proxy_notify(test_server, test_proxy, protocol, loop, from_details, to_details, close_order): callback_complete = loop.create_future() callback_complete_proxy = loop.create_future() @@ -140,6 +150,12 @@ async def proxy_subscribe(self, request, message): assert received_notify_server.payload == received_notify_proxy.payload assert received_notify_server.headers == received_notify_proxy.headers - await server_app.close() - await proxy_app.close() - await app.close() + for item in close_order: + if item == 'client': + await app.close() + elif item == 'server': + await server_app.close() + elif item == 'proxy': + await proxy_app.close() + else: + raise ValueError('Invalid close_order') diff --git a/tests/test_sip_scenario.py b/tests/test_sip_scenario.py index fdb1a89..d78ebb4 100644 --- a/tests/test_sip_scenario.py +++ b/tests/test_sip_scenario.py @@ -2,8 +2,10 @@ import pytest import asyncio +from async_timeout import timeout -async def test_notify(test_server, protocol, loop, from_details, to_details): + +async def test_notify(test_server, protocol, loop, from_details, to_details, close_order): notify_list = [0, 1, 2, 3, 4] subscribe_future = loop.create_future() @@ -14,13 +16,19 @@ async def resolve(self, *args, **kwargs): return self.subscribe async def subscribe(self, request, msg): - dialog = await request.prepare(status_code=200) + expires = int(msg.headers['Expires']) + dialog = await request.prepare(status_code=200, headers={'Expires': expires}) await asyncio.sleep(0.1) for i in notify_list: await dialog.notify(payload=str(i)) subscribe_future.set_result(None) + async for msg in dialog: + if msg.method == 'SUBSCRIBE': + expires = int(msg.headers['Expires']) + await dialog.reply(msg, status_code=200, headers={'Expires': expires}) + app = aiosip.Application(loop=loop) server_app = aiosip.Application(loop=loop, dialplan=Dialplan()) server = await test_server(server_app) @@ -43,11 +51,15 @@ async def subscribe(self, request, msg): await subscribe_future - await server_app.close() - await app.close() + if close_order[0] == 'client': + await app.close() + await server_app.close() + else: + await server_app.close() + await app.close() -async def test_authentication(test_server, protocol, loop, from_details, to_details): +async def test_authentication(test_server, protocol, loop, from_details, to_details, close_order): password = 'abcdefg' received_messages = list() @@ -66,8 +78,10 @@ async def subscribe(self, request, message): async for message in dialog: received_messages.append(message) - assert dialog.validate_auth(message, password) - await dialog.reply(message, 200) + if dialog.validate_auth(message, password): + await dialog.reply(message, 200) + else: + await dialog.unauthorized(message) app = aiosip.Application(loop=loop) server_app = aiosip.Application(loop=loop, dialplan=Dialplan()) @@ -88,11 +102,15 @@ async def subscribe(self, request, message): assert len(received_messages) == 2 assert 'Authorization' in received_messages[1].headers - await server_app.close() - await app.close() + if close_order[0] == 'client': + await app.close() + await server_app.close() + else: + await server_app.close() + await app.close() -async def test_authentication_rejection(test_server, protocol, loop, from_details, to_details): +async def test_authentication_rejection(test_server, protocol, loop, from_details, to_details, close_order): received_messages = list() class Dialplan(aiosip.BaseDialplan): @@ -135,11 +153,15 @@ async def subscribe(self, request, message): assert len(received_messages) == 2 assert result.status_code == 401 - await server_app.close() - await app.close() + if close_order[0] == 'client': + await app.close() + await server_app.close() + else: + await server_app.close() + await app.close() -async def test_invite(test_server, protocol, loop, from_details, to_details): +async def test_invite(test_server, protocol, loop, from_details, to_details, close_order): call_established = loop.create_future() call_disconnected = loop.create_future() @@ -189,11 +211,15 @@ async def invite(self, request, message): assert responses == [100, 180, 200] - await app.close() - await server_app.close() + if close_order[0] == 'client': + await app.close() + await server_app.close() + else: + await server_app.close() + await app.close() -async def test_cancel(test_server, protocol, loop, from_details, to_details): +async def test_cancel(test_server, protocol, loop, from_details, to_details, close_order): cancel_future = loop.create_future() class Dialplan(aiosip.BaseDialplan): @@ -232,5 +258,9 @@ async def cancel(self, request, message): result = await cancel_future assert result.method == 'CANCEL' - await app.close() - await server_app.close() + if close_order[0] == 'client': + await app.close() + await server_app.close() + else: + await server_app.close() + await app.close() diff --git a/tests/test_sip_server.py b/tests/test_sip_server.py index 3f854f1..aa3c3fb 100644 --- a/tests/test_sip_server.py +++ b/tests/test_sip_server.py @@ -1,7 +1,7 @@ import aiosip -async def test_subscribe(test_server, protocol, loop, from_details, to_details): +async def test_subscribe(test_server, protocol, loop, from_details, to_details, close_order): callback_complete = loop.create_future() class Dialplan(aiosip.BaseDialplan): @@ -35,11 +35,15 @@ async def on_subscribe(self, request, message): assert received_request.method == 'SUBSCRIBE' - await server_app.close() - await app.close() + if close_order[0] == 'client': + await app.close() + await server_app.close() + else: + await server_app.close() + await app.close() -async def test_response_501(test_server, protocol, loop, from_details, to_details): +async def test_response_501(test_server, protocol, loop, from_details, to_details, close_order): app = aiosip.Application(loop=loop) server_app = aiosip.Application(loop=loop) server = await test_server(server_app) @@ -56,11 +60,15 @@ async def test_response_501(test_server, protocol, loop, from_details, to_detail assert subscription.status_code == 501 assert subscription.status_message == 'Not Implemented' - await server_app.close() - await app.close() + if close_order[0] == 'client': + await app.close() + await server_app.close() + else: + await server_app.close() + await app.close() -async def test_exception_in_handler(test_server, protocol, loop, from_details, to_details): +async def test_exception_in_handler(test_server, protocol, loop, from_details, to_details, close_order): class Dialplan(aiosip.BaseDialplan): @@ -90,5 +98,9 @@ async def on_subscribe(self, request, message): assert subscription.status_code == 500 assert subscription.status_message == 'Server Internal Error' - await server_app.close() - await app.close() + if close_order[0] == 'client': + await app.close() + await server_app.close() + else: + await server_app.close() + await app.close()