Skip to content

Commit

Permalink
Merge pull request #146 from deiteris/rmvpe-improvements
Browse files Browse the repository at this point in the history
RMVPE improvements
  • Loading branch information
deiteris authored Jul 23, 2024
2 parents ac838b3 + 00403b5 commit d356c30
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 59 deletions.
3 changes: 2 additions & 1 deletion server/downloader/Downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ async def download(params: dict):
if expected_hash is not None and hash != expected_hash:
raise DownloadVerificationException(saveTo, hash, expected_hash)

write_file_entry(saveTo, hash)
if expected_hash is not None:
write_file_entry(saveTo, hash)

def write_file_entry(saveTo: str, hash: str):
global lock, files
Expand Down
18 changes: 9 additions & 9 deletions server/downloader/WeightDownloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,47 +21,47 @@ async def downloadWeight(params: ServerSettings):
# "hash": "fed21bfb71a38df821cf9ae43e5da8b3",
# },
{
"url": "https://huggingface.co/wok000/weights/resolve/main/crepe/onnx/full.onnx",
"url": "https://huggingface.co/wok000/weights/resolve/7a376af24f4f21f9d9e24160ed9d858e8a33bf93/crepe/onnx/full.onnx",
"saveTo": params.crepe_onnx_full,
"hash": "e9bb11eb5d3557805715077b30aefebc",
},
{
"url": "https://huggingface.co/wok000/weights/resolve/main/crepe/onnx/tiny.onnx",
"url": "https://huggingface.co/wok000/weights/resolve/7a376af24f4f21f9d9e24160ed9d858e8a33bf93/crepe/onnx/tiny.onnx",
"saveTo": params.crepe_onnx_tiny,
"hash": "b509427f6d223152e57ff2aeb1b48300",
},
{
"url": "https://github.com/maxrmorrison/torchcrepe/raw/master/torchcrepe/assets/full.pth",
"url": "https://github.com/maxrmorrison/torchcrepe/raw/745670a18bf8c5f1a2f08c910c72433badde3e08/torchcrepe/assets/full.pth",
"saveTo": params.crepe_full,
"hash": "2ab425d128692f27ad5b765f13752333",
},
{
"url": "https://github.com/maxrmorrison/torchcrepe/raw/master/torchcrepe/assets/tiny.pth",
"url": "https://github.com/maxrmorrison/torchcrepe/raw/745670a18bf8c5f1a2f08c910c72433badde3e08/torchcrepe/assets/tiny.pth",
"saveTo": params.crepe_tiny,
"hash": "eec11d7661587b6b90da7823cf409340",
},
{
"url": "https://huggingface.co/wok000/weights_gpl/resolve/main/content-vec/contentvec-f.onnx",
"url": "https://huggingface.co/wok000/weights_gpl/resolve/c2f3e4a8884dba0995347dfe24dc0ad40acb9eb7/content-vec/contentvec-f.onnx",
"saveTo": params.content_vec_500_onnx,
"hash": "ab288ca5b540a4a15909a40edf875d1e",
},
{
"url": "https://huggingface.co/wok000/weights/resolve/main/rmvpe/rmvpe_20231006.pt",
"url": "https://huggingface.co/wok000/weights/resolve/4a9dbeb086b66721378b4fb29c84bf94d3e076ec/rmvpe/rmvpe_20231006.pt",
"saveTo": params.rmvpe,
"hash": "7989809b6b54fb33653818e357bcb643",
},
{
"url": "https://huggingface.co/deiteris/weights/resolve/main/rmvpe.onnx",
"url": "https://huggingface.co/deiteris/weights/resolve/5040af391eb55d6415a209bfeb3089a866491670/rmvpe_upd.onnx",
"saveTo": params.rmvpe_onnx,
"hash": "9d8ae16af5ac4d9a200e4723de35b30b",
"hash": "9c6d7712f84d487ae781b0d7435c269b",
},
{
"url": "https://github.com/CNChTu/FCPE/raw/819765c8db719c457f53aaee3238879ab98ed0cd/torchfcpe/assets/fcpe_c_v001.pt",
"saveTo": params.fcpe,
"hash": "933f1b588409b3945389381a2ab98014",
},
{
"url": "https://huggingface.co/deiteris/weights/resolve/main/fcpe.onnx",
"url": "https://huggingface.co/deiteris/weights/resolve/6abbb0285b1fc154e112b3c002ae63e1c1733d53/fcpe.onnx",
"saveTo": params.fcpe_onnx,
"hash": "6a7b11db05def00053102920d039760f",
},
Expand Down
7 changes: 4 additions & 3 deletions server/utils/rmvpe_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,11 @@ def decode(self, hidden: torch.Tensor, threshold: float) -> torch.Tensor:
product_sum = torch.sum(weights * self.idx_cents, dim=2) # [B, T]
weight_sum = torch.sum(weights, dim=2) # [B, T]
cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T]
cents[cents <= threshold] = 0
return 10 * 2 ** (cents / 1200)
f0 = 10 * 2 ** (cents / 1200)
uv = hidden.max(dim=2)[0] < threshold # [B, T]
return f0 * ~uv

