diff --git a/tools/api_server.py b/tools/api_server.py index 57e596d7..b5740313 100644 --- a/tools/api_server.py +++ b/tools/api_server.py @@ -6,6 +6,7 @@ from threading import Lock from typing import Annotated, Literal, Optional +import librosa import numpy as np import soundfile as sf import torch @@ -29,6 +30,7 @@ from transformers import AutoTokenizer import tools.llama.generate +from fish_speech.models.vqgan.utils import sequence_mask from tools.llama.generate import encode_tokens, generate, load_model @@ -133,7 +135,6 @@ def __del__(self): logger.info("The vqgan model is removed from memory.") @torch.no_grad() - @torch.autocast(device_type="cuda", enabled=True) def sematic_to_wav(self, indices): model = self.model indices = indices.to(model.device).long() @@ -170,6 +171,57 @@ def sematic_to_wav(self, indices): return fake_audio, model.sampling_rate + @torch.no_grad() + def wav_to_semantic(self, audio): + model = self.model + # Load audio + audio, _ = librosa.load( + audio, + sr=model.sampling_rate, + mono=True, + ) + audios = torch.from_numpy(audio).to(model.device)[None, None, :] + logger.info( + f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds" + ) + + # VQ Encoder + audio_lengths = torch.tensor( + [audios.shape[2]], device=model.device, dtype=torch.long + ) + + features = gt_mels = model.mel_transform( + audios, sample_rate=model.sampling_rate + ) + + if model.downsample is not None: + features = model.downsample(features) + + mel_lengths = audio_lengths // model.hop_length + feature_lengths = ( + audio_lengths + / model.hop_length + / (model.downsample.total_strides if model.downsample is not None else 1) + ).long() + + feature_masks = torch.unsqueeze( + sequence_mask(feature_lengths, features.shape[2]), 1 + ).to(gt_mels.dtype) + + # vq_features is 50 hz, need to convert to true mel size + text_features = model.mel_encoder(features, feature_masks) + _, indices, _ = model.vq_encoder(text_features, feature_masks) + + if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1: + indices = indices[:, 0, :, 0] + else: + logger.error(f"Unknown indices shape: {indices.shape}") + return + + logger.info(f"Generated indices of shape {indices.shape}") + + return indices + class LoadLlamaModelRequest(BaseModel): config_name: str = "text2semantic_finetune" @@ -275,6 +327,7 @@ class InvokeRequest(BaseModel): top_p: float = 0.5 repetition_penalty: float = 1.5 temperature: float = 0.7 + order: str = "zh,jp,en" use_g2p: bool = True seed: Optional[int] = None speaker: Optional[str] = None @@ -299,19 +352,27 @@ def api_invoke_model( llama_model_manager = model["llama"] vqgan_model_manager = model["vqgan"] - # Lock - model["lock"].acquire() - device = llama_model_manager.device seed = req.seed prompt_tokens = req.prompt_tokens logger.info(f"Device: {device}") - prompt_tokens = ( - torch.from_numpy(np.load(prompt_tokens)).to(device) - if prompt_tokens is not None - else None - ) + if prompt_tokens is not None and prompt_tokens.endswith(".npy"): + prompt_tokens = torch.from_numpy(np.load(prompt_tokens)).to(device) + elif prompt_tokens is not None and prompt_tokens.endswith(".wav"): + prompt_tokens = vqgan_model_manager.wav_to_semantic(prompt_tokens) + elif prompt_tokens is not None: + logger.error(f"Unknown prompt tokens: {prompt_tokens}") + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + content="Unknown prompt tokens, it should be either .npy or .wav file.", + ) + else: + prompt_tokens = None + + # Lock + model["lock"].acquire() + encoded = encode_tokens( llama_model_manager.tokenizer, req.text, @@ -321,6 +382,7 @@ def api_invoke_model( device=device, use_g2p=req.use_g2p, speaker=req.speaker, + order=req.order, ) prompt_length = encoded.size(1) logger.info(f"Encoded prompt shape: {encoded.shape}") diff --git a/tools/llama/generate.py b/tools/llama/generate.py index c84a69be..3a02bd6a 100644 --- a/tools/llama/generate.py +++ b/tools/llama/generate.py @@ -268,12 +268,14 @@ def encode_tokens( prompt_tokens=None, use_g2p=False, speaker=None, + order="zh,jp,en", ): if prompt_text is not None: string = prompt_text + " " + string if use_g2p: - prompt = g2p(string) + order = order.split(",") + prompt = g2p(string, order=order) prompt = [ (f"" if i not in pu_symbols and i != pad_symbol else i) for _, i in prompt @@ -382,6 +384,7 @@ def load_model(config_name, checkpoint_path, device, precision): @click.option("--use-g2p/--no-g2p", default=True) @click.option("--seed", type=int, default=42) @click.option("--speaker", type=str, default=None) +@click.option("--order", type=str, default="zh,jp,en") @click.option("--half/--no-half", default=False) def main( text: str, @@ -400,6 +403,7 @@ def main( use_g2p: bool, seed: int, speaker: Optional[str], + order: str, half: bool, ) -> None: device = "cuda" @@ -429,6 +433,7 @@ def main( device=device, use_g2p=use_g2p, speaker=speaker, + order=order, ) prompt_length = encoded.size(1) logger.info(f"Encoded prompt shape: {encoded.shape}")