Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vits code refactoring #173

Merged
merged 3 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 32 additions & 35 deletions wetts/vits/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torchaudio
import torch.utils.data

from mel_processing import spectrogram_torch
from utils import load_filepaths_and_text
from utils.mel_processing import spectrogram_torch
from utils.task import load_filepaths_and_text


class TextAudioSpeakerLoader(torch.utils.data.Dataset):
Expand Down Expand Up @@ -60,16 +60,13 @@ def _filter(self):
src_sampling_rate = torchaudio.info(audiopath).sample_rate
# filename|speaker|text
text = item[2]
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
if self.min_text_len <= len(text) and len(
text) <= self.max_text_len:
audiopaths_sid_text_new.append(item)
lengths.append(
int(
os.path.getsize(audiopath)
* self.sampling_rate
/ src_sampling_rate
)
// (2 * self.hop_length)
)
os.path.getsize(audiopath) * self.sampling_rate /
src_sampling_rate) // (2 * self.hop_length))
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths

Expand All @@ -86,9 +83,8 @@ def get_audio(self, filename):
audio, sampling_rate = torchaudio.load(filename, normalize=False)
if sampling_rate != self.sampling_rate:
audio = audio.to(torch.float)
audio = torchaudio.transforms.Resample(sampling_rate, self.sampling_rate)(
audio
)
audio = torchaudio.transforms.Resample(sampling_rate,
self.sampling_rate)(audio)
audio = audio.to(torch.int16)
audio = audio[0] # Get the first channel
audio_norm = audio / self.max_wav_value
Expand All @@ -114,7 +110,8 @@ def get_sid(self, sid):
return sid

def __getitem__(self, index):
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
return self.get_audio_text_speaker_pair(
self.audiopaths_sid_text[index])

def __len__(self):
return len(self.audiopaths_sid_text)
Expand All @@ -133,9 +130,8 @@ def __call__(self, batch):
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
)
ids = torch.LongTensor([x[1].size(1) for x in batch])
_, ids_sorted_decreasing = torch.sort(ids, dim=0, descending=True)

max_text_len = max([len(x[0]) for x in batch])
max_spec_len = max([x[1].size(1) for x in batch])
Expand All @@ -147,7 +143,8 @@ def __call__(self, batch):
sid = torch.LongTensor(len(batch))

text_padded = torch.LongTensor(len(batch), max_text_len)
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0),
max_spec_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
text_padded.zero_()
spec_padded.zero_()
Expand All @@ -156,15 +153,15 @@ def __call__(self, batch):
row = batch[ids_sorted_decreasing[i]]

text = row[0]
text_padded[i, : text.size(0)] = text
text_padded[i, :text.size(0)] = text
text_lengths[i] = text.size(0)

spec = row[1]
spec_padded[i, :, : spec.size(1)] = spec
spec_padded[i, :, :spec.size(1)] = spec
spec_lengths[i] = spec.size(1)

wav = row[2]
wav_padded[i, :, : wav.size(1)] = wav
wav_padded[i, :, :wav.size(1)] = wav
wav_lengths[i] = wav.size(1)

sid[i] = row[3]
Expand All @@ -191,7 +188,8 @@ def __call__(self, batch):
)


