Skip to content

Commit

Permalink
Add g2p config and semantic convert to api server
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Dec 20, 2023
1 parent c163ea5 commit 3a08434
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 10 deletions.
80 changes: 71 additions & 9 deletions tools/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from threading import Lock
from typing import Annotated, Literal, Optional

import librosa
import numpy as np
import soundfile as sf
import torch
Expand All @@ -29,6 +30,7 @@
from transformers import AutoTokenizer

import tools.llama.generate
from fish_speech.models.vqgan.utils import sequence_mask
from tools.llama.generate import encode_tokens, generate, load_model


Expand Down Expand Up @@ -133,7 +135,6 @@ def __del__(self):
logger.info("The vqgan model is removed from memory.")

@torch.no_grad()
@torch.autocast(device_type="cuda", enabled=True)
def sematic_to_wav(self, indices):
model = self.model
indices = indices.to(model.device).long()
Expand Down Expand Up @@ -170,6 +171,57 @@ def sematic_to_wav(self, indices):

return fake_audio, model.sampling_rate

@torch.no_grad()
def wav_to_semantic(self, audio):
model = self.model
# Load audio
audio, _ = librosa.load(
audio,
sr=model.sampling_rate,
mono=True,
)
audios = torch.from_numpy(audio).to(model.device)[None, None, :]
logger.info(
f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
)

# VQ Encoder
audio_lengths = torch.tensor(
[audios.shape[2]], device=model.device, dtype=torch.long
)

features = gt_mels = model.mel_transform(
audios, sample_rate=model.sampling_rate
)

if model.downsample is not None:
features = model.downsample(features)

mel_lengths = audio_lengths // model.hop_length
feature_lengths = (
audio_lengths
/ model.hop_length
/ (model.downsample.total_strides if model.downsample is not None else 1)
).long()

feature_masks = torch.unsqueeze(
sequence_mask(feature_lengths, features.shape[2]), 1
).to(gt_mels.dtype)

# vq_features is 50 hz, need to convert to true mel size
text_features = model.mel_encoder(features, feature_masks)
_, indices, _ = model.vq_encoder(text_features, feature_masks)

if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1:
indices = indices[:, 0, :, 0]
else:
logger.error(f"Unknown indices shape: {indices.shape}")
return

logger.info(f"Generated indices of shape {indices.shape}")

return indices


class LoadLlamaModelRequest(BaseModel):
config_name: str = "text2semantic_finetune"
Expand Down Expand Up @@ -275,6 +327,7 @@ class InvokeRequest(BaseModel):
top_p: float = 0.5
repetition_penalty: float = 1.5
temperature: float = 0.7
order: str = "zh,jp,en"
use_g2p: bool = True
seed: Optional[int] = None
speaker: Optional[str] = None
Expand All @@ -299,19 +352,27 @@ def api_invoke_model(
llama_model_manager = model["llama"]
vqgan_model_manager = model["vqgan"]

# Lock
model["lock"].acquire()

device = llama_model_manager.device
seed = req.seed
prompt_tokens = req.prompt_tokens
logger.info(f"Device: {device}")

prompt_tokens = (
torch.from_numpy(np.load(prompt_tokens)).to(device)
if prompt_tokens is not None
else None
)
if prompt_tokens is not None and prompt_tokens.endswith(".npy"):
prompt_tokens = torch.from_numpy(np.load(prompt_tokens)).to(device)
elif prompt_tokens is not None and prompt_tokens.endswith(".wav"):
prompt_tokens = vqgan_model_manager.wav_to_semantic(prompt_tokens)
elif prompt_tokens is not None:
logger.error(f"Unknown prompt tokens: {prompt_tokens}")
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
content="Unknown prompt tokens, it should be either .npy or .wav file.",
)
else:
prompt_tokens = None

# Lock
model["lock"].acquire()

encoded = encode_tokens(
llama_model_manager.tokenizer,
req.text,
Expand All @@ -321,6 +382,7 @@ def api_invoke_model(
device=device,
use_g2p=req.use_g2p,
speaker=req.speaker,
order=req.order,
)
prompt_length = encoded.size(1)
logger.info(f"Encoded prompt shape: {encoded.shape}")
Expand Down
7 changes: 6 additions & 1 deletion tools/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,14 @@ def encode_tokens(
prompt_tokens=None,
use_g2p=False,
speaker=None,
order="zh,jp,en",
):
if prompt_text is not None:
string = prompt_text + " " + string

if use_g2p:
prompt = g2p(string)
order = order.split(",")
prompt = g2p(string, order=order)
prompt = [
(f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
for _, i in prompt
Expand Down Expand Up @@ -382,6 +384,7 @@ def load_model(config_name, checkpoint_path, device, precision):
@click.option("--use-g2p/--no-g2p", default=True)
@click.option("--seed", type=int, default=42)
@click.option("--speaker", type=str, default=None)
@click.option("--order", type=str, default="zh,jp,en")
@click.option("--half/--no-half", default=False)
def main(
text: str,
Expand All @@ -400,6 +403,7 @@ def main(
use_g2p: bool,
seed: int,
speaker: Optional[str],
order: str,
half: bool,
) -> None:
device = "cuda"
Expand Down Expand Up @@ -429,6 +433,7 @@ def main(
device=device,
use_g2p=use_g2p,
speaker=speaker,
order=order,
)
prompt_length = encoded.size(1)
logger.info(f"Encoded prompt shape: {encoded.shape}")
Expand Down

0 comments on commit 3a08434

Please sign in to comment.