From d8d71b2cbbef1691d298624e8f1f8643af1e9fce Mon Sep 17 00:00:00 2001 From: Hao Guan <10684225+hguandl@users.noreply.github.com> Date: Fri, 20 Dec 2024 17:38:37 +0800 Subject: [PATCH] feat: Bearer auth for HTTP API (#746) * feat: Bearer auth for HTTP API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tools/api_client.py | 8 +++++++- tools/api_server.py | 32 ++++++++++++++++++++++++++++++-- tools/server/api_utils.py | 1 + 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/tools/api_client.py b/tools/api_client.py index b47a10d5..84978d1d 100644 --- a/tools/api_client.py +++ b/tools/api_client.py @@ -119,6 +119,12 @@ def parse_args(): help="`None` means randomized inference, otherwise deterministic.\n" "It can't be used for fixing a timbre.", ) + parser.add_argument( + "--api_key", + type=str, + default="YOUR_API_KEY", + help="API key for authentication", + ) return parser.parse_args() @@ -173,7 +179,7 @@ def parse_args(): data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), stream=args.streaming, headers={ - "authorization": "Bearer YOUR_API_KEY", + "authorization": f"Bearer {args.api_key}", "content-type": "application/msgpack", }, ) diff --git a/tools/api_server.py b/tools/api_server.py index 7b5d26fc..d57c899a 100644 --- a/tools/api_server.py +++ b/tools/api_server.py @@ -2,8 +2,18 @@ import pyrootutils import uvicorn -from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes +from kui.asgi import ( + Depends, + FactoryClass, + HTTPException, + HttpRoute, + Kui, + OpenAPI, + Routes, +) +from kui.security import bearer_auth from loguru import logger +from typing_extensions import Annotated pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) @@ -31,7 +41,25 @@ def __init__(self): ("/v1/tts", TTSView), ("/v1/chat", ChatView), ] - self.routes = Routes([HttpRoute(path, view) for path, view in self.routes]) + + def api_auth(endpoint): + async def verify(token: Annotated[str, Depends(bearer_auth)]): + if token != self.args.api_key: + raise HTTPException(401, None, "Invalid token") + return await endpoint() + + async def passthrough(): + return await endpoint() + + if self.args.api_key is not None: + return verify + else: + return passthrough + + self.routes = Routes( + [HttpRoute(path, view) for path, view in self.routes], + http_middlewares=[api_auth], + ) self.openapi = OpenAPI( { diff --git a/tools/server/api_utils.py b/tools/server/api_utils.py index 5cfe4c3a..6f5cb938 100644 --- a/tools/server/api_utils.py +++ b/tools/server/api_utils.py @@ -32,6 +32,7 @@ def parse_args(): parser.add_argument("--max-text-length", type=int, default=0) parser.add_argument("--listen", type=str, default="127.0.0.1:8080") parser.add_argument("--workers", type=int, default=1) + parser.add_argument("--api-key", type=str, default=None) return parser.parse_args()