Skip to content

Commit

Permalink
vits code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
pengzhendong committed Nov 14, 2023
1 parent 7c2e94b commit 455de16
Show file tree
Hide file tree
Showing 21 changed files with 1,372 additions and 1,390 deletions.
4 changes: 2 additions & 2 deletions wetts/vits/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torchaudio
import torch.utils.data

from mel_processing import spectrogram_torch
from utils import load_filepaths_and_text
from utils.mel_processing import spectrogram_torch
from utils.task import load_filepaths_and_text


class TextAudioSpeakerLoader(torch.utils.data.Dataset):
Expand Down
8 changes: 4 additions & 4 deletions wetts/vits/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import torch

from models import SynthesizerTrn
import utils
from model.models import SynthesizerTrn
from utils import task


def get_args():
Expand All @@ -43,7 +43,7 @@ def main():
args = get_args()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

hps = utils.get_hparams_from_file(args.cfg)
hps = task.get_hparams_from_file(args.cfg)
hps['model']['is_onnx'] = True

phone_num = len(open(args.phone_table).readlines())
Expand All @@ -56,7 +56,7 @@ def main():
n_speakers=num_speakers,
**hps.model
)
utils.load_checkpoint(args.checkpoint, net_g, None)
task.load_checkpoint(args.checkpoint, net_g, None)
net_g.flow.remove_weight_norm()
net_g.dec.remove_weight_norm()
net_g.forward = net_g.export_forward
Expand Down
8 changes: 4 additions & 4 deletions wetts/vits/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from scipy.io import wavfile
import torch

from models import SynthesizerTrn
import utils
from model.models import SynthesizerTrn
from utils import task


def get_args():
Expand Down Expand Up @@ -59,7 +59,7 @@ def main():
arr = line.strip().split()
assert len(arr) == 2
speaker_dict[arr[0]] = int(arr[1])
hps = utils.get_hparams_from_file(args.cfg)
hps = task.get_hparams_from_file(args.cfg)

net_g = SynthesizerTrn(
len(phone_dict),
Expand All @@ -71,7 +71,7 @@ def main():
net_g = net_g.to(device)

net_g.eval()
utils.load_checkpoint(args.checkpoint, net_g, None)
task.load_checkpoint(args.checkpoint, net_g, None)

for line in open(args.test_file):
audio_path, speaker, text = line.strip().split("|")
Expand Down
4 changes: 2 additions & 2 deletions wetts/vits/inference_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from scipy.io import wavfile
import torch

import utils
from utils import task


def to_numpy(tensor):
Expand Down Expand Up @@ -61,7 +61,7 @@ def main():
arr = line.strip().split()
assert len(arr) == 2
speaker_dict[arr[0]] = int(arr[1])
hps = utils.get_hparams_from_file(args.cfg)
hps = task.get_hparams_from_file(args.cfg)

ort_sess = ort.InferenceSession(args.onnx_model, providers=[args.providers])
scales = torch.FloatTensor([0.667, 1.0, 0.8])
Expand Down
4 changes: 2 additions & 2 deletions wetts/vits/attentions.py → wetts/vits/model/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from torch import nn
from torch.nn import functional as F

import commons
from modules import LayerNorm
from model.normalization import LayerNorm
from utils import commons


class Encoder(nn.Module):
Expand Down
Loading

0 comments on commit 455de16

Please sign in to comment.