diff --git a/wetts/vits/data_utils.py b/wetts/vits/data_utils.py index 1037fbb..283ce0f 100644 --- a/wetts/vits/data_utils.py +++ b/wetts/vits/data_utils.py @@ -5,8 +5,8 @@ import torchaudio import torch.utils.data -from mel_processing import spectrogram_torch -from utils import load_filepaths_and_text +from utils.mel_processing import spectrogram_torch +from utils.task import load_filepaths_and_text class TextAudioSpeakerLoader(torch.utils.data.Dataset): @@ -60,16 +60,13 @@ def _filter(self): src_sampling_rate = torchaudio.info(audiopath).sample_rate # filename|speaker|text text = item[2] - if self.min_text_len <= len(text) and len(text) <= self.max_text_len: + if self.min_text_len <= len(text) and len( + text) <= self.max_text_len: audiopaths_sid_text_new.append(item) lengths.append( int( - os.path.getsize(audiopath) - * self.sampling_rate - / src_sampling_rate - ) - // (2 * self.hop_length) - ) + os.path.getsize(audiopath) * self.sampling_rate / + src_sampling_rate) // (2 * self.hop_length)) self.audiopaths_sid_text = audiopaths_sid_text_new self.lengths = lengths @@ -86,9 +83,8 @@ def get_audio(self, filename): audio, sampling_rate = torchaudio.load(filename, normalize=False) if sampling_rate != self.sampling_rate: audio = audio.to(torch.float) - audio = torchaudio.transforms.Resample(sampling_rate, self.sampling_rate)( - audio - ) + audio = torchaudio.transforms.Resample(sampling_rate, + self.sampling_rate)(audio) audio = audio.to(torch.int16) audio = audio[0] # Get the first channel audio_norm = audio / self.max_wav_value @@ -114,7 +110,8 @@ def get_sid(self, sid): return sid def __getitem__(self, index): - return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) + return self.get_audio_text_speaker_pair( + self.audiopaths_sid_text[index]) def __len__(self): return len(self.audiopaths_sid_text) @@ -133,9 +130,8 @@ def __call__(self, batch): batch: [text_normalized, spec_normalized, wav_normalized, sid] """ # Right zero-pad all one-hot text sequences to max input length - _, ids_sorted_decreasing = torch.sort( - torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True - ) + ids = torch.LongTensor([x[1].size(1) for x in batch]) + _, ids_sorted_decreasing = torch.sort(ids, dim=0, descending=True) max_text_len = max([len(x[0]) for x in batch]) max_spec_len = max([x[1].size(1) for x in batch]) @@ -147,7 +143,8 @@ def __call__(self, batch): sid = torch.LongTensor(len(batch)) text_padded = torch.LongTensor(len(batch), max_text_len) - spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), + max_spec_len) wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) text_padded.zero_() spec_padded.zero_() @@ -156,15 +153,15 @@ def __call__(self, batch): row = batch[ids_sorted_decreasing[i]] text = row[0] - text_padded[i, : text.size(0)] = text + text_padded[i, :text.size(0)] = text text_lengths[i] = text.size(0) spec = row[1] - spec_padded[i, :, : spec.size(1)] = spec + spec_padded[i, :, :spec.size(1)] = spec spec_lengths[i] = spec.size(1) wav = row[2] - wav_padded[i, :, : wav.size(1)] = wav + wav_padded[i, :, :wav.size(1)] = wav wav_lengths[i] = wav.size(1) sid[i] = row[3] @@ -191,7 +188,8 @@ def __call__(self, batch): ) -class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): +class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler + ): """ Maintain similar input lengths in a batch. Length groups are specified by boundaries. @@ -212,7 +210,10 @@ def __init__( rank=None, shuffle=True, ): - super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + super().__init__(dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle) self.lengths = dataset.lengths self.batch_size = batch_size self.boundaries = boundaries @@ -238,9 +239,8 @@ def _create_buckets(self): for i in range(len(buckets)): len_bucket = len(buckets[i]) total_batch_size = self.num_replicas * self.batch_size - rem = ( - total_batch_size - (len_bucket % total_batch_size) - ) % total_batch_size + rem = (total_batch_size - + (len_bucket % total_batch_size)) % total_batch_size num_samples_per_bucket.append(len_bucket + rem) return buckets, num_samples_per_bucket @@ -252,7 +252,8 @@ def __iter__(self): indices = [] if self.shuffle: for bucket in self.buckets: - indices.append(torch.randperm(len(bucket), generator=g).tolist()) + indices.append( + torch.randperm(len(bucket), generator=g).tolist()) else: for bucket in self.buckets: indices.append(list(range(len(bucket)))) @@ -266,22 +267,18 @@ def __iter__(self): # add extra samples to make it evenly divisible rem = num_samples_bucket - len_bucket - ids_bucket = ( - ids_bucket - + ids_bucket * (rem // len_bucket) - + ids_bucket[: (rem % len_bucket)] - ) + ids_bucket = (ids_bucket + ids_bucket * (rem // len_bucket) + + ids_bucket[:(rem % len_bucket)]) # subsample - ids_bucket = ids_bucket[self.rank :: self.num_replicas] + ids_bucket = ids_bucket[self.rank::self.num_replicas] # batching for j in range(len(ids_bucket) // self.batch_size): batch = [ bucket[idx] - for idx in ids_bucket[ - j * self.batch_size : (j + 1) * self.batch_size - ] + for idx in ids_bucket[j * self.batch_size:(j + 1) * + self.batch_size] ] batches.append(batch) diff --git a/wetts/vits/export_onnx.py b/wetts/vits/export_onnx.py index 110e081..d499a78 100644 --- a/wetts/vits/export_onnx.py +++ b/wetts/vits/export_onnx.py @@ -17,8 +17,8 @@ import torch -from models import SynthesizerTrn -import utils +from model.models import SynthesizerTrn +from utils import task def get_args(): @@ -26,7 +26,9 @@ def get_args(): parser.add_argument("--checkpoint", required=True, help="checkpoint") parser.add_argument("--cfg", required=True, help="config file") parser.add_argument("--onnx_model", required=True, help="onnx model") - parser.add_argument("--phone_table", required=True, help="input phone dict") + parser.add_argument("--phone_table", + required=True, + help="input phone dict") parser.add_argument("--speaker_table", default=None, help="speaker table") parser.add_argument( "--providers", @@ -43,20 +45,18 @@ def main(): args = get_args() os.environ["CUDA_VISIBLE_DEVICES"] = "0" - hps = utils.get_hparams_from_file(args.cfg) + hps = task.get_hparams_from_file(args.cfg) hps['model']['is_onnx'] = True phone_num = len(open(args.phone_table).readlines()) num_speakers = len(open(args.speaker_table).readlines()) - net_g = SynthesizerTrn( - phone_num, - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=num_speakers, - **hps.model - ) - utils.load_checkpoint(args.checkpoint, net_g, None) + net_g = SynthesizerTrn(phone_num, + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=num_speakers, + **hps.model) + task.load_checkpoint(args.checkpoint, net_g, None) net_g.flow.remove_weight_norm() net_g.dec.remove_weight_norm() net_g.forward = net_g.export_forward @@ -77,11 +77,24 @@ def main(): input_names=["input", "input_lengths", "scales", "sid"], output_names=["output"], dynamic_axes={ - "input": {0: "batch", 1: "phonemes"}, - "input_lengths": {0: "batch"}, - "scales": {0: "batch"}, - "sid": {0: "batch"}, - "output": {0: "batch", 1: "audio", 2: "audio_length"}, + "input": { + 0: "batch", + 1: "phonemes" + }, + "input_lengths": { + 0: "batch" + }, + "scales": { + 0: "batch" + }, + "sid": { + 0: "batch" + }, + "output": { + 0: "batch", + 1: "audio", + 2: "audio_length" + }, }, opset_version=13, verbose=False, diff --git a/wetts/vits/inference.py b/wetts/vits/inference.py index c7d44fc..18e36f6 100644 --- a/wetts/vits/inference.py +++ b/wetts/vits/inference.py @@ -21,8 +21,8 @@ from scipy.io import wavfile import torch -from models import SynthesizerTrn -import utils +from model.models import SynthesizerTrn +from utils import task def get_args(): @@ -30,12 +30,15 @@ def get_args(): parser.add_argument("--checkpoint", required=True, help="checkpoint") parser.add_argument("--cfg", required=True, help="config file") parser.add_argument("--outdir", required=True, help="ouput directory") - parser.add_argument("--phone_table", required=True, help="input phone dict") + parser.add_argument("--phone_table", + required=True, + help="input phone dict") parser.add_argument("--speaker_table", default=True, help="speaker table") parser.add_argument("--test_file", required=True, help="test file") - parser.add_argument( - "--gpu", type=int, default=-1, help="gpu id for this local rank, -1 for cpu" - ) + parser.add_argument("--gpu", + type=int, + default=-1, + help="gpu id for this local rank, -1 for cpu") args = parser.parse_args() return args @@ -59,19 +62,17 @@ def main(): arr = line.strip().split() assert len(arr) == 2 speaker_dict[arr[0]] = int(arr[1]) - hps = utils.get_hparams_from_file(args.cfg) + hps = task.get_hparams_from_file(args.cfg) - net_g = SynthesizerTrn( - len(phone_dict), - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=len(speaker_dict), - **hps.model - ) + net_g = SynthesizerTrn(len(phone_dict), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=len(speaker_dict), + **hps.model) net_g = net_g.to(device) net_g.eval() - utils.load_checkpoint(args.checkpoint, net_g, None) + task.load_checkpoint(args.checkpoint, net_g, None) for line in open(args.test_file): audio_path, speaker, text = line.strip().split("|") @@ -84,25 +85,17 @@ def main(): x_length = torch.LongTensor([seq.size(0)]).to(device) sid = torch.LongTensor([sid]).to(device) st = time.time() - audio = ( - net_g.infer( - x, - x_length, - sid=sid, - noise_scale=0.667, - noise_scale_w=0.8, - length_scale=1, - )[0][0, 0] - .data.cpu() - .float() - .numpy() - ) + audio = (net_g.infer( + x, + x_length, + sid=sid, + noise_scale=0.667, + noise_scale_w=0.8, + length_scale=1, + )[0][0, 0].data.cpu().float().numpy()) audio *= 32767 / max(0.01, np.max(np.abs(audio))) * 0.6 - print( - "RTF {}".format( - (time.time() - st) / (audio.shape[0] / hps.data.sampling_rate) - ) - ) + print("RTF {}".format((time.time() - st) / + (audio.shape[0] / hps.data.sampling_rate))) sys.stdout.flush() audio = np.clip(audio, -32767.0, 32767.0) wavfile.write( diff --git a/wetts/vits/inference_onnx.py b/wetts/vits/inference_onnx.py index 10167ec..d6dd726 100644 --- a/wetts/vits/inference_onnx.py +++ b/wetts/vits/inference_onnx.py @@ -19,15 +19,12 @@ from scipy.io import wavfile import torch -import utils +from utils import task def to_numpy(tensor): - return ( - tensor.detach().cpu().numpy() - if tensor.requires_grad - else tensor.detach().numpy() - ) + return (tensor.detach().cpu().numpy() + if tensor.requires_grad else tensor.detach().numpy()) def get_args(): @@ -35,7 +32,9 @@ def get_args(): parser.add_argument("--onnx_model", required=True, help="onnx model") parser.add_argument("--cfg", required=True, help="config file") parser.add_argument("--outdir", required=True, help="ouput directory") - parser.add_argument("--phone_table", required=True, help="input phone dict") + parser.add_argument("--phone_table", + required=True, + help="input phone dict") parser.add_argument("--speaker_table", default=True, help="speaker table") parser.add_argument("--test_file", required=True, help="test file") parser.add_argument( @@ -61,9 +60,10 @@ def main(): arr = line.strip().split() assert len(arr) == 2 speaker_dict[arr[0]] = int(arr[1]) - hps = utils.get_hparams_from_file(args.cfg) + hps = task.get_hparams_from_file(args.cfg) - ort_sess = ort.InferenceSession(args.onnx_model, providers=[args.providers]) + ort_sess = ort.InferenceSession(args.onnx_model, + providers=[args.providers]) scales = torch.FloatTensor([0.667, 1.0, 0.8]) # make triton dynamic shape happy scales = scales.unsqueeze(0) diff --git a/wetts/vits/losses.py b/wetts/vits/losses.py index b1b263e..470abec 100644 --- a/wetts/vits/losses.py +++ b/wetts/vits/losses.py @@ -19,7 +19,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs): for dr, dg in zip(disc_real_outputs, disc_generated_outputs): dr = dr.float() dg = dg.float() - r_loss = torch.mean((1 - dr) ** 2) + r_loss = torch.mean((1 - dr)**2) g_loss = torch.mean(dg**2) loss += r_loss + g_loss r_losses.append(r_loss.item()) @@ -33,7 +33,7 @@ def generator_loss(disc_outputs): gen_losses = [] for dg in disc_outputs: dg = dg.float() - l = torch.mean((1 - dg) ** 2) + l = torch.mean((1 - dg)**2) gen_losses.append(l) loss += l @@ -52,7 +52,7 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): z_mask = z_mask.float() kl = logs_p - logs_q - 0.5 - kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2.0 * logs_p) kl = torch.sum(kl * z_mask) l = kl / torch.sum(z_mask) return l diff --git a/wetts/vits/attentions.py b/wetts/vits/model/attentions.py similarity index 81% rename from wetts/vits/attentions.py rename to wetts/vits/model/attentions.py index 68ca0b4..1648d95 100644 --- a/wetts/vits/attentions.py +++ b/wetts/vits/model/attentions.py @@ -4,22 +4,21 @@ from torch import nn from torch.nn import functional as F -import commons -from modules import LayerNorm +from model.normalization import LayerNorm +from utils import commons class Encoder(nn.Module): - def __init__( - self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - window_size=4, - **kwargs - ): + + def __init__(self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + **kwargs): super().__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels @@ -42,8 +41,7 @@ def __init__( n_heads, p_dropout=p_dropout, window_size=window_size, - ) - ) + )) self.norm_layers_1.append(LayerNorm(hidden_channels)) self.ffn_layers.append( FFN( @@ -52,8 +50,7 @@ def __init__( filter_channels, kernel_size, p_dropout=p_dropout, - ) - ) + )) self.norm_layers_2.append(LayerNorm(hidden_channels)) def forward(self, x, x_mask): @@ -72,18 +69,17 @@ def forward(self, x, x_mask): class Decoder(nn.Module): - def __init__( - self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - proximal_bias=False, - proximal_init=True, - **kwargs - ): + + def __init__(self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs): super().__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels @@ -110,14 +106,13 @@ def __init__( p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init, - ) - ) + )) self.norm_layers_0.append(LayerNorm(hidden_channels)) self.encdec_attn_layers.append( - MultiHeadAttention( - hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout - ) - ) + MultiHeadAttention(hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout)) self.norm_layers_1.append(LayerNorm(hidden_channels)) self.ffn_layers.append( FFN( @@ -127,8 +122,7 @@ def __init__( kernel_size, p_dropout=p_dropout, causal=True, - ) - ) + )) self.norm_layers_2.append(LayerNorm(hidden_channels)) def forward(self, x, x_mask, h, h_mask): @@ -137,8 +131,7 @@ def forward(self, x, x_mask, h, h_mask): h: encoder output """ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( - device=x.device, dtype=x.dtype - ) + device=x.device, dtype=x.dtype) encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) x = x * x_mask for i in range(self.n_layers): @@ -158,6 +151,7 @@ def forward(self, x, x_mask, h, h_mask): class MultiHeadAttention(nn.Module): + def __init__( self, channels, @@ -196,12 +190,10 @@ def __init__( rel_stddev = self.k_channels**-0.5 self.emb_rel_k = nn.Parameter( torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) - * rel_stddev - ) + * rel_stddev) self.emb_rel_v = nn.Parameter( torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) - * rel_stddev - ) + * rel_stddev) nn.init.xavier_uniform_(self.conv_q.weight) nn.init.xavier_uniform_(self.conv_k.weight) @@ -224,51 +216,50 @@ def forward(self, x, c, attn_mask=None): def attention(self, query, key, value, mask=None): # reshape [b, d, t] -> [b, n_h, t, d_k] b, d, t_s, t_t = (*key.size(), query.size(2)) - query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + query = query.view(b, self.n_heads, self.k_channels, + t_t).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, + t_s).transpose(2, 3) - scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + scores = torch.matmul(query / math.sqrt(self.k_channels), + key.transpose(-2, -1)) if self.window_size is not None: msg = "Relative attention is only available for self-attention." assert t_s == t_t, msg - key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + key_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_k, t_s) rel_logits = self._matmul_with_relative_keys( - query / math.sqrt(self.k_channels), key_relative_embeddings - ) - scores_local = self._relative_position_to_absolute_position(rel_logits) + query / math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position( + rel_logits) scores = scores + scores_local if self.proximal_bias: msg = "Proximal bias is only available for self-attention." assert t_s == t_t, msg scores = scores + self._attention_bias_proximal(t_s).to( - device=scores.device, dtype=scores.dtype - ) + device=scores.device, dtype=scores.dtype) if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) if self.block_length is not None: msg = "Local attention is only available for self-attention." assert t_s == t_t, msg block_mask = ( - torch.ones_like(scores) - .triu(-self.block_length) - .tril(self.block_length) - ) + torch.ones_like(scores).triu(-self.block_length).tril( + self.block_length)) scores = scores.masked_fill(block_mask == 0, -1e4) p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] p_attn = self.drop(p_attn) output = torch.matmul(p_attn, value) if self.window_size is not None: - relative_weights = self._absolute_position_to_relative_position(p_attn) + relative_weights = self._absolute_position_to_relative_position( + p_attn) value_relative_embeddings = self._get_relative_embeddings( - self.emb_rel_v, t_s - ) + self.emb_rel_v, t_s) output = output + self._matmul_with_relative_values( - relative_weights, value_relative_embeddings - ) - output = ( - output.transpose(2, 3).contiguous().view(b, d, t_t) - ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + relative_weights, value_relative_embeddings) + output = (output.transpose(2, 3).contiguous().view(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] return output, p_attn def _matmul_with_relative_values(self, x, y): @@ -298,13 +289,14 @@ def _get_relative_embeddings(self, relative_embeddings, length): if pad_length > 0: padded_relative_embeddings = F.pad( relative_embeddings, - commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], + [0, 0]]), ) else: padded_relative_embeddings = relative_embeddings - used_relative_embeddings = padded_relative_embeddings[ - :, slice_start_position:slice_end_position - ] + used_relative_embeddings = padded_relative_embeddings[:, + slice_start_position: + slice_end_position] return used_relative_embeddings def _relative_position_to_absolute_position(self, x): @@ -314,18 +306,18 @@ def _relative_position_to_absolute_position(self, x): """ batch, heads, length, _ = x.size() # Concat columns of pad to shift from relative to absolute indexing. - x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, + 1]])) # Concat extra elements so to add up to shape (len+1, 2*len-1). x_flat = x.view([batch, heads, length * 2 * length]) x_flat = F.pad( - x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) - ) + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, + length - 1]])) # Reshape and slice out the padded elements. - x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ - :, :, :length, length - 1 : - ] + x_final = x_flat.view([batch, heads, length + 1, + 2 * length - 1])[:, :, :length, length - 1:] return x_final def _absolute_position_to_relative_position(self, x): @@ -336,11 +328,13 @@ def _absolute_position_to_relative_position(self, x): batch, heads, length, _ = x.size() # padd along column x = F.pad( - x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) - ) + x, + commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, + length - 1]])) x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) # add 0's in the beginning that will skew the elements after reshape - x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] return x_final @@ -353,10 +347,12 @@ def _attention_bias_proximal(self, length): """ r = torch.arange(length, dtype=torch.float32) diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) - return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + return torch.unsqueeze( + torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) class FFN(nn.Module): + def __init__( self, in_channels, diff --git a/wetts/vits/model/decoder.py b/wetts/vits/model/decoder.py new file mode 100644 index 0000000..36dab19 --- /dev/null +++ b/wetts/vits/model/decoder.py @@ -0,0 +1,303 @@ +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm +from torch.nn.utils.parametrizations import weight_norm +from torchaudio.transforms import InverseSpectrogram + +from model.modules import LRELU_SLOPE +from model.normalization import LayerNorm +from utils.commons import init_weights, get_padding +from utils.stft import OnnxSTFT + + +class Generator(nn.Module): + + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, + upsample_initial_channel, + 7, + 1, + padding=3) + resblock = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class ResBlock1(torch.nn.Module): + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + )), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + )), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + )), + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + )), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + )), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + )), + ]) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + )), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + )), + ]) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class ConvNeXtLayer(nn.Module): + + def __init__(self, channels, h_channels, scale): + super().__init__() + self.dw_conv = nn.Conv1d( + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + ) + self.norm = LayerNorm(channels) + self.pw_conv1 = nn.Conv1d(channels, h_channels, 1) + self.pw_conv2 = nn.Conv1d(h_channels, channels, 1) + self.scale = nn.Parameter(torch.full(size=(1, channels, 1), + fill_value=scale), + requires_grad=True) + + def forward(self, x): + res = x + x = self.dw_conv(x) + x = self.norm(x) + x = self.pw_conv1(x) + x = F.gelu(x) + x = self.pw_conv2(x) + x = self.scale * x + x = res + x + return x + + +class VocosGenerator(nn.Module): + + def __init__(self, + in_channels, + channels, + h_channels, + out_channels, + num_layers, + istft_config, + gin_channels, + is_onnx=False): + super().__init__() + + self.pad = nn.ReflectionPad1d([1, 0]) + self.in_conv = nn.Conv1d(in_channels, + channels, + kernel_size=1, + padding=0) + self.cond = Conv1d(gin_channels, channels, 1) + self.norm_pre = LayerNorm(channels) + scale = 1 / num_layers + self.layers = nn.ModuleList([ + ConvNeXtLayer(channels, h_channels, scale) + for _ in range(num_layers) + ]) + self.norm_post = LayerNorm(channels) + self.out_conv = nn.Conv1d(channels, out_channels, kernel_size=1) + self.is_onnx = is_onnx + + if self.is_onnx: + self.stft = OnnxSTFT(filter_length=istft_config['n_fft'], + hop_length=istft_config['hop_length'], + win_length=istft_config['win_length']) + else: + self.istft = InverseSpectrogram(**istft_config) + + def forward(self, x, g=None): + x = self.pad(x) + x = self.in_conv(x) + self.cond(g) + x = self.norm_pre(x) + for layer in self.layers: + x = layer(x) + x = self.norm_post(x) + x = self.out_conv(x) + mag, phase = x.chunk(2, dim=1) + mag = mag.exp().clamp_max(max=1e2) + if self.is_onnx: + o = self.stft.inverse(mag, phase).to(x.device) + else: + s = mag * (phase.cos() + 1j * phase.sin()) + o = self.istft(s).unsqueeze(1) + return o + + def remove_weight_norm(self): + pass diff --git a/wetts/vits/model/discriminators.py b/wetts/vits/model/discriminators.py new file mode 100644 index 0000000..7b387ac --- /dev/null +++ b/wetts/vits/model/discriminators.py @@ -0,0 +1,144 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn import Conv1d, Conv2d +from torch.nn.utils import spectral_norm +from torch.nn.utils.parametrizations import weight_norm + +from model.modules import LRELU_SLOPE +from utils.commons import get_padding + + +class DiscriminatorP(nn.Module): + + def __init__(self, + period, + kernel_size=5, + stride=3, + use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + )), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + )), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + )), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + )), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + )), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) + for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/wetts/vits/model/duration_predictors.py b/wetts/vits/model/duration_predictors.py new file mode 100644 index 0000000..cf82b68 --- /dev/null +++ b/wetts/vits/model/duration_predictors.py @@ -0,0 +1,303 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from model.modules import Flip +from model.normalization import LayerNorm +from utils.transforms import piecewise_rational_quadratic_transform + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + )) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class ConvFlow(nn.Module): + + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, + kernel_size, + n_layers, + p_dropout=0.0) + self.proj = nn.Conv1d(filter_channels, + self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, + 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., :self.num_bins] / math.sqrt( + self.filter_channels) + unnormalized_heights = h[..., + self.num_bins:2 * self.num_bins] / math.sqrt( + self.filter_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins:] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class Log(nn.Module): + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class StochasticDurationPredictor(nn.Module): + + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=256, + ): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = Log() + self.flows = nn.ModuleList() + self.flows.append(ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append( + ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = DDSConv(filter_channels, + kernel_size, + n_layers=3, + p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append( + ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, + kernel_size, + n_layers=3, + p_dropout=p_dropout) + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, + x, + x_mask, + w=None, + g=None, + reverse=False, + noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = (torch.randn(w.size(0), 2, w.size(2)).to( + device=x.device, dtype=x.dtype) * x_mask) + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = (torch.sum(0.5 * (math.log(2 * math.pi) + + (z**2)) * x_mask, [1, 2]) - logdet_tot) + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = (torch.randn(x.size(0), 2, x.size(2)).to( + device=x.device, dtype=x.dtype) * noise_scale) + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, + gin_channels): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, + filter_channels, + kernel_size, + padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, + filter_channels, + kernel_size, + padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask diff --git a/wetts/vits/model/encoders.py b/wetts/vits/model/encoders.py new file mode 100644 index 0000000..e5da606 --- /dev/null +++ b/wetts/vits/model/encoders.py @@ -0,0 +1,94 @@ +import math + +import torch +from torch import nn + +import model.attentions as attentions +from model.modules import WN +from utils import commons + + +class TextEncoder(nn.Module): + + def __init__( + self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder(hidden_channels, filter_channels, + n_heads, n_layers, kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), + 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class PosteriorEncoder(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), + 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask diff --git a/wetts/vits/model/flows.py b/wetts/vits/model/flows.py new file mode 100644 index 0000000..7abc0bb --- /dev/null +++ b/wetts/vits/model/flows.py @@ -0,0 +1,116 @@ +import torch +from torch import nn + +from model.modules import Flip, WN + + +class ResidualCouplingBlock(nn.Module): + + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=256, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + )) + self.flows.append(Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + def remove_weight_norm(self): + for i, l in enumerate(self.flows): + if i % 2 == 0: + l.remove_weight_norm() + + +class ResidualCouplingLayer(nn.Module): + + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=256, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, + self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + def remove_weight_norm(self): + self.enc.remove_weight_norm() diff --git a/wetts/vits/model/models.py b/wetts/vits/model/models.py new file mode 100644 index 0000000..a372fad --- /dev/null +++ b/wetts/vits/model/models.py @@ -0,0 +1,264 @@ +import math +import time + +import torch +from torch import nn +import monotonic_align + +from model.decoder import Generator, VocosGenerator +from model.duration_predictors import StochasticDurationPredictor, DurationPredictor +from model.encoders import TextEncoder, PosteriorEncoder +from model.flows import ResidualCouplingBlock +from utils import commons + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + n_vocab, + spec_channels, + segment_size, + inter_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[8, 8, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[16, 16, 4, 4], + n_speakers=1, + gin_channels=256, + use_sdp=True, + vocoder_type="hifigan", + vocos_channels=512, + vocos_h_channels=1536, + vocos_out_channels=1026, + vocos_num_layers=8, + vocos_istft_config={ + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "center": True, + }, + is_onnx=False, + **kwargs): + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + self.use_sdp = use_sdp + + self.enc_p = TextEncoder( + n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ) + if vocoder_type == "vocos": + self.dec = VocosGenerator( + inter_channels, + vocos_channels, + vocos_h_channels, + vocos_out_channels, + vocos_num_layers, + vocos_istft_config, + gin_channels, + is_onnx, + ) + else: + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock(inter_channels, + hidden_channels, + 5, + 1, + 4, + gin_channels=gin_channels) + + if use_sdp: + self.dp = StochasticDurationPredictor(hidden_channels, + 192, + 3, + 0.5, + 4, + gin_channels=gin_channels) + else: + self.dp = DurationPredictor(hidden_channels, + 256, + 3, + 0.5, + gin_channels=gin_channels) + + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def forward(self, x, x_lengths, y, y_lengths, sid=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] + neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], + keepdim=True) # [b, 1, t_s] + neg_cent2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent4 = torch.sum(-0.5 * (m_p**2) * s_p_sq_r, [1], + keepdim=True) # [b, 1, t_s] + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze( + y_mask, -1) + attn = (monotonic_align.maximum_path( + neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()) + + w = attn.sum(2) + if self.use_sdp: + l_length = self.dp(x, x_mask, w, g=g) + l_length = l_length / torch.sum(x_mask) + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.dp(x, x_mask, g=g) + l_length = torch.sum( + (logw - logw_)**2, [1, 2]) / torch.sum(x_mask) # for averaging + + # expand prior + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, + 2)).transpose(1, 2) + logs_p = torch.matmul(attn.squeeze(1), + logs_p.transpose(1, 2)).transpose(1, 2) + + z_slice, ids_slice = commons.rand_slice_segments( + z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + return ( + o, + l_length, + attn, + ids_slice, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) + + def infer( + self, + x, + x_lengths, + sid=None, + noise_scale=1, + length_scale=1, + noise_scale_w=1.0, + max_len=None, + ): + t1 = time.time() + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + t2 = time.time() + if self.use_sdp: + logw = self.dp(x, + x_mask, + g=g, + reverse=True, + noise_scale=noise_scale_w) + else: + logw = self.dp(x, x_mask, g=g) + t3 = time.time() + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), + 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose( + 1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose( + 1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + t4 = time.time() + z = self.flow(z_p, y_mask, g=g, reverse=True) + t5 = time.time() + o = self.dec((z * y_mask)[:, :, :max_len], g=g) + t6 = time.time() + print("TextEncoder: {}s DurationPredictor: {}s Flow: {}s Decoder: {}s". + format( + round(t2 - t1, 3), + round(t3 - t2, 3), + round(t5 - t4, 3), + round(t6 - t5, 3), + )) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def export_forward(self, x, x_lengths, scales, sid): + # shape of scales: Bx3, make triton happy + audio, *_ = self.infer( + x, + x_lengths, + sid, + noise_scale=scales[0][0], + length_scale=scales[0][1], + noise_scale_w=scales[0][2], + ) + return audio + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): + g_src = self.emb_g(sid_src).unsqueeze(-1) + g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) diff --git a/wetts/vits/model/modules.py b/wetts/vits/model/modules.py new file mode 100644 index 0000000..27a63dc --- /dev/null +++ b/wetts/vits/model/modules.py @@ -0,0 +1,100 @@ +import torch +from torch import nn +from torch.nn.utils.parametrizations import weight_norm + +from utils import commons + +LRELU_SLOPE = 0.1 + + +class WN(nn.Module): + + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels, + p_dropout=0, + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size, ) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = nn.ModuleList() + self.res_skip_layers = nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] + + acts = commons.fused_add_tanh_sigmoid_multiply( + x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, :self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels:, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + nn.utils.remove_weight_norm(l) + + +class Flip(nn.Module): + + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x diff --git a/wetts/vits/model/normalization.py b/wetts/vits/model/normalization.py new file mode 100644 index 0000000..7bc33af --- /dev/null +++ b/wetts/vits/model/normalization.py @@ -0,0 +1,19 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class LayerNorm(nn.Module): + + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels, ), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) diff --git a/wetts/vits/models.py b/wetts/vits/models.py deleted file mode 100644 index 085ce9e..0000000 --- a/wetts/vits/models.py +++ /dev/null @@ -1,842 +0,0 @@ -import math -import time - -import torch -from torch import nn -from torch.nn import functional as F -from torch.nn import Conv1d, ConvTranspose1d, Conv2d -from torch.nn.utils import remove_weight_norm, spectral_norm -from torch.nn.utils.parametrizations import weight_norm -from torchaudio.transforms import InverseSpectrogram -import monotonic_align - -import commons -import modules -import attentions -from commons import init_weights, get_padding -from stft import OnnxSTFT - - -class StochasticDurationPredictor(nn.Module): - def __init__( - self, - in_channels, - filter_channels, - kernel_size, - p_dropout, - n_flows=4, - gin_channels=256, - ): - super().__init__() - filter_channels = in_channels # it needs to be removed from future version. - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.log_flow = modules.Log() - self.flows = nn.ModuleList() - self.flows.append(modules.ElementwiseAffine(2)) - for i in range(n_flows): - self.flows.append( - modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) - ) - self.flows.append(modules.Flip()) - - self.post_pre = nn.Conv1d(1, filter_channels, 1) - self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.post_convs = modules.DDSConv( - filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout - ) - self.post_flows = nn.ModuleList() - self.post_flows.append(modules.ElementwiseAffine(2)) - for i in range(4): - self.post_flows.append( - modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) - ) - self.post_flows.append(modules.Flip()) - - self.pre = nn.Conv1d(in_channels, filter_channels, 1) - self.proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.convs = modules.DDSConv( - filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout - ) - self.cond = nn.Conv1d(gin_channels, filter_channels, 1) - - def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): - x = torch.detach(x) - x = self.pre(x) - g = torch.detach(g) - x = x + self.cond(g) - x = self.convs(x, x_mask) - x = self.proj(x) * x_mask - - if not reverse: - flows = self.flows - assert w is not None - - logdet_tot_q = 0 - h_w = self.post_pre(w) - h_w = self.post_convs(h_w, x_mask) - h_w = self.post_proj(h_w) * x_mask - e_q = ( - torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) - * x_mask - ) - z_q = e_q - for flow in self.post_flows: - z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) - logdet_tot_q += logdet_q - z_u, z1 = torch.split(z_q, [1, 1], 1) - u = torch.sigmoid(z_u) * x_mask - z0 = (w - u) * x_mask - logdet_tot_q += torch.sum( - (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] - ) - logq = ( - torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - - logdet_tot_q - ) - - logdet_tot = 0 - z0, logdet = self.log_flow(z0, x_mask) - logdet_tot += logdet - z = torch.cat([z0, z1], 1) - for flow in flows: - z, logdet = flow(z, x_mask, g=x, reverse=reverse) - logdet_tot = logdet_tot + logdet - nll = ( - torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - - logdet_tot - ) - return nll + logq # [b] - else: - flows = list(reversed(self.flows)) - flows = flows[:-2] + [flows[-1]] # remove a useless vflow - z = ( - torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) - * noise_scale - ) - for flow in flows: - z = flow(z, x_mask, g=x, reverse=reverse) - z0, z1 = torch.split(z, [1, 1], 1) - logw = z0 - return logw - - -class DurationPredictor(nn.Module): - def __init__( - self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels - ): - super().__init__() - - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.gin_channels = gin_channels - - self.drop = nn.Dropout(p_dropout) - self.conv_1 = nn.Conv1d( - in_channels, filter_channels, kernel_size, padding=kernel_size // 2 - ) - self.norm_1 = modules.LayerNorm(filter_channels) - self.conv_2 = nn.Conv1d( - filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 - ) - self.norm_2 = modules.LayerNorm(filter_channels) - self.proj = nn.Conv1d(filter_channels, 1, 1) - self.cond = nn.Conv1d(gin_channels, in_channels, 1) - - def forward(self, x, x_mask, g=None): - x = torch.detach(x) - g = torch.detach(g) - x = x + self.cond(g) - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.norm_1(x) - x = self.drop(x) - x = self.conv_2(x * x_mask) - x = torch.relu(x) - x = self.norm_2(x) - x = self.drop(x) - x = self.proj(x * x_mask) - return x * x_mask - - -class TextEncoder(nn.Module): - def __init__( - self, - n_vocab, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - ): - super().__init__() - self.n_vocab = n_vocab - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - - self.emb = nn.Embedding(n_vocab, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) - - self.encoder = attentions.Encoder( - hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths): - x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] - x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( - x.dtype - ) - - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - - m, logs = torch.split(stats, self.out_channels, dim=1) - return x, m, logs, x_mask - - -class ResidualCouplingBlock(nn.Module): - def __init__( - self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=256, - ): - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.flows = nn.ModuleList() - for i in range(n_flows): - self.flows.append( - modules.ResidualCouplingLayer( - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=gin_channels, - mean_only=True, - ) - ) - self.flows.append(modules.Flip()) - - def forward(self, x, x_mask, g=None, reverse=False): - if not reverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, reverse=reverse) - else: - for flow in reversed(self.flows): - x = flow(x, x_mask, g=g, reverse=reverse) - return x - - def remove_weight_norm(self): - for i, l in enumerate(self.flows): - if i % 2 == 0: - l.remove_weight_norm() - - -class PosteriorEncoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN( - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=gin_channels, - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths, g=None): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( - x.dtype - ) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - -class Generator(torch.nn.Module): - def __init__( - self, - initial_channel, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels, - ): - super(Generator, self).__init__() - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.conv_pre = Conv1d( - initial_channel, upsample_initial_channel, 7, 1, padding=3 - ) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - upsample_initial_channel // (2**i), - upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate( - zip(resblock_kernel_sizes, resblock_dilation_sizes) - ): - self.resblocks.append(resblock(ch, k, d)) - - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) - - self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - - def forward(self, x, g=None): - x = self.conv_pre(x) - x = x + self.cond(g) - - for i in range(self.num_upsamples): - x = F.leaky_relu(x, modules.LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - -class ConvNeXtLayer(nn.Module): - def __init__(self, channels, h_channels, scale): - super().__init__() - self.dw_conv = nn.Conv1d( - channels, - channels, - kernel_size=3, - padding=1, - groups=channels, - ) - self.norm = modules.LayerNorm(channels) - self.pw_conv1 = nn.Conv1d(channels, h_channels, 1) - self.pw_conv2 = nn.Conv1d(h_channels, channels, 1) - self.scale = nn.Parameter( - torch.full(size=(1, channels, 1), fill_value=scale), requires_grad=True - ) - - def forward(self, x): - res = x - x = self.dw_conv(x) - x = self.norm(x) - x = self.pw_conv1(x) - x = F.gelu(x) - x = self.pw_conv2(x) - x = self.scale * x - x = res + x - return x - -class VocosGenerator(nn.Module): - def __init__( - self, - in_channels, - channels, - h_channels, - out_channels, - num_layers, - istft_config, - gin_channels, - is_onnx=False - ): - super().__init__() - - self.pad = nn.ReflectionPad1d([1, 0]) - self.in_conv = nn.Conv1d(in_channels, channels, kernel_size=1, padding=0) - self.cond = Conv1d(gin_channels, channels, 1) - self.norm_pre = modules.LayerNorm(channels) - scale = 1 / num_layers - self.layers = nn.ModuleList( - [ConvNeXtLayer(channels, h_channels, scale) for _ in range(num_layers)] - ) - self.norm_post = modules.LayerNorm(channels) - self.out_conv = nn.Conv1d(channels, out_channels, kernel_size=1) - self.is_onnx = is_onnx - - if self.is_onnx: - self.stft = OnnxSTFT(filter_length=istft_config['n_fft'], - hop_length=istft_config['hop_length'], - win_length=istft_config['win_length']) - else: - self.istft = InverseSpectrogram(**istft_config) - - def forward(self, x, g=None): - x = self.pad(x) - x = self.in_conv(x) + self.cond(g) - x = self.norm_pre(x) - for layer in self.layers: - x = layer(x) - x = self.norm_post(x) - x = self.out_conv(x) - mag, phase = x.chunk(2, dim=1) - mag = mag.exp().clamp_max(max=1e2) - if self.is_onnx: - o = self.stft.inverse(mag, phase).to(x.device) - else: - s = mag * (phase.cos() + 1j * phase.sin()) - o = self.istft(s).unsqueeze(1) - return o - - def remove_weight_norm(self): - pass - -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f( - Conv2d( - 1, - 32, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 32, - 128, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 128, - 512, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 512, - 1024, - (kernel_size, 1), - (stride, 1), - padding=(get_padding(kernel_size, 1), 0), - ) - ), - norm_f( - Conv2d( - 1024, - 1024, - (kernel_size, 1), - 1, - padding=(get_padding(kernel_size, 1), 0), - ) - ), - ] - ) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ] - ) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(MultiPeriodDiscriminator, self).__init__() - periods = [2, 3, 5, 7, 11] - - discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] - discs = discs + [ - DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods - ] - self.discriminators = nn.ModuleList(discs) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - y_d_gs.append(y_d_g) - fmap_rs.append(fmap_r) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class SynthesizerTrn(nn.Module): - """ - Synthesizer for Training - """ - - def __init__( - self, - n_vocab, - spec_channels, - segment_size, - inter_channels=192, - hidden_channels=192, - filter_channels=768, - n_heads=2, - n_layers=6, - kernel_size=3, - p_dropout=0.1, - resblock="1", - resblock_kernel_sizes=[3, 7, 11], - resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], - upsample_rates=[8, 8, 2, 2], - upsample_initial_channel=512, - upsample_kernel_sizes=[16, 16, 4, 4], - n_speakers=1, - gin_channels=256, - use_sdp=True, - vocoder_type="hifigan", - vocos_channels=512, - vocos_h_channels=1536, - vocos_out_channels=1026, - vocos_num_layers=8, - vocos_istft_config={ - "n_fft": 1024, - "hop_length": 256, - "win_length": 1024, - "center": True, - }, - is_onnx=False, - **kwargs - ): - super().__init__() - self.n_vocab = n_vocab - self.spec_channels = spec_channels - self.inter_channels = inter_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.resblock = resblock - self.resblock_kernel_sizes = resblock_kernel_sizes - self.resblock_dilation_sizes = resblock_dilation_sizes - self.upsample_rates = upsample_rates - self.upsample_initial_channel = upsample_initial_channel - self.upsample_kernel_sizes = upsample_kernel_sizes - self.segment_size = segment_size - self.n_speakers = n_speakers - self.gin_channels = gin_channels - self.use_sdp = use_sdp - - self.enc_p = TextEncoder( - n_vocab, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - ) - if vocoder_type == "vocos": - self.dec = VocosGenerator( - inter_channels, - vocos_channels, - vocos_h_channels, - vocos_out_channels, - vocos_num_layers, - vocos_istft_config, - gin_channels, - is_onnx, - ) - else: - self.dec = Generator( - inter_channels, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=gin_channels, - ) - self.enc_q = PosteriorEncoder( - spec_channels, - inter_channels, - hidden_channels, - 5, - 1, - 16, - gin_channels=gin_channels, - ) - self.flow = ResidualCouplingBlock( - inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels - ) - - if use_sdp: - self.dp = StochasticDurationPredictor( - hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels - ) - else: - self.dp = DurationPredictor( - hidden_channels, 256, 3, 0.5, gin_channels=gin_channels - ) - - self.emb_g = nn.Embedding(n_speakers, gin_channels) - - def forward(self, x, x_lengths, y, y_lengths, sid=None): - x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) - z_p = self.flow(z, y_mask, g=g) - - with torch.no_grad(): - # negative cross-entropy - s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] - neg_cent1 = torch.sum( - -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True - ) # [b, 1, t_s] - neg_cent2 = torch.matmul( - -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r - ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] - neg_cent3 = torch.matmul( - z_p.transpose(1, 2), (m_p * s_p_sq_r) - ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] - neg_cent4 = torch.sum( - -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True - ) # [b, 1, t_s] - neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 - - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = ( - monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)) - .unsqueeze(1) - .detach() - ) - - w = attn.sum(2) - if self.use_sdp: - l_length = self.dp(x, x_mask, w, g=g) - l_length = l_length / torch.sum(x_mask) - else: - logw_ = torch.log(w + 1e-6) * x_mask - logw = self.dp(x, x_mask, g=g) - l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum( - x_mask - ) # for averaging - - # expand prior - m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) - logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) - - z_slice, ids_slice = commons.rand_slice_segments( - z, y_lengths, self.segment_size - ) - o = self.dec(z_slice, g=g) - return ( - o, - l_length, - attn, - ids_slice, - x_mask, - y_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - ) - - def infer( - self, - x, - x_lengths, - sid=None, - noise_scale=1, - length_scale=1, - noise_scale_w=1.0, - max_len=None, - ): - t1 = time.time() - x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - t2 = time.time() - if self.use_sdp: - logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) - else: - logw = self.dp(x, x_mask, g=g) - t3 = time.time() - w = torch.exp(logw) * x_mask * length_scale - w_ceil = torch.ceil(w) - y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to( - x_mask.dtype - ) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = commons.generate_path(w_ceil, attn_mask) - - m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose( - 1, 2 - ) # [b, t', t], [b, t, d] -> [b, d, t'] - logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose( - 1, 2 - ) # [b, t', t], [b, t, d] -> [b, d, t'] - - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale - t4 = time.time() - z = self.flow(z_p, y_mask, g=g, reverse=True) - t5 = time.time() - o = self.dec((z * y_mask)[:, :, :max_len], g=g) - t6 = time.time() - print( - "TextEncoder: {}s DurationPredictor: {}s Flow: {}s Decoder: {}s".format( - round(t2 - t1, 3), - round(t3 - t2, 3), - round(t5 - t4, 3), - round(t6 - t5, 3), - ) - ) - return o, attn, y_mask, (z, z_p, m_p, logs_p) - - def export_forward(self, x, x_lengths, scales, sid): - # shape of scales: Bx3, make triton happy - audio, *_ = self.infer( - x, - x_lengths, - sid, - noise_scale=scales[0][0], - length_scale=scales[0][1], - noise_scale_w=scales[0][2], - ) - return audio - - def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): - g_src = self.emb_g(sid_src).unsqueeze(-1) - g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) - z_p = self.flow(z, y_mask, g=g_src) - z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) - o_hat = self.dec(z_hat * y_mask, g=g_tgt) - return o_hat, y_mask, (z, z_p, z_hat) diff --git a/wetts/vits/modules.py b/wetts/vits/modules.py deleted file mode 100644 index 8007c0e..0000000 --- a/wetts/vits/modules.py +++ /dev/null @@ -1,511 +0,0 @@ -import math - -import torch -from torch import nn -from torch.nn import functional as F -from torch.nn import Conv1d -from torch.nn.utils import remove_weight_norm -from torch.nn.utils.parametrizations import weight_norm - -import commons -from commons import init_weights, get_padding -from transforms import piecewise_rational_quadratic_transform - -LRELU_SLOPE = 0.1 - - -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - - -class ConvReluNorm(nn.Module): - def __init__( - self, - in_channels, - hidden_channels, - out_channels, - kernel_size, - n_layers, - p_dropout, - ): - super().__init__() - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - assert n_layers > 1, "Number of layers should be larger than 0." - - self.conv_layers = nn.ModuleList() - self.norm_layers = nn.ModuleList() - self.conv_layers.append( - nn.Conv1d( - in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 - ) - ) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) - for _ in range(n_layers - 1): - self.conv_layers.append( - nn.Conv1d( - hidden_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2, - ) - ) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.proj = nn.Conv1d(hidden_channels, out_channels, 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward(self, x, x_mask): - x_org = x - for i in range(self.n_layers): - x = self.conv_layers[i](x * x_mask) - x = self.norm_layers[i](x) - x = self.relu_drop(x) - x = x_org + self.proj(x) - return x * x_mask - - -class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size**i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append( - nn.Conv1d( - channels, - channels, - kernel_size, - groups=channels, - dilation=dilation, - padding=padding, - ) - ) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask - - -class WN(torch.nn.Module): - def __init__( - self, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels, - p_dropout=0, - ): - super(WN, self).__init__() - assert kernel_size % 2 == 1 - self.hidden_channels = hidden_channels - self.kernel_size = (kernel_size,) - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - self.p_dropout = p_dropout - - self.in_layers = torch.nn.ModuleList() - self.res_skip_layers = torch.nn.ModuleList() - self.drop = nn.Dropout(p_dropout) - - cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) - self.cond_layer = weight_norm(cond_layer, name="weight") - - for i in range(n_layers): - dilation = dilation_rate**i - padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d( - hidden_channels, - 2 * hidden_channels, - kernel_size, - dilation=dilation, - padding=padding, - ) - in_layer = weight_norm(in_layer, name="weight") - self.in_layers.append(in_layer) - - # last one is not necessary - if i < n_layers - 1: - res_skip_channels = 2 * hidden_channels - else: - res_skip_channels = hidden_channels - - res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = weight_norm(res_skip_layer, name="weight") - self.res_skip_layers.append(res_skip_layer) - - def forward(self, x, x_mask, g=None, **kwargs): - output = torch.zeros_like(x) - n_channels_tensor = torch.IntTensor([self.hidden_channels]) - - g = self.cond_layer(g) - - for i in range(self.n_layers): - x_in = self.in_layers[i](x) - cond_offset = i * 2 * self.hidden_channels - g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] - - acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) - acts = self.drop(acts) - - res_skip_acts = self.res_skip_layers[i](acts) - if i < self.n_layers - 1: - res_acts = res_skip_acts[:, : self.hidden_channels, :] - x = (x + res_acts) * x_mask - output = output + res_skip_acts[:, self.hidden_channels :, :] - else: - output = output + res_skip_acts - return output * x_mask - - def remove_weight_norm(self): - torch.nn.utils.remove_weight_norm(self.cond_layer) - for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) - for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) - - -class ResBlock1(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock1, self).__init__() - self.convs1 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x, x_mask=None): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - if x_mask is not None: - xt = xt * x_mask - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - if x_mask is not None: - xt = xt * x_mask - xt = c2(xt) - x = xt + x - if x_mask is not None: - x = x * x_mask - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - -class ResBlock2(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3)): - super(ResBlock2, self).__init__() - self.convs = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - ] - ) - self.convs.apply(init_weights) - - def forward(self, x, x_mask=None): - for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) - if x_mask is not None: - xt = xt * x_mask - xt = c(xt) - x = xt + x - if x_mask is not None: - x = x * x_mask - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) - - -class Log(nn.Module): - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask - logdet = torch.sum(-y, [1, 2]) - return y, logdet - else: - x = torch.exp(x) * x_mask - return x - - -class Flip(nn.Module): - def forward(self, x, *args, reverse=False, **kwargs): - x = torch.flip(x, [1]) - if not reverse: - logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) - return x, logdet - else: - return x - - -class ElementwiseAffine(nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - self.m = nn.Parameter(torch.zeros(channels, 1)) - self.logs = nn.Parameter(torch.zeros(channels, 1)) - - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1, 2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x - - -class ResidualCouplingLayer(nn.Module): - def __init__( - self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - p_dropout=0, - gin_channels=256, - mean_only=False, - ): - assert channels % 2 == 0, "channels should be divisible by 2" - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.half_channels = channels // 2 - self.mean_only = mean_only - - self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = WN( - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - p_dropout=p_dropout, - gin_channels=gin_channels, - ) - self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) - self.post.weight.data.zero_() - self.post.bias.data.zero_() - - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels] * 2, 1) - h = self.pre(x0) * x_mask - h = self.enc(h, x_mask, g=g) - stats = self.post(h) * x_mask - if not self.mean_only: - m, logs = torch.split(stats, [self.half_channels] * 2, 1) - else: - m = stats - logs = torch.zeros_like(m) - - if not reverse: - x1 = m + x1 * torch.exp(logs) * x_mask - x = torch.cat([x0, x1], 1) - logdet = torch.sum(logs, [1, 2]) - return x, logdet - else: - x1 = (x1 - m) * torch.exp(-logs) * x_mask - x = torch.cat([x0, x1], 1) - return x - - def remove_weight_norm(self): - self.enc.remove_weight_norm() - - -class ConvFlow(nn.Module): - def __init__( - self, - in_channels, - filter_channels, - kernel_size, - n_layers, - num_bins=10, - tail_bound=5.0, - ): - super().__init__() - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.num_bins = num_bins - self.tail_bound = tail_bound - self.half_channels = in_channels // 2 - - self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) - self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) - self.proj = nn.Conv1d( - filter_channels, self.half_channels * (num_bins * 3 - 1), 1 - ) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels] * 2, 1) - h = self.pre(x0) - h = self.convs(h, x_mask, g=g) - h = self.proj(h) * x_mask - - b, c, t = x0.shape - h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] - - unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) - unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( - self.filter_channels - ) - unnormalized_derivatives = h[..., 2 * self.num_bins :] - - x1, logabsdet = piecewise_rational_quadratic_transform( - x1, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=reverse, - tails="linear", - tail_bound=self.tail_bound, - ) - - x = torch.cat([x0, x1], 1) * x_mask - logdet = torch.sum(logabsdet * x_mask, [1, 2]) - if not reverse: - return x, logdet - else: - return x diff --git a/wetts/vits/train.py b/wetts/vits/train.py index 50a43af..5b661b0 100644 --- a/wetts/vits/train.py +++ b/wetts/vits/train.py @@ -8,26 +8,23 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.cuda.amp import autocast, GradScaler -import commons -import utils from data_utils import ( TextAudioSpeakerLoader, TextAudioSpeakerCollate, DistributedBucketSampler, ) -from models import ( - SynthesizerTrn, - MultiPeriodDiscriminator, -) +from model.discriminators import MultiPeriodDiscriminator +from model.models import SynthesizerTrn from losses import generator_loss, discriminator_loss, feature_loss, kl_loss -from mel_processing import mel_spectrogram_torch, spec_to_mel_torch +from utils import commons, task +from utils.mel_processing import mel_spectrogram_torch, spec_to_mel_torch torch.backends.cudnn.benchmark = False global_step = 0 def main(): - hps = utils.get_hparams() + hps = task.get_hparams() torch.manual_seed(hps.train.seed) global global_step world_size = int(os.environ.get('WORLD_SIZE', 1)) @@ -36,10 +33,11 @@ def main(): torch.torch.cuda.set_device(local_rank) dist.init_process_group("nccl") if rank == 0: - logger = utils.get_logger(hps.model_dir) + logger = task.get_logger(hps.model_dir) logger.info(hps) writer = SummaryWriter(log_dir=hps.model_dir) - writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) + writer_eval = SummaryWriter( + log_dir=os.path.join(hps.model_dir, "eval")) train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data) train_sampler = DistributedBucketSampler( train_dataset, @@ -59,7 +57,8 @@ def main(): batch_sampler=train_sampler, ) if rank == 0: - eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data) + eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, + hps.data) eval_loader = DataLoader( eval_dataset, num_workers=8, @@ -70,13 +69,11 @@ def main(): collate_fn=collate_fn, ) - net_g = SynthesizerTrn( - hps.data.num_phones, - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **hps.model - ).cuda(rank) + net_g = SynthesizerTrn(hps.data.num_phones, + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model).cuda(rank) net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) optim_g = torch.optim.AdamW( net_g.parameters(), @@ -94,23 +91,21 @@ def main(): net_d = DDP(net_d, device_ids=[rank]) try: - _, _, _, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g - ) - _, _, _, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d - ) + _, _, _, epoch_str = task.load_checkpoint( + task.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, + optim_g) + _, _, _, epoch_str = task.load_checkpoint( + task.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, + optim_d) global_step = (epoch_str - 1) * len(train_loader) except Exception as e: epoch_str = 1 global_step = 0 scheduler_g = torch.optim.lr_scheduler.ExponentialLR( - optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 - ) + optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) scheduler_d = torch.optim.lr_scheduler.ExponentialLR( - optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 - ) + optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) scaler = GradScaler(enabled=hps.train.fp16_run) @@ -145,9 +140,8 @@ def main(): scheduler_d.step() -def train_and_evaluate( - rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers -): +def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, + loaders, logger, writers): net_g, net_d = nets optim_g, optim_d = optims scheduler_g, scheduler_d = schedulers @@ -161,23 +155,21 @@ def train_and_evaluate( net_g.train() net_d.train() for batch_idx, ( - x, - x_lengths, - spec, - spec_lengths, - y, - y_lengths, - speakers, + x, + x_lengths, + spec, + spec_lengths, + y, + y_lengths, + speakers, ) in enumerate(train_loader): x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda( - rank, non_blocking=True - ) - spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( - rank, non_blocking=True - ) + rank, non_blocking=True) + spec, spec_lengths = spec.cuda( + rank, non_blocking=True), spec_lengths.cuda(rank, + non_blocking=True) y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( - rank, non_blocking=True - ) + rank, non_blocking=True) speakers = speakers.cuda(rank, non_blocking=True) with autocast(enabled=hps.train.fp16_run): @@ -198,8 +190,7 @@ def train_and_evaluate( hps.data.sampling_rate, ) y_mel = commons.slice_segments( - mel, ids_slice, hps.train.segment_size // hps.data.hop_length - ) + mel, ids_slice, hps.train.segment_size // hps.data.hop_length) y_hat_mel = mel_spectrogram_torch( y_hat.squeeze(1), hps.data.filter_length, @@ -209,16 +200,14 @@ def train_and_evaluate( hps.data.win_length, ) - y = commons.slice_segments( - y, ids_slice * hps.data.hop_length, hps.train.segment_size - ) # slice + y = commons.slice_segments(y, ids_slice * hps.data.hop_length, + hps.train.segment_size) # slice # Discriminator y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) with autocast(enabled=False): loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( - y_d_hat_r, y_d_hat_g - ) + y_d_hat_r, y_d_hat_g) loss_disc_all = loss_disc optim_d.zero_grad() scaler.scale(loss_disc_all).backward() @@ -232,7 +221,8 @@ def train_and_evaluate( with autocast(enabled=False): loss_dur = torch.sum(l_length.float()) loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel - loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, + z_mask) * hps.train.c_kl loss_fm = feature_loss(fmap_r, fmap_g) loss_gen, losses_gen = generator_loss(y_d_hat_g) @@ -247,13 +237,13 @@ def train_and_evaluate( if rank == 0: if global_step % hps.train.log_interval == 0: lr = optim_g.param_groups[0]["lr"] - losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl] - logger.info( - "Train Epoch: {} [{:.0f}%]".format( - epoch, 100.0 * batch_idx / len(train_loader) - ) - ) - logger.info([round(x.item(), 5) for x in losses] + [global_step, lr]) + losses = [ + loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl + ] + logger.info("Train Epoch: {} [{:.0f}%]".format( + epoch, 100.0 * batch_idx / len(train_loader))) + logger.info([round(x.item(), 5) + for x in losses] + [global_step, lr]) scalar_dict = { "loss/g/total": loss_gen_all, @@ -262,39 +252,39 @@ def train_and_evaluate( "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g, } - scalar_dict.update( - { - "loss/g/fm": loss_fm, - "loss/g/mel": loss_mel, - "loss/g/dur": loss_dur, - "loss/g/kl": loss_kl, - } - ) - - scalar_dict.update( - {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} - ) - scalar_dict.update( - {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} - ) - scalar_dict.update( - {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} - ) + scalar_dict.update({ + "loss/g/fm": loss_fm, + "loss/g/mel": loss_mel, + "loss/g/dur": loss_dur, + "loss/g/kl": loss_kl, + }) + + scalar_dict.update({ + "loss/g/{}".format(i): v + for i, v in enumerate(losses_gen) + }) + scalar_dict.update({ + "loss/d_r/{}".format(i): v + for i, v in enumerate(losses_disc_r) + }) + scalar_dict.update({ + "loss/d_g/{}".format(i): v + for i, v in enumerate(losses_disc_g) + }) image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy( - y_mel[0].data.cpu().numpy() - ), - "slice/mel_gen": utils.plot_spectrogram_to_numpy( - y_hat_mel[0].data.cpu().numpy() - ), - "all/mel": utils.plot_spectrogram_to_numpy( - mel[0].data.cpu().numpy() - ), - "all/attn": utils.plot_alignment_to_numpy( - attn[0, 0].data.cpu().numpy() - ), + "slice/mel_org": + task.plot_spectrogram_to_numpy( + y_mel[0].data.cpu().numpy()), + "slice/mel_gen": + task.plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy()), + "all/mel": + task.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), + "all/attn": + task.plot_alignment_to_numpy(attn[0, + 0].data.cpu().numpy()), } - utils.summarize( + task.summarize( writer=writer, global_step=global_step, images=image_dict, @@ -303,19 +293,21 @@ def train_and_evaluate( if global_step % hps.train.eval_interval == 0: evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint( + task.save_checkpoint( net_g, optim_g, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), + os.path.join(hps.model_dir, + "G_{}.pth".format(global_step)), ) - utils.save_checkpoint( + task.save_checkpoint( net_d, optim_d, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), + os.path.join(hps.model_dir, + "D_{}.pth".format(global_step)), ) global_step += 1 @@ -327,13 +319,13 @@ def evaluate(hps, generator, eval_loader, writer_eval): generator.eval() with torch.no_grad(): for batch_idx, ( - x, - x_lengths, - spec, - spec_lengths, - y, - y_lengths, - speakers, + x, + x_lengths, + spec, + spec_lengths, + y, + y_lengths, + speakers, ) in enumerate(eval_loader): x, x_lengths = x.cuda(0), x_lengths.cuda(0) spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0) @@ -349,9 +341,10 @@ def evaluate(hps, generator, eval_loader, writer_eval): y_lengths = y_lengths[:1] speakers = speakers[:1] break - y_hat, attn, mask, *_ = generator.module.infer( - x, x_lengths, speakers, max_len=1000 - ) + y_hat, attn, mask, *_ = generator.module.infer(x, + x_lengths, + speakers, + max_len=1000) y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length mel = spec_to_mel_torch( @@ -369,16 +362,15 @@ def evaluate(hps, generator, eval_loader, writer_eval): hps.data.win_length, ) image_dict = { - "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) + "gen/mel": task.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) } - audio_dict = {"gen/audio": y_hat[0, :, : y_hat_lengths[0]]} + audio_dict = {"gen/audio": y_hat[0, :, :y_hat_lengths[0]]} if global_step == 0: image_dict.update( - {"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())} - ) - audio_dict.update({"gt/audio": y[0, :, : y_lengths[0]]}) + {"gt/mel": task.plot_spectrogram_to_numpy(mel[0].cpu().numpy())}) + audio_dict.update({"gt/audio": y[0, :, :y_lengths[0]]}) - utils.summarize( + task.summarize( writer=writer_eval, global_step=global_step, images=image_dict, diff --git a/wetts/vits/commons.py b/wetts/vits/utils/commons.py similarity index 82% rename from wetts/vits/commons.py rename to wetts/vits/utils/commons.py index 1a30a33..943889c 100644 --- a/wetts/vits/commons.py +++ b/wetts/vits/utils/commons.py @@ -22,9 +22,8 @@ def convert_pad_shape(pad_shape): def kl_divergence(m_p, logs_p, m_q, logs_q): """KL(P||Q)""" kl = (logs_q - logs_p) - 0.5 - kl += ( - 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) - ) + kl += (0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q)**2)) * + torch.exp(-2.0 * logs_q)) return kl @@ -53,20 +52,23 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4): if x_lengths is None: x_lengths = t ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ids_str = (torch.rand([b]).to(device=x.device) * + ids_str_max).to(dtype=torch.long) ret = slice_segments(x, ids_str, segment_size) return ret, ids_str -def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): +def get_timing_signal_1d(length, + channels, + min_timescale=1.0, + max_timescale=1.0e4): position = torch.arange(length, dtype=torch.float) num_timescales = channels // 2 - log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( - num_timescales - 1 - ) + log_timescale_increment = math.log( + float(max_timescale) / float(min_timescale)) / (num_timescales - 1) inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment - ) + torch.arange(num_timescales, dtype=torch.float) * + -log_timescale_increment) scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) signal = F.pad(signal, [0, 0, 0, channels % 2]) @@ -76,13 +78,15 @@ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + signal = get_timing_signal_1d(length, channels, min_timescale, + max_timescale) return x + signal.to(dtype=x.dtype, device=x.device) def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + signal = get_timing_signal_1d(length, channels, min_timescale, + max_timescale) return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) @@ -126,7 +130,8 @@ def generate_path(duration, mask): cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0] + ]))[:, :-1] path = path.unsqueeze(1).transpose(2, 3) * mask return path @@ -142,8 +147,8 @@ def clip_grad_value_(parameters, clip_value, norm_type=2): total_norm = 0 for p in parameters: param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type + total_norm += param_norm.item()**norm_type if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1.0 / norm_type) + total_norm = total_norm**(1.0 / norm_type) return total_norm diff --git a/wetts/vits/mel_processing.py b/wetts/vits/utils/mel_processing.py similarity index 76% rename from wetts/vits/mel_processing.py rename to wetts/vits/utils/mel_processing.py index 7655c9d..28bab73 100644 --- a/wetts/vits/mel_processing.py +++ b/wetts/vits/utils/mel_processing.py @@ -38,7 +38,12 @@ def spectral_de_normalize_torch(magnitudes): hann_window = {} -def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): +def spectrogram_torch(y, + n_fft, + sampling_rate, + hop_size, + win_size, + center=False): if torch.min(y) < -1.0: print("min value is ", torch.min(y)) if torch.max(y) > 1.0: @@ -49,8 +54,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) wnsize_dtype_device = str(win_size) + "_" + dtype_device if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( - dtype=y.dtype, device=y.device - ) + dtype=y.dtype, device=y.device) y = F.pad( y.unsqueeze(1), @@ -82,18 +86,22 @@ def spec_to_mel_torch(spec, n_fft, n_mels, sampling_rate): dtype_device = str(spec.dtype) + "_" + str(spec.device) if dtype_device not in mel_basis: mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels) - mel_basis[dtype_device] = torch.from_numpy(mel).to( - dtype=spec.dtype, device=spec.device - ) + mel_basis[dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, + device=spec.device) spec = torch.matmul(mel_basis[dtype_device], spec) spec = spectral_normalize_torch(spec) return spec -def mel_spectrogram_torch( - y, n_fft, n_mels, sampling_rate, hop_size, win_size, center=False -): - spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center) +def mel_spectrogram_torch(y, + n_fft, + n_mels, + sampling_rate, + hop_size, + win_size, + center=False): + spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, + center) spec = spec_to_mel_torch(spec, n_fft, n_mels, sampling_rate) return spec diff --git a/wetts/vits/stft.py b/wetts/vits/utils/stft.py similarity index 79% rename from wetts/vits/stft.py rename to wetts/vits/utils/stft.py index 1c94c48..f05aa28 100644 --- a/wetts/vits/stft.py +++ b/wetts/vits/utils/stft.py @@ -74,22 +74,25 @@ def window_sumsquare( # Compute the squared window at the desired length win_sq = get_window(window, win_length, fftbins=True) - win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.normalize(win_sq, norm=norm)**2 win_sq = librosa_util.pad_center(win_sq, n_fft) # Fill the envelope for i in range(n_frames): sample = i * hop_length - x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + x[sample:min(n, sample + + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] return x class STFT(torch.nn.Module): """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" - def __init__( - self, filter_length=800, hop_length=200, win_length=800, window="hann" - ): + def __init__(self, + filter_length=800, + hop_length=200, + win_length=800, + window="hann"): super(STFT, self).__init__() self.filter_length = filter_length self.hop_length = hop_length @@ -100,14 +103,14 @@ def __init__( fourier_basis = np.fft.fft(np.eye(self.filter_length)) cutoff = int((self.filter_length / 2 + 1)) - fourier_basis = np.vstack( - [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] - ) + fourier_basis = np.vstack([ + np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :]) + ]) forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) inverse_basis = torch.FloatTensor( - np.linalg.pinv(scale * fourier_basis).T[:, None, :] - ) + np.linalg.pinv(scale * fourier_basis).T[:, None, :]) if window is not None: assert filter_length >= win_length @@ -150,14 +153,15 @@ def transform(self, input_data): imag_part = forward_transform[:, cutoff:, :] magnitude = torch.sqrt(real_part**2 + imag_part**2) - phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) + phase = torch.autograd.Variable( + torch.atan2(imag_part.data, real_part.data)) return magnitude, phase def inverse(self, magnitude, phase): recombine_magnitude_phase = torch.cat( - [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 - ) + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], + dim=1) inverse_transform = F.conv_transpose1d( recombine_magnitude_phase, @@ -177,25 +181,21 @@ def inverse(self, magnitude, phase): ) # remove modulation effects approx_nonzero_indices = torch.from_numpy( - np.where(window_sum > tiny(window_sum))[0] - ) - window_sum = torch.autograd.Variable( - torch.from_numpy(window_sum), requires_grad=False - ) - window_sum = ( - window_sum.to(inverse_transform.device()) - if magnitude.is_cuda - else window_sum - ) + np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.autograd.Variable(torch.from_numpy(window_sum), + requires_grad=False) + window_sum = (window_sum.to(inverse_transform.device()) + if magnitude.is_cuda else window_sum) inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ - approx_nonzero_indices - ] + approx_nonzero_indices] # scale by hop ratio inverse_transform *= float(self.filter_length) / self.hop_length - inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] - inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, + int(self.filter_length / 2):] + inverse_transform = inverse_transform[:, :, :-int(self.filter_length / + 2):] return inverse_transform @@ -206,16 +206,18 @@ def forward(self, input_data): class TorchSTFT(torch.nn.Module): - def __init__( - self, filter_length=800, hop_length=200, win_length=800, window="hann" - ): + + def __init__(self, + filter_length=800, + hop_length=200, + win_length=800, + window="hann"): super().__init__() self.filter_length = filter_length self.hop_length = hop_length self.win_length = win_length self.window = torch.from_numpy( - get_window(window, win_length, fftbins=True).astype(np.float32) - ) + get_window(window, win_length, fftbins=True).astype(np.float32)) def transform(self, input_data): forward_transform = torch.stft( @@ -250,7 +252,11 @@ def forward(self, input_data): class OnnxSTFT(torch.nn.Module): """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" - def __init__(self, filter_length=800, hop_length=200, win_length=800, + + def __init__(self, + filter_length=800, + hop_length=200, + win_length=800, window='hann'): super(OnnxSTFT, self).__init__() self.filter_length = filter_length @@ -262,15 +268,17 @@ def __init__(self, filter_length=800, hop_length=200, win_length=800, fourier_basis = np.fft.fft(np.eye(self.filter_length)) cutoff = int((self.filter_length / 2 + 1)) - fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), - np.imag(fourier_basis[:cutoff, :])]) + fourier_basis = np.vstack([ + np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :]) + ]) forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) inverse_basis = torch.FloatTensor( np.linalg.pinv(scale * fourier_basis).T[:, None, :]) if window is not None: - assert(filter_length >= win_length) + assert (filter_length >= win_length) # get window and zero center pad it to filter_length fft_window = get_window(window, win_length, fftbins=True) fft_window = pad_center(fft_window, filter_length) @@ -297,11 +305,11 @@ def transform(self, input_data): mode='reflect') input_data = input_data.squeeze(1) - forward_transform = F.conv1d( - input_data, - Variable(self.forward_basis, requires_grad=False), - stride=self.hop_length, - padding=0) + forward_transform = F.conv1d(input_data, + Variable(self.forward_basis, + requires_grad=False), + stride=self.hop_length, + padding=0) cutoff = int((self.filter_length / 2) + 1) real_part = forward_transform[:, :cutoff, :] @@ -315,16 +323,19 @@ def transform(self, input_data): def inverse(self, magnitude, phase): recombine_magnitude_phase = torch.cat( - [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1) - - inverse_transform = F.conv_transpose1d( - recombine_magnitude_phase, - Variable(self.inverse_basis, requires_grad=False), - stride=self.hop_length, - padding=0) - - inverse_transform = inverse_transform[:, :, int(self.filter_length / 2):] - inverse_transform = inverse_transform[:, :, :-int(self.filter_length / 2):] + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], + dim=1) + + inverse_transform = F.conv_transpose1d(recombine_magnitude_phase, + Variable(self.inverse_basis, + requires_grad=False), + stride=self.hop_length, + padding=0) + + inverse_transform = inverse_transform[:, :, + int(self.filter_length / 2):] + inverse_transform = inverse_transform[:, :, :-int(self.filter_length / + 2):] return inverse_transform diff --git a/wetts/vits/utils.py b/wetts/vits/utils/task.py similarity index 84% rename from wetts/vits/utils.py rename to wetts/vits/utils/task.py index 774008f..afae44c 100644 --- a/wetts/vits/utils.py +++ b/wetts/vits/utils/task.py @@ -8,7 +8,6 @@ MATPLOTLIB_FLAG = False - logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, @@ -40,18 +39,16 @@ def load_checkpoint(checkpoint_path, model, optimizer=None): model.module.load_state_dict(new_state_dict) else: model.load_state_dict(new_state_dict) - logger.info( - "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) - ) + logger.info("Loaded checkpoint '{}' (iteration {})".format( + checkpoint_path, iteration)) return model, optimizer, learning_rate, iteration -def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): +def save_checkpoint(model, optimizer, learning_rate, iteration, + checkpoint_path): logger.info( "Saving model and optimizer state at iteration {} to {}".format( - iteration, checkpoint_path - ) - ) + iteration, checkpoint_path)) if hasattr(model, "module"): state_dict = model.module.state_dict() else: @@ -106,7 +103,10 @@ def plot_spectrogram_to_numpy(spectrogram): import numpy as np fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + im = ax.imshow(spectrogram, + aspect="auto", + origin="lower", + interpolation="none") plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") @@ -114,7 +114,7 @@ def plot_spectrogram_to_numpy(spectrogram): fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, )) plt.close() return data @@ -132,9 +132,10 @@ def plot_alignment_to_numpy(alignment, info=None): import numpy as np fig, ax = plt.subplots(figsize=(6, 4)) - im = ax.imshow( - alignment.transpose(), aspect="auto", origin="lower", interpolation="none" - ) + im = ax.imshow(alignment.transpose(), + aspect="auto", + origin="lower", + interpolation="none") fig.colorbar(im, ax=ax) xlabel = "Decoder timestep" if info is not None: @@ -145,7 +146,7 @@ def plot_alignment_to_numpy(alignment, info=None): fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, )) plt.close() return data @@ -165,10 +166,20 @@ def get_hparams(init=True): default="./configs/base.json", help="JSON file for configuration", ) - parser.add_argument("-m", "--model", type=str, required=True, help="Model name") - parser.add_argument("--train_data", type=str, required=True, help="train data") + parser.add_argument("-m", + "--model", + type=str, + required=True, + help="Model name") + parser.add_argument("--train_data", + type=str, + required=True, + help="train data") parser.add_argument("--val_data", type=str, required=True, help="val data") - parser.add_argument("--phone_table", type=str, required=True, help="phone table") + parser.add_argument("--phone_table", + type=str, + required=True, + help="phone table") parser.add_argument( "--speaker_table", type=str, @@ -196,10 +207,14 @@ def get_hparams(init=True): config["data"]["training_files"] = args.train_data config["data"]["validation_files"] = args.val_data config["data"]["phone_table"] = args.phone_table - phones = [line for line in open(args.phone_table).readlines() if line.strip()] + phones = [ + line for line in open(args.phone_table).readlines() if line.strip() + ] config["data"]["num_phones"] = len(phones) config["data"]["speaker_table"] = args.speaker_table - speakers = [line for line in open(args.speaker_table).readlines() if line.strip()] + speakers = [ + line for line in open(args.speaker_table).readlines() if line.strip() + ] config["data"]["n_speakers"] = len(speakers) hparams = HParams(**config) @@ -230,7 +245,8 @@ def get_logger(model_dir, filename="train.log"): logger = logging.getLogger(os.path.basename(model_dir)) logger.setLevel(logging.INFO) - formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + formatter = logging.Formatter( + "%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") if not os.path.exists(model_dir): os.makedirs(model_dir) h = logging.FileHandler(os.path.join(model_dir, filename)) @@ -241,6 +257,7 @@ def get_logger(model_dir, filename="train.log"): class HParams: + def __init__(self, **kwargs): for k, v in kwargs.items(): if type(v) == dict: diff --git a/wetts/vits/transforms.py b/wetts/vits/utils/transforms.py similarity index 84% rename from wetts/vits/transforms.py rename to wetts/vits/utils/transforms.py index 164939b..1b99fd7 100644 --- a/wetts/vits/transforms.py +++ b/wetts/vits/utils/transforms.py @@ -35,8 +35,7 @@ def piecewise_rational_quadratic_transform( min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, - **spline_kwargs - ) + **spline_kwargs) return outputs, logabsdet @@ -67,7 +66,9 @@ def unconstrained_rational_quadratic_spline( unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) constant = np.log(np.exp(1 - min_derivative) - 1) unnormalized_derivatives[..., 0] = constant - unnormalized_derivatives[..., unnormalized_derivatives.size(-1) - 1] = constant + unnormalized_derivatives[..., + unnormalized_derivatives.size(-1) - + 1] = constant outputs[outside_interval_mask] = inputs[outside_interval_mask] logabsdet[outside_interval_mask] = 0 @@ -81,7 +82,8 @@ def unconstrained_rational_quadratic_spline( inputs=inputs[inside_interval_mask], unnormalized_widths=unnormalized_widths[inside_interval_mask, :], unnormalized_heights=unnormalized_heights[inside_interval_mask, :], - unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[ + inside_interval_mask, :], inverse=inverse, left=-tail_bound, right=tail_bound, @@ -152,17 +154,17 @@ def rational_quadratic_spline( input_delta = delta.gather(-1, bin_idx)[..., 0] input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] - input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., + 0] input_heights = heights.gather(-1, bin_idx)[..., 0] if inverse: a = (inputs - input_cumheights) * ( - input_derivatives + input_derivatives_plus_one - 2 * input_delta - ) + input_heights * (input_delta - input_derivatives) + input_derivatives + input_derivatives_plus_one - 2 * + input_delta) + input_heights * (input_delta - input_derivatives) b = input_heights * input_derivatives - (inputs - input_cumheights) * ( - input_derivatives + input_derivatives_plus_one - 2 * input_delta - ) + input_derivatives + input_derivatives_plus_one - 2 * input_delta) c = -input_delta * (inputs - input_cumheights) discriminant = b.pow(2) - 4 * a * c @@ -174,34 +176,31 @@ def rational_quadratic_spline( theta_one_minus_theta = root * (1 - root) denominator = input_delta + ( (input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta - ) + * theta_one_minus_theta) derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * root.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - root).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + input_derivatives * + (1 - root).pow(2)) + logabsdet = torch.log( + derivative_numerator) - 2 * torch.log(denominator) return outputs, -logabsdet else: theta = (inputs - input_cumwidths) / input_bin_widths theta_one_minus_theta = theta * (1 - theta) - numerator = input_heights * ( - input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta - ) + numerator = input_heights * (input_delta * theta.pow(2) + + input_derivatives * theta_one_minus_theta) denominator = input_delta + ( (input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta - ) + * theta_one_minus_theta) outputs = input_cumheights + numerator / denominator derivative_numerator = input_delta.pow(2) * ( - input_derivatives_plus_one * theta.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - theta).pow(2) - ) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + input_derivatives * + (1 - theta).pow(2)) + logabsdet = torch.log( + derivative_numerator) - 2 * torch.log(denominator) return outputs, logabsdet