Skip to content

Commit

Permalink
Merge pull request #922 from IAHispano/formatter/main
Browse files Browse the repository at this point in the history
chore(format): run black on main
  • Loading branch information
blaisewf authored Dec 22, 2024
2 parents c306f1c + b5eef92 commit 72e7503
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 50 deletions.
20 changes: 11 additions & 9 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def run_extract_script(
sample_rate,
embedder_model,
embedder_model_custom,
include_mutes
include_mutes,
],
),
]
Expand Down Expand Up @@ -525,7 +525,9 @@ def run_train_script(
from rvc.lib.tools.pretrained_selector import pretrained_selector

if custom_pretrained == False:
pg, pd = pretrained_selector(str(rvc_version), str(vocoder), True, int(sample_rate))
pg, pd = pretrained_selector(
str(rvc_version), str(vocoder), True, int(sample_rate)
)
else:
if g_pretrained_path is None or d_pretrained_path is None:
raise ValueError(
Expand Down Expand Up @@ -558,7 +560,7 @@ def run_train_script(
overtraining_threshold,
cleanup,
vocoder,
checkpointing
checkpointing,
],
),
]
Expand Down Expand Up @@ -1858,9 +1860,9 @@ def parse_arguments():
preprocess_parser.add_argument(
"--cut_preprocess",
type=str,
choices=['Skip', 'Simple', 'Automatic'],
choices=["Skip", "Simple", "Automatic"],
help="Cut the dataset into smaller segments for faster preprocessing.",
default='Automatic',
default="Automatic",
required=True,
)
preprocess_parser.add_argument(
Expand Down Expand Up @@ -1902,7 +1904,7 @@ def parse_arguments():
choices=[0.0, 0.1, 0.2, 0.3, 0.4],
default=0.3,
required=False,
)
)

# Parser for 'extract' mode
extract_parser = subparsers.add_parser(
Expand Down Expand Up @@ -1981,8 +1983,8 @@ def parse_arguments():
help="Number of silent files to include.",
choices=range(0, 11),
default=2,
required=True
)
required=True,
)

# Parser for 'train' mode
train_parser = subparsers.add_parser("train", help="Train an RVC model.")
Expand Down Expand Up @@ -2010,7 +2012,7 @@ def parse_arguments():
help="Enables memory-efficient training.",
default=False,
required=False,
)
)
train_parser.add_argument(
"--save_every_epoch",
type=int,
Expand Down
2 changes: 1 addition & 1 deletion rvc/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def setup_network(self):
use_f0=self.use_f0,
text_enc_hidden_dim=self.text_enc_hidden_dim,
is_half=False,
vocoder=self.vocoder
vocoder=self.vocoder,
)
del self.net_g.enc_q
self.net_g.load_state_dict(self.cpt["weight"], strict=False)
Expand Down
31 changes: 23 additions & 8 deletions rvc/lib/algorithm/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,41 @@ class MultiPeriodDiscriminator(torch.nn.Module):
Defaults to False.
"""

def __init__(self, version: str, use_spectral_norm: bool = False, checkpointing: bool = False):
def __init__(
self, version: str, use_spectral_norm: bool = False, checkpointing: bool = False
):
super(MultiPeriodDiscriminator, self).__init__()
periods = (
[2, 3, 5, 7, 11, 17] if version == "v1" else [2, 3, 5, 7, 11, 17, 23, 37]
)
self.checkpointing = checkpointing
self.discriminators = torch.nn.ModuleList(
[DiscriminatorS(use_spectral_norm=use_spectral_norm, checkpointing=checkpointing)]
+ [DiscriminatorP(p, use_spectral_norm=use_spectral_norm, checkpointing=checkpointing) for p in periods]
[
DiscriminatorS(
use_spectral_norm=use_spectral_norm, checkpointing=checkpointing
)
]
+ [
DiscriminatorP(
p, use_spectral_norm=use_spectral_norm, checkpointing=checkpointing
)
for p in periods
]
)

def forward(self, y, y_hat):
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
for d in self.discriminators:
if self.training and self.checkpointing:

def forward_discriminator(d, y, y_hat):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
return y_d_r, fmap_r, y_d_g, fmap_g
y_d_r, fmap_r, y_d_g, fmap_g = checkpoint.checkpoint(forward_discriminator, d, y, y_hat, use_reentrant=False)

y_d_r, fmap_r, y_d_g, fmap_g = checkpoint.checkpoint(
forward_discriminator, d, y, y_hat, use_reentrant=False
)
else:
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
Expand Down Expand Up @@ -82,8 +97,8 @@ def forward(self, x):
fmap = []
for conv in self.convs:
if self.training and self.checkpointing:
x = checkpoint.checkpoint(conv, x, use_reentrant = False)
x = checkpoint.checkpoint(self.lrelu, x, use_reentrant = False)
x = checkpoint.checkpoint(conv, x, use_reentrant=False)
x = checkpoint.checkpoint(self.lrelu, x, use_reentrant=False)
else:
x = self.lrelu(conv(x))
fmap.append(x)
Expand Down Expand Up @@ -153,8 +168,8 @@ def forward(self, x):

for conv in self.convs:
if self.training and self.checkpointing:
x = checkpoint.checkpoint(conv, x, use_reentrant = False)
x = checkpoint.checkpoint(self.lrelu, x, use_reentrant = False)
x = checkpoint.checkpoint(conv, x, use_reentrant=False)
x = checkpoint.checkpoint(self.lrelu, x, use_reentrant=False)
else:
x = self.lrelu(conv(x))
fmap.append(x)
Expand Down
4 changes: 3 additions & 1 deletion rvc/lib/algorithm/generators/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rvc.lib.algorithm.residuals import LRELU_SLOPE, ResBlock
from rvc.lib.algorithm.commons import init_weights


class HiFiGANGenerator(torch.nn.Module):
"""
HiFi-GAN Generator module for audio synthesis.
Expand Down Expand Up @@ -107,6 +108,7 @@ def remove_weight_norm(self):
for l in self.resblocks:
l.remove_weight_norm()