def convert(pt_model: torch.nn.Module, input_names: list[str], inputs: tuple[torch.Tensor], output_names: list[str], dynamic_axes: dict, convert_to_fp16: bool) -> onnx.ModelProto:
def convert(pt_model: torch.nn.Module, input_names: list[str], inputs: tuple[torch.Tensor], output_names: list[str], dynamic_axes: dict) -> onnx.ModelProto:
with BytesIO() as io:
torch.onnx.export(
pt_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ def extract(
sr: int,
window: int,
) -> torch.Tensor:
return self.rmvpe.infer_from_audio_t(audio)
return self.rmvpe.infer_from_audio_t(audio).squeeze()
64 changes: 19 additions & 45 deletions server/voice_changer/common/rmvpe/rmvpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from librosa.filters import mel



class BiGRU(nn.Module):
def __init__(self, input_features, hidden_features, num_layers):
super(BiGRU, self).__init__()
Expand Down Expand Up @@ -328,10 +327,7 @@ def forward(self, audio, keyshift=0, speed=1, center=True):


class RMVPE:
def __init__(self, model_path: str, is_half: bool, device: torch.device | str = None):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def __init__(self, model_path: str, is_half: bool, device: torch.device):
model = E2E(4, 1, (2, 2))
if model_path.endswith('.safetensors'):
with safe_open(model_path, 'pt', device=str(device) if device.type == 'cuda' else 'cpu') as cpt:
Expand All @@ -345,54 +341,32 @@ def __init__(self, model_path: str, is_half: bool, device: torch.device | str =
model = model.half()
self.model = model

self.device = device
self.mel_extractor = MelSpectrogram(
is_half, 128, 16000, 1024, 160, None, 30, 8000
).to(device)
cents_mapping = 20 * torch.arange(360, device=device) + 1997.3794084376191
self.cents_mapping = F.pad(cents_mapping, (4, 4))
self.idx = torch.arange(360, device=device)[None, None, :]
self.idx_cents = self.idx * 20 + 1997.3794084376191

@torch.no_grad()
def mel2hidden(self, mel: torch.Tensor) -> torch.Tensor:
n_frames = mel.shape[-1]
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
if n_pad > 0:
mel = F.pad(mel, (0, n_pad), mode="reflect")
hidden = self.model(mel)
return hidden[:, :n_frames]
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflect')
return self.model(mel)[:, :n_frames]

def decode(self, hidden: torch.Tensor, threshold: float):
cents_pred = self.to_local_average_cents(hidden, threshold=threshold)
f0 = 10 * (2 ** (cents_pred / 1200))
f0[f0 == 10] = 0
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
return f0
center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1]
start = torch.clip(center - 4, min=0) # [B, T, 1]
end = torch.clip(center + 5, max=360) # [B, T, 1]
idx_mask = (self.idx >= start) & (self.idx < end) # [B, T, N]
weights = hidden * idx_mask # [B, T, N]
product_sum = torch.sum(weights * self.idx_cents, dim=2) # [B, T]
weight_sum = torch.sum(weights, dim=2) # [B, T]
cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T]
f0 = 10 * 2 ** (cents / 1200)
uv = hidden.max(dim=2)[0] < threshold # [B, T]
return f0 * ~uv

@torch.no_grad()
def infer_from_audio_t(self, audio: torch.Tensor, threshold: float = 0.05) -> torch.Tensor:
mel: torch.Tensor = self.mel_extractor(audio.unsqueeze(0), center=True)
hidden = self.mel2hidden(mel).squeeze(0)
f0 = self.decode(hidden, threshold=threshold)
return f0

def to_local_average_cents(self, hidden: torch.Tensor, threshold: float):
# t0 = ttime()
starts = torch.argmax(hidden, dim=1) # 帧长#index
hidden = F.pad(hidden, (8, 0)) # 帧长,368
# t1 = ttime()
hidden_avgs: list[torch.Tensor] = []
cents_mapping_avgs: list[torch.Tensor] = []
center = starts + 4
ends = center + 5
for idx in range(hidden.shape[0]):
hidden_avgs.append(hidden[:, starts[idx] : ends[idx]][idx]) # NOQA
cents_mapping_avgs.append(self.cents_mapping[starts[idx] : ends[idx]]) # NOQA
# t2 = ttime()
hidden_avgs = torch.stack(hidden_avgs) # 帧长,9
cents_mapping_avgs = torch.stack(cents_mapping_avgs) # 帧长,9
division = torch.sum(hidden_avgs * cents_mapping_avgs, 1) / torch.sum(hidden_avgs, 1) # 帧长
# t3 = ttime()
maxx = torch.max(hidden, dim=1) # 帧长
division[maxx.values <= threshold] = 0
# t4 = ttime()
# print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
return division
hidden = self.mel2hidden(mel)
return self.decode(hidden, threshold)

0 comments on commit d356c30

Please sign in to comment.