class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler
):
"""
Maintain similar input lengths in a batch.
Length groups are specified by boundaries.
Expand All @@ -212,7 +210,10 @@ def __init__(
rank=None,
shuffle=True,
):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
super().__init__(dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle)
self.lengths = dataset.lengths
self.batch_size = batch_size
self.boundaries = boundaries
Expand All @@ -238,9 +239,8 @@ def _create_buckets(self):
for i in range(len(buckets)):
len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size
rem = (
total_batch_size - (len_bucket % total_batch_size)
) % total_batch_size
rem = (total_batch_size -
(len_bucket % total_batch_size)) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket

Expand All @@ -252,7 +252,8 @@ def __iter__(self):
indices = []
if self.shuffle:
for bucket in self.buckets:
indices.append(torch.randperm(len(bucket), generator=g).tolist())
indices.append(
torch.randperm(len(bucket), generator=g).tolist())
else:
for bucket in self.buckets:
indices.append(list(range(len(bucket))))
Expand All @@ -266,22 +267,18 @@ def __iter__(self):

# add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket
ids_bucket = (
ids_bucket
+ ids_bucket * (rem // len_bucket)
+ ids_bucket[: (rem % len_bucket)]
)
ids_bucket = (ids_bucket + ids_bucket * (rem // len_bucket) +
ids_bucket[:(rem % len_bucket)])

# subsample
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
ids_bucket = ids_bucket[self.rank::self.num_replicas]

# batching
for j in range(len(ids_bucket) // self.batch_size):
batch = [
bucket[idx]
for idx in ids_bucket[
j * self.batch_size : (j + 1) * self.batch_size
]
for idx in ids_bucket[j * self.batch_size:(j + 1) *
self.batch_size]
]
batches.append(batch)

Expand Down
47 changes: 30 additions & 17 deletions wetts/vits/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@

import torch

from models import SynthesizerTrn
import utils
from model.models import SynthesizerTrn
from utils import task


def get_args():
parser = argparse.ArgumentParser(description="export onnx model")
parser.add_argument("--checkpoint", required=True, help="checkpoint")
parser.add_argument("--cfg", required=True, help="config file")
parser.add_argument("--onnx_model", required=True, help="onnx model")
parser.add_argument("--phone_table", required=True, help="input phone dict")
parser.add_argument("--phone_table",
required=True,
help="input phone dict")
parser.add_argument("--speaker_table", default=None, help="speaker table")
parser.add_argument(
"--providers",
Expand All @@ -43,20 +45,18 @@ def main():
args = get_args()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

hps = utils.get_hparams_from_file(args.cfg)
hps = task.get_hparams_from_file(args.cfg)
hps['model']['is_onnx'] = True

phone_num = len(open(args.phone_table).readlines())
num_speakers = len(open(args.speaker_table).readlines())

net_g = SynthesizerTrn(
phone_num,
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=num_speakers,
**hps.model
)
utils.load_checkpoint(args.checkpoint, net_g, None)
net_g = SynthesizerTrn(phone_num,
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=num_speakers,
**hps.model)
task.load_checkpoint(args.checkpoint, net_g, None)
net_g.flow.remove_weight_norm()
net_g.dec.remove_weight_norm()
net_g.forward = net_g.export_forward
Expand All @@ -77,11 +77,24 @@ def main():
input_names=["input", "input_lengths", "scales", "sid"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch", 1: "phonemes"},
"input_lengths": {0: "batch"},
"scales": {0: "batch"},
"sid": {0: "batch"},
"output": {0: "batch", 1: "audio", 2: "audio_length"},
"input": {
0: "batch",
1: "phonemes"
},
"input_lengths": {
0: "batch"
},
"scales": {
0: "batch"
},
"sid": {
0: "batch"
},
"output": {
0: "batch",
1: "audio",
2: "audio_length"
},
},
opset_version=13,
verbose=False,
Expand Down
59 changes: 26 additions & 33 deletions wetts/vits/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,24 @@
from scipy.io import wavfile
import torch

from models import SynthesizerTrn
import utils
from model.models import SynthesizerTrn
from utils import task


def get_args():
parser = argparse.ArgumentParser(description="inference")
parser.add_argument("--checkpoint", required=True, help="checkpoint")
parser.add_argument("--cfg", required=True, help="config file")
parser.add_argument("--outdir", required=True, help="ouput directory")
parser.add_argument("--phone_table", required=True, help="input phone dict")
parser.add_argument("--phone_table",
required=True,
help="input phone dict")
parser.add_argument("--speaker_table", default=True, help="speaker table")
parser.add_argument("--test_file", required=True, help="test file")
parser.add_argument(
"--gpu", type=int, default=-1, help="gpu id for this local rank, -1 for cpu"
)
parser.add_argument("--gpu",
type=int,
default=-1,
help="gpu id for this local rank, -1 for cpu")
args = parser.parse_args()
return args

Expand All @@ -59,19 +62,17 @@ def main():
arr = line.strip().split()
assert len(arr) == 2
speaker_dict[arr[0]] = int(arr[1])
hps = utils.get_hparams_from_file(args.cfg)
hps = task.get_hparams_from_file(args.cfg)

net_g = SynthesizerTrn(
len(phone_dict),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=len(speaker_dict),
**hps.model
)
net_g = SynthesizerTrn(len(phone_dict),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=len(speaker_dict),
**hps.model)
net_g = net_g.to(device)

net_g.eval()
utils.load_checkpoint(args.checkpoint, net_g, None)
task.load_checkpoint(args.checkpoint, net_g, None)

for line in open(args.test_file):
audio_path, speaker, text = line.strip().split("|")
Expand All @@ -84,25 +85,17 @@ def main():
x_length = torch.LongTensor([seq.size(0)]).to(device)
sid = torch.LongTensor([sid]).to(device)
st = time.time()
audio = (
net_g.infer(
x,
x_length,
sid=sid,
noise_scale=0.667,
noise_scale_w=0.8,
length_scale=1,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
audio = (net_g.infer(
x,
x_length,
sid=sid,
noise_scale=0.667,
noise_scale_w=0.8,
length_scale=1,
)[0][0, 0].data.cpu().float().numpy())
audio *= 32767 / max(0.01, np.max(np.abs(audio))) * 0.6
print(
"RTF {}".format(
(time.time() - st) / (audio.shape[0] / hps.data.sampling_rate)
)
)
print("RTF {}".format((time.time() - st) /
(audio.shape[0] / hps.data.sampling_rate)))
sys.stdout.flush()
audio = np.clip(audio, -32767.0, 32767.0)
wavfile.write(
Expand Down
18 changes: 9 additions & 9 deletions wetts/vits/inference_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,22 @@
from scipy.io import wavfile
import torch

import utils
from utils import task


def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.detach().numpy()
)
return (tensor.detach().cpu().numpy()
if tensor.requires_grad else tensor.detach().numpy())


def get_args():
parser = argparse.ArgumentParser(description="inference")
parser.add_argument("--onnx_model", required=True, help="onnx model")
parser.add_argument("--cfg", required=True, help="config file")
parser.add_argument("--outdir", required=True, help="ouput directory")
parser.add_argument("--phone_table", required=True, help="input phone dict")
parser.add_argument("--phone_table",
required=True,
help="input phone dict")
parser.add_argument("--speaker_table", default=True, help="speaker table")
parser.add_argument("--test_file", required=True, help="test file")
parser.add_argument(
Expand All @@ -61,9 +60,10 @@ def main():
arr = line.strip().split()
assert len(arr) == 2
speaker_dict[arr[0]] = int(arr[1])
hps = utils.get_hparams_from_file(args.cfg)
hps = task.get_hparams_from_file(args.cfg)

ort_sess = ort.InferenceSession(args.onnx_model, providers=[args.providers])
ort_sess = ort.InferenceSession(args.onnx_model,
providers=[args.providers])
scales = torch.FloatTensor([0.667, 1.0, 0.8])
# make triton dynamic shape happy
scales = scales.unsqueeze(0)
Expand Down
Loading