From 4f409cc393ee550d5ea5b7923e79c7fa82cfcee8 Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 19:12:37 +0900 Subject: [PATCH 1/3] added download function --- parallel_wavegan/__init__.py | 2 +- parallel_wavegan/utils/utils.py | 50 +++++++++++++++++++++++++++++++++ setup.py | 3 +- 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/parallel_wavegan/__init__.py b/parallel_wavegan/__init__.py index cfb0319b..9d7b8d24 100644 --- a/parallel_wavegan/__init__.py +++ b/parallel_wavegan/__init__.py @@ -1,3 +1,3 @@ # -*- coding: utf-8 -*- -__version__ = "0.4.4" +__version__ = "0.4.5" diff --git a/parallel_wavegan/utils/utils.py b/parallel_wavegan/utils/utils.py index 9892b9de..98ab6367 100644 --- a/parallel_wavegan/utils/utils.py +++ b/parallel_wavegan/utils/utils.py @@ -9,6 +9,7 @@ import logging import os import sys +import tarfile from distutils.version import LooseVersion @@ -22,6 +23,28 @@ from parallel_wavegan.layers import PQMF +PRETRAINED_MODEL_LIST = { + "ljspeech_parallel_wavegan.v1": "1PdZv37JhAQH6AwNh31QlqruqrvjTBq7U", + "ljspeech_parallel_wavegan.v1.long": "1A9TsrD9fHxFviJVFjCk5W6lkzWXwhftv", + "ljspeech_parallel_wavegan.v1.no_limit": "1CdWKSiKoFNPZyF1lo7Dsj6cPKmfLJe72", + "ljspeech_parallel_wavegan.v3": "1-oZpwpWZMMolDYsCqeL12dFkXSBD9VBq", + "ljspeech_full_band_melgan.v2": "1Kb7q5zBeQ30Wsnma0X23G08zvgDG5oen", + "ljspeech_multi_band_melgan.v2": "1b70pJefKI8DhGYz4SxbEHpxm92tj1_qC", + "jsut_parallel_wavegan.v1": "1qok91A6wuubuz4be-P9R2zKhNmQXG0VQ", + "jsut_multi_band_melgan.v2": "1chTt-76q2p69WPpZ1t1tt8szcM96IKad", + "csmsc_parallel_wavegan.v1": "1QTOAokhD5dtRnqlMPTXTW91-CG7jf74e", + "csmsc_multi_band_melgan.v2": "1G6trTmt0Szq-jWv2QDhqglMdWqQxiXQT", + "arctic_slt_parallel_wavegan.v1": "1_MXePg40-7DTjD0CDVzyduwQuW_O9aA1", + "jnas_parallel_wavegan.v1": "1D2TgvO206ixdLI90IqG787V6ySoXLsV_", + "vctk_parallel_wavegan.v1": "1bqEFLgAroDcgUy5ZFP4g2O2MwcwWLEca", + "vctk_parallel_wavegan.v1.long": "1tO4-mFrZ3aVYotgg7M519oobYkD4O_0-", + "vctk_multi_band_melgan.v2": "10PRQpHMFPE7RjF-MHYqvupK9S0xwBlJ_", + "libritts_parallel_wavegan.v1": "1zHQl8kUYEuZ_i1qEFU6g2MEu99k3sHmR", + "libritts_parallel_wavegan.v1.long": "1b9zyBYGCCaJu0TIus5GXoMF8M3YEbqOw", + "libritts_multi_band_melgan.v2": "1kIDSBjrQvAsRewHPiFwBZ3FDelTWMp64", +} + + def find_files(root_dir, query="*.wav", include_root_dir=True): """Find files recursively. @@ -290,3 +313,30 @@ def load_model(checkpoint, config=None): ) return model + + +def download_pretrained_model(tag, download_dir=None): + """Download pretrained model form google drive. + + Args: + tag (str): Pretrained model tag. + download_dir (str): Directory to save downloaded files. + + Returns: + str: Path of downloaded model checkpoint. + + """ + assert tag in PRETRAINED_MODEL_LIST, f"{tag} does not exists." + id_ = PRETRAINED_MODEL_LIST[tag] + if download_dir is None: + download_dir = os.path.expanduser("~/.cache/parallel_wavegan") + output_path = f"{download_dir}/{tag}.tar.gz" + os.makedirs(f"{download_dir}", exist_ok=True) + if not os.path.exists(output_path): + import gdown + gdown.download(f"https://drive.google.com/uc?id={id_}", output_path) + with tarfile.open(output_path, 'r:*') as tar: + tar.extractall(f"{download_dir}/{tag}") + checkpoint_path = find_files(f"{download_dir}/{tag}", "*.pkl") + + return checkpoint_path[0] diff --git a/setup.py b/setup.py index 2c32b6b8..06931d2a 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ "yq>=2.10.0", # Fix No module named "numba.decorators" "numba<=0.48", + "gdown", ], "setup": [ "numpy", @@ -65,7 +66,7 @@ dirname = os.path.dirname(__file__) setup(name="parallel_wavegan", - version="0.4.4", + version="0.4.5", url="http://github.com/kan-bayashi/ParallelWaveGAN", author="Tomoki Hayashi", author_email="hayashi.tomoki@g.sp.m.is.nagoya-u.ac.jp", From 93dfc812388cc692da66a3280531ecd28da3c02b Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 19:23:59 +0900 Subject: [PATCH 2/3] updated --- parallel_wavegan/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parallel_wavegan/utils/utils.py b/parallel_wavegan/utils/utils.py index 98ab6367..bc21af03 100644 --- a/parallel_wavegan/utils/utils.py +++ b/parallel_wavegan/utils/utils.py @@ -334,7 +334,7 @@ def download_pretrained_model(tag, download_dir=None): os.makedirs(f"{download_dir}", exist_ok=True) if not os.path.exists(output_path): import gdown - gdown.download(f"https://drive.google.com/uc?id={id_}", output_path) + gdown.download(f"https://drive.google.com/uc?id={id_}", output_path, quiet=False) with tarfile.open(output_path, 'r:*') as tar: tar.extractall(f"{download_dir}/{tag}") checkpoint_path = find_files(f"{download_dir}/{tag}", "*.pkl") From 1ca8781bb846448df48b41feac07d85f4bd88edf Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 18 Aug 2020 19:55:37 +0900 Subject: [PATCH 3/3] fixed --- parallel_wavegan/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parallel_wavegan/utils/utils.py b/parallel_wavegan/utils/utils.py index bc21af03..65314679 100644 --- a/parallel_wavegan/utils/utils.py +++ b/parallel_wavegan/utils/utils.py @@ -337,6 +337,6 @@ def download_pretrained_model(tag, download_dir=None): gdown.download(f"https://drive.google.com/uc?id={id_}", output_path, quiet=False) with tarfile.open(output_path, 'r:*') as tar: tar.extractall(f"{download_dir}/{tag}") - checkpoint_path = find_files(f"{download_dir}/{tag}", "*.pkl") + checkpoint_path = find_files(f"{download_dir}/{tag}", "checkpoint*.pkl") return checkpoint_path[0]