Skip to content

Commit

Permalink
refactor: openapi doc (#770)
Browse files Browse the repository at this point in the history
* refactor: openapi doc

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: spicysama <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2024
1 parent d8d71b2 commit 0b48e78
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 224 deletions.
37 changes: 11 additions & 26 deletions tools/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
OpenAPI,
Routes,
)
from kui.cors import CORSConfig
from kui.openapi.specification import Info
from kui.security import bearer_auth
from loguru import logger
from typing_extensions import Annotated
Expand All @@ -20,27 +22,13 @@
from tools.server.api_utils import MsgPackRequest, parse_args
from tools.server.exception_handler import ExceptionHandler
from tools.server.model_manager import ModelManager
from tools.server.views import (
ASRView,
ChatView,
HealthView,
TTSView,
VQGANDecodeView,
VQGANEncodeView,
)
from tools.server.views import routes


class API(ExceptionHandler):
def __init__(self):
self.args = parse_args()
self.routes = [
("/v1/health", HealthView),
("/v1/vqgan/encode", VQGANEncodeView),
("/v1/vqgan/decode", VQGANDecodeView),
("/v1/asr", ASRView),
("/v1/tts", TTSView),
("/v1/chat", ChatView),
]
self.routes = routes

def api_auth(endpoint):
async def verify(token: Annotated[str, Depends(bearer_auth)]):
Expand All @@ -56,16 +44,13 @@ async def passthrough():
else:
return passthrough

self.routes = Routes(
[HttpRoute(path, view) for path, view in self.routes],
http_middlewares=[api_auth],
)

self.openapi = OpenAPI(
{
"title": "Fish Speech API",
"version": "1.5.0",
},
Info(
{
"title": "Fish Speech API",
"version": "1.5.0",
}
),
).routes

# Initialize the app
Expand All @@ -76,7 +61,7 @@ async def passthrough():
Exception: self.other_exception_handler,
},
factory_class=FactoryClass(http=MsgPackRequest),
cors_config={},
cors_config=CORSConfig(),
)

# Add the state variables
Expand Down
3 changes: 2 additions & 1 deletion tools/schema.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import queue
from dataclasses import dataclass
from typing import Annotated, Literal
from typing import Literal

import torch
from pydantic import BaseModel, Field, conint, conlist
from pydantic.functional_validators import SkipValidation
from typing_extensions import Annotated

from fish_speech.conversation import Message, TextPart, VQPart

Expand Down
Loading

0 comments on commit 0b48e78

Please sign in to comment.