Skip to content

Commit

Permalink
fix: only enable load with weights_only in pytorch>=2.4
Browse files Browse the repository at this point in the history
Allows moving the minimum Pytorch version back to 2.1
  • Loading branch information
eginhard committed Oct 25, 2024
1 parent b66c782 commit 965a121
Show file tree
Hide file tree
Showing 17 changed files with 89 additions and 47 deletions.
46 changes: 25 additions & 21 deletions TTS/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
import _codecs
import importlib.metadata
from collections import defaultdict

import numpy as np
import torch

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

__version__ = importlib.metadata.version("coqui-tts")


torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
if is_pytorch_at_least_2_4():
import _codecs
from collections import defaultdict

import numpy as np
import torch

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam

torch.serialization.add_safe_globals([dict, defaultdict, RAdam])

# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)
# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)

# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
3 changes: 2 additions & 1 deletion TTS/tts/layers/bark/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from TTS.tts.layers.bark.model import GPT, GPTConfig
from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

if (
torch.cuda.is_available()
Expand Down Expand Up @@ -118,7 +119,7 @@ def load_model(ckpt_path, device, config, model_type="text"):
logger.info(f"{model_type} model not found, downloading...")
_download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR)

checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=is_pytorch_at_least_2_4())
# this is a hack
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args:
Expand Down
3 changes: 2 additions & 1 deletion TTS/tts/layers/tortoise/arch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers import LogitsWarper

from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
from TTS.utils.generic_utils import is_pytorch_at_least_2_4


def zero_module(module):
Expand Down Expand Up @@ -332,7 +333,7 @@ def __init__(
self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None:
with fsspec.open(self.mel_norm_file) as f:
self.mel_norms = torch.load(f, weights_only=True)
self.mel_norms = torch.load(f, weights_only=is_pytorch_at_least_2_4())
else:
self.mel_norms = None

Expand Down
3 changes: 2 additions & 1 deletion TTS/tts/layers/tortoise/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.io.wavfile import read

from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -124,7 +125,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []):
voices = get_voices(extra_voice_dirs)
paths = voices[voice]
if len(paths) == 1 and paths[0].endswith(".pth"):
return None, torch.load(paths[0], weights_only=True)
return None, torch.load(paths[0], weights_only=is_pytorch_at_least_2_4())
else:
conds = []
for cond_path in paths:
Expand Down
4 changes: 3 additions & 1 deletion TTS/tts/layers/xtts/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torchaudio
from einops import rearrange

from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -46,7 +48,7 @@ def dvae_wav_to_mel(
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4())
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel

Expand Down
3 changes: 2 additions & 1 deletion TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec

from TTS.utils.generic_utils import is_pytorch_at_least_2_4
from TTS.vocoder.models.hifigan_generator import get_padding

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -328,7 +329,7 @@ def remove_weight_norm(self):
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())
self.load_state_dict(state["model"])
if eval:
self.eval()
Expand Down
9 changes: 7 additions & 2 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -91,7 +92,9 @@ def __init__(self, config: Coqpit):

# load GPT if available
if self.args.gpt_checkpoint:
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=True)
gpt_checkpoint = torch.load(
self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()
)
# deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
logger.info("Coqui Trainer checkpoint detected! Converting it!")
Expand Down Expand Up @@ -184,7 +187,9 @@ def __init__(self, config: Coqpit):

self.dvae.eval()
if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=True)
dvae_checkpoint = torch.load(
self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()
)
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
else:
Expand Down
4 changes: 3 additions & 1 deletion TTS/tts/layers/xtts/xtts_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch

from TTS.utils.generic_utils import is_pytorch_at_least_2_4


class SpeakerManager:
def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path, weights_only=True)
self.speakers = torch.load(speaker_file_path, weights_only=is_pytorch_at_least_2_4())

@property
def name_to_id(self):
Expand Down
8 changes: 5 additions & 3 deletions TTS/tts/models/neuralhmm_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -107,7 +107,7 @@ def update_mean_std(self, statistics_dict: Dict):

def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4())
self.update_mean_std(statistics_dict)

mels = self.normalize(mels)
Expand Down Expand Up @@ -292,7 +292,9 @@ def on_init_start(self, trainer):
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
statistics = torch.load(
trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()
)
data_mean, data_std, init_transition_prob = (
statistics["mean"],
statistics["std"],
Expand Down
8 changes: 5 additions & 3 deletions TTS/tts/models/overflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -120,7 +120,7 @@ def update_mean_std(self, statistics_dict: Dict):

def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4())
self.update_mean_std(statistics_dict)

