diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 44784ff8..cff92ffa 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -8,6 +8,7 @@ import asyncio import unittest +import weakref from uvloop import _testbase as tb @@ -25,12 +26,11 @@ async def on_request(request): app = aiohttp.web.Application() app.router.add_get('/', on_request) - f = self.loop.create_server( - app.make_handler(), - '0.0.0.0', '0') - srv = self.loop.run_until_complete(f) - - port = srv.sockets[0].getsockname()[1] + runner = aiohttp.web.AppRunner(app) + self.loop.run_until_complete(runner.setup()) + site = aiohttp.web.TCPSite(runner, '0.0.0.0', '0') + self.loop.run_until_complete(site.start()) + port = site._server.sockets[0].getsockname()[1] async def test(): # Make sure we're using the correct event loop. @@ -45,11 +45,61 @@ async def test(): self.assertEqual(result, PAYLOAD) self.loop.run_until_complete(test()) - self.loop.run_until_complete(app.shutdown()) - self.loop.run_until_complete(app.cleanup()) + self.loop.run_until_complete(runner.cleanup()) + + def test_aiohttp_graceful_shutdown(self): + async def websocket_handler(request): + ws = aiohttp.web.WebSocketResponse() + await ws.prepare(request) + request.app['websockets'].add(ws) + try: + async for msg in ws: + await ws.send_str(msg.data) + finally: + request.app['websockets'].discard(ws) + return ws + + async def on_shutdown(app): + for ws in set(app['websockets']): + await ws.close( + code=aiohttp.WSCloseCode.GOING_AWAY, + message='Server shutdown') + + asyncio.set_event_loop(self.loop) + app = aiohttp.web.Application() + app.router.add_get('/', websocket_handler) + app.on_shutdown.append(on_shutdown) + app['websockets'] = weakref.WeakSet() + + runner = aiohttp.web.AppRunner(app) + self.loop.run_until_complete(runner.setup()) + site = aiohttp.web.TCPSite(runner, '0.0.0.0', '0') + self.loop.run_until_complete(site.start()) + port = site._server.sockets[0].getsockname()[1] + + async def client(): + async with aiohttp.ClientSession() as client: + async with client.ws_connect( + 'http://127.0.0.1:{}'.format(port)) as ws: + await ws.send_str("hello") + async for msg in ws: + assert msg.data == "hello" + + client_task = asyncio.ensure_future(client()) + + async def stop(): + await asyncio.sleep(0.1) + try: + await asyncio.wait_for(runner.cleanup(), timeout=0.1) + finally: + try: + client_task.cancel() + await client_task + except asyncio.CancelledError: + pass + + self.loop.run_until_complete(stop()) - srv.close() - self.loop.run_until_complete(srv.wait_closed()) @unittest.skipIf(skip_tests, "no aiohttp module") diff --git a/uvloop/server.pyx b/uvloop/server.pyx index fb4c4691..5e0100af 100644 --- a/uvloop/server.pyx +++ b/uvloop/server.pyx @@ -44,7 +44,10 @@ cdef class Server: @cython.iterable_coroutine async def wait_closed(self): - if self._waiters is None: + # Do not remove `self._servers is None` below + # because close() method only closes server sockets + # and existing client connections are left open. + if self._servers is None or self._waiters is None: return waiter = self._loop._new_future() self._waiters.append(waiter)