class SineGenerator(torch.nn.Module):
"""
Sine wave generator with optional harmonic overtones and noise.
Expand Down Expand Up @@ -220,4 +222,4 @@ def forward(self, f0: torch.Tensor, upsampling_factor: int):
# Combine sine waves and noise
sine_waveforms = sine_waves * voiced_mask + noise

return sine_waveforms, voiced_mask, noise
return sine_waveforms, voiced_mask, noise
16 changes: 10 additions & 6 deletions rvc/lib/algorithm/generators/hifigan_nsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rvc.lib.algorithm.residuals import LRELU_SLOPE, ResBlock
from rvc.lib.algorithm.commons import init_weights


class SourceModuleHnNSF(torch.nn.Module):
"""
Source Module for generating harmonic and noise components for audio synthesis.
Expand Down Expand Up @@ -49,6 +50,7 @@ def forward(self, x: torch.Tensor, upsample_factor: int = 1):
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
return sine_merge, None, None


class HiFiGANNSFGenerator(torch.nn.Module):
"""
Generator module based on the Neural Source Filter (NSF) architecture.
Expand Down Expand Up @@ -87,9 +89,7 @@ def __init__(
self.num_upsamples = len(upsample_rates)
self.checkpointing = checkpointing
self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
self.m_source = SourceModuleHnNSF(
sample_rate=sr, harmonic_num=0
)
self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)

self.conv_pre = torch.nn.Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
Expand Down Expand Up @@ -169,7 +169,9 @@ def __init__(
self.upp = math.prod(upsample_rates)
self.lrelu_slope = LRELU_SLOPE

def forward(self, x: torch.Tensor, f0: torch.Tensor, g: Optional[torch.Tensor] = None):
def forward(
self, x: torch.Tensor, f0: torch.Tensor, g: Optional[torch.Tensor] = None
):
har_source, _, _ = self.m_source(f0, self.upp)
har_source = har_source.transpose(1, 2)

Expand Down Expand Up @@ -198,7 +200,9 @@ def resblock_forward(x, blocks):

# Checkpoint or regular computation for ResBlocks
if self.training and self.checkpointing:
x = checkpoint.checkpoint(resblock_forward, x, blocks, use_reentrant=False)
x = checkpoint.checkpoint(
resblock_forward, x, blocks, use_reentrant=False
)
else:
x = resblock_forward(x, blocks)

Expand Down Expand Up @@ -228,4 +232,4 @@ def __prepare_scriptable__(self):
and hook.__class__.__name__ == "WeightNorm"
):
remove_weight_norm(l)
return self
return self
12 changes: 11 additions & 1 deletion rvc/lib/algorithm/generators/refinegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from rvc.lib.algorithm.commons import get_padding


class ResBlock(torch.nn.Module):
"""
Residual block with multiple dilated convolutions.
Expand Down Expand Up @@ -94,6 +95,7 @@ def init_weights(self, m):
m.weight.data.normal_(0, 0.01)
m.bias.data.fill_(0.0)


