From 95d90c84525703094e8b0d34b85dd7b8e63b17e2 Mon Sep 17 00:00:00 2001 From: Lengyue Date: Wed, 20 Dec 2023 01:08:39 +0000 Subject: [PATCH] Optimize api server --- pyproject.toml | 3 ++- tools/api_server.py | 56 +++++++++++++++++++++++++++++---------------- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fdab7199..8317bafc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,8 @@ dependencies = [ "wandb", "tensorboard", "grpcio>=1.58.0", - "kui>=1.6.0" + "kui>=1.6.0", + "zibai-server>=0.9.0" ] [build-system] diff --git a/tools/api_server.py b/tools/api_server.py index 8ef177f2..57e596d7 100644 --- a/tools/api_server.py +++ b/tools/api_server.py @@ -3,7 +3,8 @@ import time import traceback from http import HTTPStatus -from typing import Annotated, Any, Literal, Optional +from threading import Lock +from typing import Annotated, Literal, Optional import numpy as np import soundfile as sf @@ -82,9 +83,7 @@ def __init__( torch.cuda.synchronize() logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds") - - if self.tokenizer is None: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) if self.compile: logger.info("Compiling model ...") @@ -106,10 +105,9 @@ def __del__(self): class VQGANModel: - def __init__(self, config_name: str, checkpoint_path: str): - if self.cfg is None: - with initialize(version_base="1.3", config_path="../fish_speech/configs"): - self.cfg = compose(config_name=config_name) + def __init__(self, config_name: str, checkpoint_path: str, device: str): + with initialize(version_base="1.3", config_path="../fish_speech/configs"): + self.cfg = compose(config_name=config_name) self.model = instantiate(self.cfg.model) state_dict = torch.load( @@ -120,8 +118,9 @@ def __init__(self, config_name: str, checkpoint_path: str): state_dict = state_dict["state_dict"] self.model.load_state_dict(state_dict, strict=True) self.model.eval() - self.model.cuda() - logger.info("Restored model from checkpoint") + self.model.to(device) + + logger.info("Restored VQGAN model from checkpoint") def __del__(self): self.cfg = None @@ -175,7 +174,6 @@ def sematic_to_wav(self, indices): class LoadLlamaModelRequest(BaseModel): config_name: str = "text2semantic_finetune" checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth" - device: str = "cuda" precision: Literal["float16", "bfloat16"] = "bfloat16" tokenizer: str = "fishaudio/speech-lm-v1" compile: bool = True @@ -186,15 +184,20 @@ class LoadVQGANModelRequest(BaseModel): checkpoint_path: str = "checkpoints/vqgan-v1.pth" +class LoadModelRequest(BaseModel): + device: str = "cuda" + llama: LoadLlamaModelRequest + vqgan: LoadVQGANModelRequest + + class LoadModelResponse(BaseModel): name: str @routes.http.put("/models/{name}") -def load_model( +def api_load_model( name: Annotated[str, Path("default")], - llama: Annotated[LoadLlamaModelRequest, Body()], - vqgan: Annotated[LoadVQGANModelRequest, Body()], + req: Annotated[LoadModelRequest, Body(exclusive=True)], ) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]: """ Load model @@ -203,12 +206,15 @@ def load_model( if name in MODELS: del MODELS[name] + llama = req.llama + vqgan = req.vqgan + logger.info("Loading model ...") new_model = { "llama": LlamaModel( config_name=llama.config_name, checkpoint_path=llama.checkpoint_path, - device=llama.device, + device=req.device, precision=llama.precision, tokenizer_path=llama.tokenizer, compile=llama.compile, @@ -216,7 +222,9 @@ def load_model( "vqgan": VQGANModel( config_name=vqgan.config_name, checkpoint_path=vqgan.checkpoint_path, + device=req.device, ), + "lock": Lock(), } MODELS[name] = new_model @@ -225,7 +233,7 @@ def load_model( @routes.http.delete("/models/{name}") -def delete_model( +def api_delete_model( name: Annotated[str, Path("default")], ) -> JSONResponse[200, {}, dict]: """ @@ -238,6 +246,8 @@ def delete_model( content="Model not found.", ) + del MODELS[name] + return JSONResponse( dict(message="Model deleted."), 200, @@ -245,7 +255,7 @@ def delete_model( @routes.http.get("/models") -def list_models() -> JSONResponse[200, {}, dict]: +def api_list_models() -> JSONResponse[200, {}, dict]: """ List models """ @@ -271,7 +281,7 @@ class InvokeRequest(BaseModel): @routes.http.post("/models/{name}/invoke") -def invoke_model( +def api_invoke_model( name: Annotated[str, Path("default")], req: Annotated[InvokeRequest, Body(exclusive=True)], ): @@ -289,6 +299,9 @@ def invoke_model( llama_model_manager = model["llama"] vqgan_model_manager = model["vqgan"] + # Lock + model["lock"].acquire() + device = llama_model_manager.device seed = req.seed prompt_tokens = req.prompt_tokens @@ -348,6 +361,9 @@ def invoke_model( codes = codes - 2 assert (codes >= 0).all(), "Codes should be >= 0" + # Release lock + model["lock"].release() + # --------------- llama end ------------ audio, sr = vqgan_model_manager.sematic_to_wav(codes) # --------------- vqgan end ------------ @@ -358,8 +374,8 @@ def invoke_model( return StreamResponse( iterable=[buffer.getvalue()], headers={ - "Content-Disposition": "attachment; filename=generated.wav", - "Content-Type": "audio/wav", + "Content-Disposition": "attachment; filename=audio.wav", + "Content-Type": "application/octet-stream", }, )