diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py index dc650d11..1c20a13c 100644 --- a/fish_speech/utils/file.py +++ b/fish_speech/utils/file.py @@ -2,6 +2,8 @@ from pathlib import Path from typing import Union +from loguru import logger + AUDIO_EXTENSIONS = { ".mp3", ".wav", @@ -72,3 +74,51 @@ def get_latest_checkpoint(path: Path | str) -> Path | None: return None return ckpts[-1] + + +def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]: + """ + Load a Bert-VITS2 style filelist. + """ + + files = set() + results = [] + count_duplicated, count_not_found = 0, 0 + + LANGUAGE_TO_LANGUAGES = { + "zh": ["zh", "en"], + "jp": ["jp", "en"], + "en": ["en"], + } + + with open(path, "r", encoding="utf-8") as f: + for line in f.readlines(): + filename, speaker, language, text = line.strip().split("|") + file = Path(filename) + language = language.strip().lower() + + if language == "ja": + language = "jp" + + assert language in ["zh", "jp", "en"], f"Invalid language {language}" + languages = LANGUAGE_TO_LANGUAGES[language] + + if file in files: + logger.warning(f"Duplicated file: {file}") + count_duplicated += 1 + continue + + if not file.exists(): + logger.warning(f"File not found: {file}") + count_not_found += 1 + continue + + results.append((file, speaker, languages, text)) + + if count_duplicated > 0: + logger.warning(f"Total duplicated files: {count_duplicated}") + + if count_not_found > 0: + logger.warning(f"Total files not found: {count_not_found}") + + return results diff --git a/tools/llama/build_dataset.py b/tools/llama/build_dataset.py index 97559486..1ba0c4fb 100644 --- a/tools/llama/build_dataset.py +++ b/tools/llama/build_dataset.py @@ -13,10 +13,10 @@ from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData from fish_speech.datasets.protos.text_data_stream import pack_pb_stream from fish_speech.text import g2p -from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files +from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist -def task_generator(config, filelist): +def task_generator_yaml(config): with open(config, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -30,26 +30,7 @@ def task_generator(config, filelist): ) # Load the files - if filelist: - with open(filelist, "r", encoding="utf-8") as f: - # files = [Path(line..strip().split("|")[0]) for line in f] - files = set() - countSame = 0 - countNotFound = 0 - for line in f.readlines(): - file = Path(line.strip().split("|")[0]) - if file in files: - print(f"重复音频文本:{line}") - countSame += 1 - continue - if not os.path.isfile(file): - # 过滤数据集错误:不存在对应音频 - print(f"没有找到对应的音频:{file}") - countNotFound += 1 - continue - files.add(file) - else: - files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) + files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) grouped_files = defaultdict(list) for file in files: @@ -63,22 +44,44 @@ def task_generator(config, filelist): logger.info(f"Found {len(grouped_files)} groups in {root}") for name, subset in grouped_files.items(): - yield name, subset, source, languages, extension + yield name, subset, source, languages, extension, None + + +def task_generator_filelist(filelist): + grouped_files = defaultdict(list) + for filename, speaker, languages, text in load_filelist(filelist): + if speaker in grouped_files: + assert ( + languages == grouped_files[speaker][0][2] + ), f"Speaker {speaker} has different languages" + + grouped_files[speaker].append((Path(filename), text, languages)) + + logger.info(f"Found {len(grouped_files)} groups in {filelist}") + for speaker, (filename, txt, languages) in grouped_files.items(): + yield speaker, filename, "filelist", languages, None, txt def run_task(task): - name, subset, source, languages, extension = task + name, subset, source, languages, extension, text = task # Parse the files sentences = [] for file in subset: np_file = file.with_suffix(".npy") - txt_file = file.with_suffix(extension) - if np_file.exists() is False or txt_file.exists() is False: - logger.warning(f"Can't find {np_file} or {txt_file}") + if np_file.exists() is False: + logger.warning(f"Can't find {np_file}") continue - with open(txt_file, "r") as f: - text = f.read().strip() + + if text is None: + txt_file = file.with_suffix(extension) + + if txt_file.exists() is False: + logger.warning(f"Can't find {txt_file}") + continue + + with open(txt_file, "r") as f: + text = f.read().strip() # Simple cleaning: replace { xxx } and < xxx > with space text = re.sub(r"\{.*?\}", " ", text) diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py index ac2f5235..cb800da6 100644 --- a/tools/vqgan/create_train_split.py +++ b/tools/vqgan/create_train_split.py @@ -6,7 +6,7 @@ import click from tqdm import tqdm -from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files +from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist @click.command() @@ -16,28 +16,11 @@ @click.option("--filelist", default=None, type=Path) def main(root, val_ratio, val_count, filelist): if filelist: - with open(filelist, "r", encoding="utf-8") as f: - # files = [Path(line..strip().split("|")[0]) for line in f] - files = set() - countSame = 0 - countNotFound = 0 - for line in f.readlines(): - file = Path(line.strip().split("|")[0]) - if file in files: - print(f"重复音频文本:{line}") - countSame += 1 - continue - if not os.path.isfile(file): - # 过滤数据集错误:不存在对应音频 - print(f"没有找到对应的音频:{file}") - countNotFound += 1 - continue - files.add(file) - files = list(files) + files = [i[0] for i in load_filelist(filelist)] else: files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) - print(f"Found {len(files)} files") + print(f"Found {len(files)} files") files = [str(file.relative_to(root)) for file in tqdm(files)] Random(42).shuffle(files) diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py index 0f0242c0..0d758bc1 100644 --- a/tools/vqgan/extract_vq.py +++ b/tools/vqgan/extract_vq.py @@ -19,7 +19,7 @@ from omegaconf import OmegaConf from fish_speech.models.vqgan.utils import sequence_mask -from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files +from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist # register eval resolver OmegaConf.register_new_resolver("eval", eval) @@ -188,27 +188,10 @@ def main( # This is a worker logger.info(f"Starting worker") if filelist: - with open(filelist, "r", encoding="utf-8") as f: - # files = [Path(line..strip().split("|")[0]) for line in f] - files = set() - countSame = 0 - countNotFound = 0 - for line in f.readlines(): - file = Path(line.strip().split("|")[0]) - if file in files: - print(f"重复音频文本:{line}") - countSame += 1 - continue - if not os.path.isfile(file): - # 过滤数据集错误:不存在对应音频 - print(f"没有找到对应的音频:{file}") - countNotFound += 1 - continue - files.add(file) - files = list(files) - print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}") + files = [i[0] for i in load_filelist(filelist)] else: files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True) + Random(42).shuffle(files) total_files = len(files)