class AdaIN(torch.nn.Module):
"""
Adaptive Instance Normalization layer.
Expand All @@ -104,6 +106,7 @@ class AdaIN(torch.nn.Module):
channels (int): Number of input channels.
leaky_relu_slope (float, optional): Slope for the Leaky ReLU activation applied after scaling. Defaults to 0.2.
"""

def __init__(
self,
*,
Expand All @@ -120,6 +123,7 @@ def forward(self, x: torch.Tensor):

return self.activation(x + gaussian)


class ParallelResBlock(torch.nn.Module):
"""
Parallel residual block that applies multiple residual blocks with different kernel sizes in parallel.
Expand All @@ -131,6 +135,7 @@ class ParallelResBlock(torch.nn.Module):
dilation (tuple[int], optional): Tuple of dilation rates for the convolutional layers within the residual blocks. Defaults to (1, 3, 5).
leaky_relu_slope (float, optional): Slope for the Leaky ReLU activation. Defaults to 0.2.
"""

def __init__(
self,
*,
Expand Down Expand Up @@ -181,6 +186,7 @@ def remove_parametrizations(self):
for block in self.blocks:
block[1].remove_parametrizations()


class SineGenerator(torch.nn.Module):
"""
Definition of sine generator
Expand Down Expand Up @@ -261,6 +267,7 @@ def forward(self, f0):
sine_waves = sine_waves * uv + noise * (1 - uv)
return sine_waves, uv, noise


class SourceModuleHnNSF(torch.nn.Module):
"""
Source Module for generating harmonic and noise signals.
Expand All @@ -274,6 +281,7 @@ class SourceModuleHnNSF(torch.nn.Module):
add_noise_std (float, optional): Standard deviation of the additive noise. Defaults to 0.003.
voiced_threshold (int, optional): F0 threshold for voiced/unvoiced classification. Defaults to 0.
"""

def __init__(
self,
sampling_rate,
Expand Down Expand Up @@ -303,6 +311,7 @@ def forward(self, x: torch.Tensor):

return sine_merge, None, None


class RefineGANGenerator(torch.nn.Module):
"""
RefineGAN generator for audio synthesis.
Expand All @@ -321,6 +330,7 @@ class RefineGANGenerator(torch.nn.Module):
gin_channels (int, optional): Number of channels for the global conditioning input. Defaults to 256.
checkpointing (bool, optional): Whether to use checkpointing for memory efficiency. Defaults to False.
"""

def __init__(
self,
*,
Expand Down Expand Up @@ -477,4 +487,4 @@ def remove_parametrizations(self):
block[1].remove_parametrizations()

for block in self.upsample_conv_blocks:
block.remove_parametrizations()
block.remove_parametrizations()
2 changes: 1 addition & 1 deletion rvc/lib/algorithm/synthesizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
checkpointing=checkpointing
checkpointing=checkpointing,
)
self.enc_q = PosteriorEncoder(
spec_channels,
Expand Down
3 changes: 2 additions & 1 deletion rvc/lib/tools/pretrained_selector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os


def pretrained_selector(version, vocoder, pitch_guidance, sample_rate):
base_path = os.path.join("rvc", "models", "pretraineds", f"pretrained_{version}")
f0 = "f0" if pitch_guidance else ""
Expand All @@ -19,4 +20,4 @@ def pretrained_selector(version, vocoder, pitch_guidance, sample_rate):
if os.path.exists(path_g) and os.path.exists(path_d):
return path_g, path_d
else:
return "", ""
return "", ""
Loading

0 comments on commit 72e7503

Please sign in to comment.