diff --git a/asgi_correlation_id/middleware.py b/asgi_correlation_id/middleware.py index 1eef03e..72a6c70 100644 --- a/asgi_correlation_id/middleware.py +++ b/asgi_correlation_id/middleware.py @@ -46,7 +46,7 @@ async def __call__(self, scope: 'Scope', receive: 'Receive', send: 'Send') -> No """ Load request ID from headers if present. Generate one otherwise. """ - if scope['type'] != 'http': + if scope['type'] not in ('http', 'websocket'): await self.app(scope, receive, send) return diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 7581995..e7ccc83 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -120,6 +120,9 @@ async def test_websocket_request(caplog, app): @app.websocket_route('/ws') async def websocket(websocket: 'WebSocket'): + # Check we get the right headers back + assert websocket.headers.get('x-request-id') is not None + await websocket.accept() await websocket.send_json({'msg': 'Hello WebSocket'}) await websocket.close()