Skip to content

Commit

Permalink
remove unused commons functions
Browse files Browse the repository at this point in the history
  • Loading branch information
blaisewf authored Dec 29, 2024
1 parent 1f47fbb commit 4f0dfb8
Showing 1 changed file with 1 addition and 91 deletions.
92 changes: 1 addition & 91 deletions rvc/lib/algorithm/commons.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
import torch
from typing import List, Optional
from typing import Optional


def init_weights(m, mean=0.0, std=0.01):
Expand Down Expand Up @@ -40,23 +39,6 @@ def convert_pad_shape(pad_shape):
return pad_shape


def kl_divergence(m_p, logs_p, m_q, logs_q):
"""
Calculate the KL divergence between two distributions.
Args:
m_p: The mean of the first distribution.
logs_p: The log of the standard deviation of the first distribution.
m_q: The mean of the second distribution.
logs_q: The log of the standard deviation of the second distribution.
"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl


def slice_segments(
x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4, dim: int = 2
):
Expand Down Expand Up @@ -103,42 +85,6 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
return ret, ids_str


def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
"""
Generate a 1D timing signal.
Args:
length: The length of the signal.
channels: The number of channels of the signal.
min_timescale: The minimum timescale.
max_timescale: The maximum timescale.
"""
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = torch.nn.functional.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal


def subsequent_mask(length):
"""
Generate a subsequent mask.
Args:
length: The length of the sequence.
"""
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask


@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
"""
Expand All @@ -157,16 +103,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
return acts


def convert_pad_shape(pad_shape: List[List[int]]):
"""
Convert the pad shape to a list of integers.
Args:
pad_shape: The pad shape.
"""
return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist()


def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None):
"""
Generate a sequence mask.
Expand All @@ -179,29 +115,3 @@ def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None):
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)


def clip_grad_value(parameters, clip_value, norm_type=2):
"""
Clip the gradients of a list of parameters.
Args:
parameters: The list of parameters to clip.
clip_value: The maximum value of the gradients.
norm_type: The type of norm to use for clipping.
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)

total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm

0 comments on commit 4f0dfb8

Please sign in to comment.