mels = self.normalize(mels)
Expand Down Expand Up @@ -308,7 +308,9 @@ def on_init_start(self, trainer):
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
statistics = torch.load(
trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()
)
data_mean, data_std, init_transition_prob = (
statistics["mean"],
statistics["std"],
Expand Down
19 changes: 12 additions & 7 deletions TTS/tts/models/tortoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from TTS.tts.layers.tortoise.vocoder import VocConf, VocType
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -171,7 +172,11 @@ def classify_audio_clip(clip, model_dir):
distribute_zero_label=False,
)
classifier.load_state_dict(
torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"), weights_only=True)
torch.load(
os.path.join(model_dir, "classifier.pth"),
map_location=torch.device("cpu"),
weights_only=is_pytorch_at_least_2_4(),
)
)
clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1)
Expand Down Expand Up @@ -490,15 +495,15 @@ def get_random_conditioning_latents(self):
torch.load(
os.path.join(self.models_dir, "rlg_auto.pth"),
map_location=torch.device("cpu"),
weights_only=True,
weights_only=is_pytorch_at_least_2_4(),
)
)
self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(
torch.load(
os.path.join(self.models_dir, "rlg_diffuser.pth"),
map_location=torch.device("cpu"),
weights_only=True,
weights_only=is_pytorch_at_least_2_4(),
)
)
with torch.no_grad():
Expand Down Expand Up @@ -885,25 +890,25 @@ def load_checkpoint(

if os.path.exists(ar_path):
# remove keys from the checkpoint that are not in the model
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=True)
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())

# strict set False
# due to removed `bias` and `masked_bias` changes in Transformers
self.autoregressive.load_state_dict(checkpoint, strict=False)

if os.path.exists(diff_path):
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=True), strict=strict)
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=is_pytorch_at_least_2_4()), strict=strict)

if os.path.exists(clvp_path):
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=True), strict=strict)
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=is_pytorch_at_least_2_4()), strict=strict)

if os.path.exists(vocoder_checkpoint_path):
self.vocoder.load_state_dict(
config.model_args.vocoder.value.optionally_index(
torch.load(
vocoder_checkpoint_path,
map_location=torch.device("cpu"),
weights_only=True,
weights_only=is_pytorch_at_least_2_4(),
)
)
)
Expand Down
3 changes: 2 additions & 1 deletion TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,7 +66,7 @@ def wav_to_mel_cloning(
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4())
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel

Expand Down
4 changes: 3 additions & 1 deletion TTS/tts/utils/fairseq.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch

from TTS.utils.generic_utils import is_pytorch_at_least_2_4


def rehash_fairseq_vits_checkpoint(checkpoint_file):
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)["model"]
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())["model"]
new_chk = {}
for k, v in chk.items():
if "enc_p." in k:
Expand Down
3 changes: 2 additions & 1 deletion TTS/tts/utils/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import is_pytorch_at_least_2_4


def load_file(path: str):
Expand All @@ -17,7 +18,7 @@ def load_file(path: str):
return json.load(f)
elif path.endswith(".pth"):
with fsspec.open(path, "rb") as f:
return torch.load(f, map_location="cpu", weights_only=True)
return torch.load(f, map_location="cpu", weights_only=is_pytorch_at_least_2_4())
else:
raise ValueError("Unsupported file type")

Expand Down
8 changes: 8 additions & 0 deletions TTS/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from pathlib import Path
from typing import Dict, Optional

import torch
from packaging.version import Version

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -131,3 +134,8 @@ def setup_logger(
sh = logging.StreamHandler()
sh.setFormatter(formatter)
lg.addHandler(sh)


def is_pytorch_at_least_2_4() -> bool:
"""Check if the installed Pytorch version is 2.4 or higher."""
return Version(torch.__version__) >= Version("2.4")
3 changes: 2 additions & 1 deletion TTS/vc/modules/freevc/wavlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from trainer.io import get_user_data_dir

from TTS.utils.generic_utils import is_pytorch_at_least_2_4
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig

logger = logging.getLogger(__name__)
Expand All @@ -26,7 +27,7 @@ def get_wavlm(device="cpu"):
logger.info("Downloading WavLM model to %s ...", output_path)
urllib.request.urlretrieve(model_uri, output_path)

checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=True)
checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=is_pytorch_at_least_2_4())
cfg = WavLMConfig(checkpoint["cfg"])
wavlm = WavLM(cfg).to(device)
wavlm.load_state_dict(checkpoint["model"])
Expand Down
Loading

0 comments on commit 965a121

Please sign in to comment.