diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 25c5d2d9b..8aebea62c 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -1,5 +1,6 @@ +from uvicorn.protocols.http.h11 import H11Protocol +from uvicorn.protocols.http.httptools import HttpToolsProtocol import asyncio -from uvicorn.protocols.http import H11Protocol, HttpToolsProtocol import h11 import pytest @@ -21,7 +22,8 @@ async def __call__(self, receive, send) -> None: "type": "http.response.start", "status": self.status_code, "headers": [ - [key.encode(), value.encode()] for key, value in self.headers.items() + [key.encode(), value.encode()] + for key, value in self.headers.items() ], } ) @@ -64,7 +66,7 @@ def set_content_type(self): b"Content-Type: text/plain", b"Content-Length: 100000", b"", - b'x' * 100000, + b"x" * 100000, ] ) @@ -156,13 +158,14 @@ def test_post_request(protocol_cls): class App: def __init__(self, scope): self.scope = scope + async def __call__(self, receive, send): - body = b'' + body = b"" more_body = True while more_body: message = await receive() - body += message.get('body', b'') - more_body = message.get('more_body', False) + body += message.get("body", b"") + more_body = message.get("more_body", False) response = Response(b"Body: " + body, media_type="text/plain") await response(receive, send) @@ -200,7 +203,9 @@ def app(scope): @pytest.mark.parametrize("protocol_cls", [HttpToolsProtocol, H11Protocol]) def test_chunked_encoding(protocol_cls): def app(scope): - return Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}) + return Response( + b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} + ) protocol = get_connected_protocol(app, protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) @@ -274,7 +279,7 @@ def app(scope): def test_invalid_http(protocol_cls): app = lambda scope: None protocol = get_connected_protocol(app, protocol_cls) - protocol.data_received(b'x' * 100000) + protocol.data_received(b"x" * 100000) assert protocol.transport.is_closing() @@ -283,6 +288,7 @@ def test_app_exception(protocol_cls): class App: def __init__(self, scope): self.scope = scope + async def __call__(self, receive, send): raise Exception() @@ -435,9 +441,11 @@ def __init__(self, scope): async def __call__(self, receive, send): nonlocal got_disconnect_event - message = await receive() - while message['type'] != 'http.disconnect': - continue + while True: + message = await receive() + if message["type"] == "http.disconnect": + break + got_disconnect_event = True protocol = get_connected_protocol(App, protocol_cls) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index cc9a36993..0d310a95c 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1,3 +1,5 @@ +from uvicorn.protocols.http.h11 import H11Protocol +from uvicorn.protocols.http.httptools import HttpToolsProtocol import asyncio import functools import threading @@ -5,7 +7,6 @@ import pytest import websockets from contextlib import contextmanager -from uvicorn.protocols.http import HttpToolsProtocol, H11Protocol class WebSocketResponse: @@ -229,7 +230,9 @@ def test_subprotocols(protocol_cls, acceptable_subprotocol): class App(WebSocketResponse): async def websocket_connect(self, message): if acceptable_subprotocol in self.scope["subprotocols"]: - await self.send({"type": "websocket.accept", "subprotocol": acceptable_subprotocol}) + await self.send( + {"type": "websocket.accept", "subprotocol": acceptable_subprotocol} + ) else: await self.send({"type": "websocket.close"}) diff --git a/tests/raise_import_error.py b/tests/raise_import_error.py new file mode 100644 index 000000000..24158d1da --- /dev/null +++ b/tests/raise_import_error.py @@ -0,0 +1,5 @@ +# Used by test_importer.py + +myattr = 123 + +import does_not_exist diff --git a/tests/test_auto_detection.py b/tests/test_auto_detection.py new file mode 100644 index 000000000..8b85d006d --- /dev/null +++ b/tests/test_auto_detection.py @@ -0,0 +1,18 @@ +from uvicorn.protocols.http.httptools import HttpToolsProtocol +from uvicorn.protocols.http.auto import AutoHTTPProtocol +from uvicorn.loops.auto import auto_loop_setup +import asyncio +import uvloop + +# TODO: Add pypy to our testing matrix, and assert we get the correct classes +# dependent on the platform we're running the tests under. + + +def test_http_auto(): + protocol = AutoHTTPProtocol(app=None) + assert isinstance(protocol, HttpToolsProtocol) + + +def test_loop_auto(): + loop = auto_loop_setup() + assert isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy) diff --git a/tests/test_importer.py b/tests/test_importer.py new file mode 100644 index 000000000..bec6a768a --- /dev/null +++ b/tests/test_importer.py @@ -0,0 +1,44 @@ +from uvicorn.importer import import_from_string, ImportFromStringError +import pytest +import os +import sys + + +def test_invalid_format(): + with pytest.raises(ImportFromStringError) as exc: + import_from_string("example:") + expected = 'Import string "example:" must be in format ":".' + assert expected in str(exc) + + +def test_invalid_module(): + with pytest.raises(ImportFromStringError) as exc: + import_from_string("module_does_not_exist:myattr") + expected = 'Could not import module "module_does_not_exist".' + assert expected in str(exc) + + +def test_invalid_attr(): + with pytest.raises(ImportFromStringError) as exc: + import_from_string("tempfile:attr_does_not_exist") + expected = 'Attribute "attr_does_not_exist" not found in module "tempfile".' + assert expected in str(exc) + + +def test_internal_import_error(): + with pytest.raises(ImportError) as exc: + import_from_string("tests.raise_import_error:myattr") + + +def test_valid_import(): + instance = import_from_string("tempfile:TemporaryFile") + from tempfile import TemporaryFile + + assert instance == TemporaryFile + + +def test_no_import_needed(): + from tempfile import TemporaryFile + + instance = import_from_string(TemporaryFile) + assert instance == TemporaryFile diff --git a/uvicorn/importer.py b/uvicorn/importer.py new file mode 100644 index 000000000..5a3613002 --- /dev/null +++ b/uvicorn/importer.py @@ -0,0 +1,37 @@ +import importlib + + +class ImportFromStringError(Exception): + pass + + +def import_from_string(import_str): + if not isinstance(import_str, str): + return import_str + + module_str, _, attrs_str = import_str.partition(":") + if not module_str or not attrs_str: + message = ( + 'Import string "{import_str}" must be in format ":".' + ) + raise ImportFromStringError(message.format(import_str=import_str)) + + try: + module = importlib.import_module(module_str) + except ImportError as exc: + if exc.name != module_str: + raise + message = 'Could not import module "{module_str}".' + raise ImportFromStringError(message.format(module_str=module_str)) + + instance = module + try: + for attr_str in attrs_str.split("."): + instance = getattr(instance, attr_str) + except AttributeError: + message = 'Attribute "{attrs_str}" not found in module "{module_str}".' + raise ImportFromStringError( + message.format(attrs_str=attrs_str, module_str=module_str) + ) + + return instance diff --git a/uvicorn/loops/asyncio.py b/uvicorn/loops/asyncio.py new file mode 100644 index 000000000..ba56b252b --- /dev/null +++ b/uvicorn/loops/asyncio.py @@ -0,0 +1,5 @@ +import asyncio + + +def asyncio_setup(): + return asyncio.get_event_loop() diff --git a/uvicorn/loops/auto.py b/uvicorn/loops/auto.py new file mode 100644 index 000000000..41ff9dda0 --- /dev/null +++ b/uvicorn/loops/auto.py @@ -0,0 +1,11 @@ +def auto_loop_setup(): + try: + import uvloop + except ImportError: # pragma: no cover + from uvicorn.loops.asyncio import asyncio_setup + + return asyncio_setup() + else: + from uvicorn.loops.uvloop import uvloop_setup + + return uvloop_setup() diff --git a/uvicorn/loops/uvloop.py b/uvicorn/loops/uvloop.py new file mode 100644 index 000000000..4dd937bc5 --- /dev/null +++ b/uvicorn/loops/uvloop.py @@ -0,0 +1,8 @@ +import asyncio +import uvloop + + +def uvloop_setup(): + asyncio.get_event_loop().close() + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + return asyncio.get_event_loop() diff --git a/uvicorn/main.py b/uvicorn/main.py index 6b31bce6a..8d4b501bc 100644 --- a/uvicorn/main.py +++ b/uvicorn/main.py @@ -1,8 +1,6 @@ -from uvicorn.protocols.http import H11Protocol, HttpToolsProtocol - +from uvicorn.importer import import_from_string, ImportFromStringError import asyncio import click -import importlib import signal import os import logging @@ -10,9 +8,6 @@ import sys -LOOP_CHOICES = click.Choice(["uvloop", "asyncio"]) -LEVEL_CHOICES = click.Choice(["debug", "info", "warning", "error", "critical"]) -HTTP_CHOICES = click.Choice(["httptools", "h11"]) LOG_LEVELS = { "critical": logging.CRITICAL, "error": logging.ERROR, @@ -20,94 +15,59 @@ "info": logging.INFO, "debug": logging.DEBUG, } -HTTP_PROTOCOLS = {"h11": H11Protocol, "httptools": HttpToolsProtocol} - +HTTP_PROTOCOLS = { + "auto": "uvicorn.protocols.http.auto:AutoHTTPProtocol", + "h11": "uvicorn.protocols.http.h11:H11Protocol", + "httptools": "uvicorn.protocols.http.httptools:HttpToolsProtocol", +} +LOOP_SETUPS = { + "auto": "uvicorn.loops.auto:auto_loop_setup", + "asyncio": "uvicorn.loops.asyncio:asyncio_setup", + "uvloop": "uvicorn.loops.uvloop:uvloop_setup", +} -if platform.python_implementation() == 'PyPy': - DEFAULT_LOOP = 'asyncio' - DEFAULT_PARSER = 'h11' -elif platform.system() == 'Windows' or platform.system().startswith('CYGWIN'): - DEFAULT_LOOP = 'asyncio' - DEFAULT_PARSER = 'h11' -else: - DEFAULT_LOOP = 'uvloop' - DEFAULT_PARSER = 'httptools' +LEVEL_CHOICES = click.Choice(LOG_LEVELS.keys()) +HTTP_CHOICES = click.Choice(HTTP_PROTOCOLS.keys()) +LOOP_CHOICES = click.Choice(LOOP_SETUPS.keys()) @click.command() @click.argument("app") @click.option("--host", type=str, default="127.0.0.1", help="Host") @click.option("--port", type=int, default=8000, help="Port") -@click.option("--loop", type=LOOP_CHOICES, default=DEFAULT_LOOP, help="Event loop") -@click.option("--http", type=HTTP_CHOICES, default=DEFAULT_PARSER, help="HTTP Handler") +@click.option("--loop", type=LOOP_CHOICES, default="auto", help="Event loop") +@click.option("--http", type=HTTP_CHOICES, default="auto", help="HTTP Handler") @click.option("--workers", type=int, default=1, help="Number of worker processes") @click.option("--log-level", type=LEVEL_CHOICES, default="info", help="Log level") def main(app, host: str, port: int, loop: str, http: str, workers: int, log_level: str): - log_level = LOG_LEVELS[log_level] - logging.basicConfig(format="%(levelname)s: %(message)s", level=log_level) - logger = logging.getLogger() - loop = get_event_loop(loop) - sys.path.insert(0, ".") - app = load_app(app) - protocol_class = HTTP_PROTOCOLS[http] + try: + app = import_from_string(app) + except ImportFromStringError as exc: + click.error("Error loading ASGI app. %s" % exc) if workers != 1: raise click.UsageError( - 'Not yet available. For multiple worker processes, use gunicorn. ' + "Not yet available. For multiple worker processes, use gunicorn. " 'eg. "gunicorn -w 4 -k uvicorn.workers.UvicornWorker".' ) - server = Server(app, host, port, loop, logger, protocol_class) - server.run() + run(app, host, port, http, loop, log_level) -def run(app, host="127.0.0.1", port=8000, log_level="info"): +def run(app, host="127.0.0.1", port=8000, loop="auto", http="auto", log_level="info"): log_level = LOG_LEVELS[log_level] logging.basicConfig(format="%(levelname)s: %(message)s", level=log_level) - - loop = get_event_loop(DEFAULT_LOOP) logger = logging.getLogger() - protocol_class = {'httptools': HttpToolsProtocol, 'h11': H11Protocol}[DEFAULT_PARSER] - server = Server(app, host, port, loop, logger, protocol_class) - server.run() - - -def load_app(app): - if not isinstance(app, str): - return app - - if ":" not in app: - message = 'Invalid app string "{app}". Must be in format ":".' - raise click.UsageError(message.format(app=app)) - - module_str, attrs = app.split(":", 1) - try: - module = importlib.import_module(module_str) - except ModuleNotFoundError as exc: - if exc.name != module_str: - raise - message = 'Error loading ASGI app. Could not import module "{module_str}".' - raise click.UsageError(message.format(module_str=module_str)) - - try: - for attr in attrs.split('.'): - asgi_app = getattr(module, attr) - except AttributeError: - message = 'Error loading ASGI app. No app "{attrs}" found in module "{module_str}".' - raise click.UsageError(message.format(attrs=attrs, module_str=module_str)) - - return asgi_app + app = import_from_string(app) + loop_setup = import_from_string(LOOP_SETUPS[loop]) + protocol_class = import_from_string(HTTP_PROTOCOLS[http]) + loop = loop_setup() -def get_event_loop(loop): - if loop == "uvloop": - import uvloop - - asyncio.get_event_loop().close() - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - return asyncio.get_event_loop() + server = Server(app, host, port, loop, logger, protocol_class) + server.run() class Server: @@ -132,8 +92,8 @@ def __init__( def set_signal_handlers(self): handled = ( - signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. - signal.SIGTERM, # Unix signal 15. Sent by `kill `. + signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. + signal.SIGTERM, # Unix signal 15. Sent by `kill `. ) try: for sig in handled: @@ -154,7 +114,7 @@ def run(self): self.loop.run_forever() def handle_exit(self, sig, frame): - if hasattr(sig, 'name'): + if hasattr(sig, "name"): msg = "Received signal %s. Shutting down." % sig.name else: msg = "Received signal. Shutting down." diff --git a/uvicorn/protocols/http/__init__.py b/uvicorn/protocols/http/__init__.py index c2ceea215..e69de29bb 100644 --- a/uvicorn/protocols/http/__init__.py +++ b/uvicorn/protocols/http/__init__.py @@ -1,5 +0,0 @@ -from uvicorn.protocols.http.h11 import H11Protocol -from uvicorn.protocols.http.httptools import HttpToolsProtocol - - -__all__ = ["H11Protocol", "HttpToolsProtocol"] diff --git a/uvicorn/protocols/http/auto.py b/uvicorn/protocols/http/auto.py new file mode 100644 index 000000000..f9d87fb69 --- /dev/null +++ b/uvicorn/protocols/http/auto.py @@ -0,0 +1,11 @@ +def AutoHTTPProtocol(*args, **kwargs): + try: + import httptools + except ImportError: # pragma: no cover + from uvicorn.protocols.http.h11 import H11Protocol + + return H11Protocol(*args, **kwargs) + else: + from uvicorn.protocols.http.httptools import HttpToolsProtocol + + return HttpToolsProtocol(*args, **kwargs) diff --git a/uvicorn/protocols/http/httptools.py b/uvicorn/protocols/http/httptools.py index 20d90f056..4c05c947f 100644 --- a/uvicorn/protocols/http/httptools.py +++ b/uvicorn/protocols/http/httptools.py @@ -35,11 +35,24 @@ def _get_status_line(status_code): class HttpToolsProtocol(asyncio.Protocol): __slots__ = ( - 'app', 'loop', 'state', 'logger', 'access_logs', 'parser', - 'transport', 'server', 'client', 'scheme', - 'scope', 'headers', 'cycle', 'client_event', - 'readable', 'writable', 'writable_event', - 'pipeline' + "app", + "loop", + "state", + "logger", + "access_logs", + "parser", + "transport", + "server", + "client", + "scheme", + "scope", + "headers", + "cycle", + "client_event", + "readable", + "writable", + "writable_event", + "pipeline", ) def __init__(self, app, loop=None, state=None, logger=None): @@ -189,9 +202,17 @@ def resume_writing(self): class RequestResponseCycle: __slots__ = ( - 'scope', 'protocol', 'disconnected', 'done_callback', - 'body', 'more_body', - 'response_started', 'response_complete', 'keep_alive', 'chunked_encoding', 'expected_content_length' + "scope", + "protocol", + "disconnected", + "done_callback", + "body", + "more_body", + "response_started", + "response_complete", + "keep_alive", + "chunked_encoding", + "expected_content_length", ) def __init__(self, scope, protocol): diff --git a/uvicorn/protocols/websockets/websockets.py b/uvicorn/protocols/websockets/websockets.py index e9c23ebad..c498076a8 100644 --- a/uvicorn/protocols/websockets/websockets.py +++ b/uvicorn/protocols/websockets/websockets.py @@ -8,15 +8,15 @@ def __init__(self, raw_headers): self.raw_headers = raw_headers def get(self, key, default=None): - get_key = key.lower().encode('latin-1') + get_key = key.lower().encode("latin-1") for raw_key, raw_value in self.raw_headers: if raw_key == get_key: - return raw_value.decode('latin-1') + return raw_value.decode("latin-1") return default def __setitem__(self, key, value): - set_key = key.lower().encode('latin-1') - set_value = value.encode('latin-1') + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") for idx, (raw_key, raw_value) in enumerate(self.raw_headers): if raw_key == set_key: self.raw_headers[idx] = set_value @@ -39,8 +39,8 @@ def websocket_upgrade(http): # Retrieve any subprotocols to be negotiated with the consumer later subprotocols = [ - subprotocol.strip() for subprotocol in - request_headers.get("sec-websocket-protocol", "").split(",") + subprotocol.strip() + for subprotocol in request_headers.get("sec-websocket-protocol", "").split(",") ] http.scope.update({"type": "websocket", "subprotocols": subprotocols}) asgi_instance = http.app(http.scope)