diff --git a/tools/llama/build_dataset.py b/tools/llama/build_dataset.py index 63ef6c09..97559486 100644 --- a/tools/llama/build_dataset.py +++ b/tools/llama/build_dataset.py @@ -1,6 +1,8 @@ +import os import re from collections import defaultdict from multiprocessing import Pool +from pathlib import Path import click import numpy as np @@ -14,7 +16,7 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files -def task_generator(config): +def task_generator(config, filelist): with open(config, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -28,7 +30,26 @@ def task_generator(config): ) # Load the files - files = list_files(root, AUDIO_EXTENSIONS, recursive=True) + if filelist: + with open(filelist, "r", encoding="utf-8") as f: + # files = [Path(line..strip().split("|")[0]) for line in f] + files = set() + countSame = 0 + countNotFound = 0 + for line in f.readlines(): + file = Path(line.strip().split("|")[0]) + if file in files: + print(f"重复音频文本:{line}") + countSame += 1 + continue + if not os.path.isfile(file): + # 过滤数据集错误:不存在对应音频 + print(f"没有找到对应的音频:{file}") + countNotFound += 1 + continue + files.add(file) + else: + files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) grouped_files = defaultdict(list) for file in files: @@ -38,7 +59,6 @@ def task_generator(config): p = file.parent.parent.name else: raise ValueError(f"Invalid parent level {parent_level}") - grouped_files[p].append(file) logger.info(f"Found {len(grouped_files)} groups in {root}") @@ -57,7 +77,6 @@ def run_task(task): if np_file.exists() is False or txt_file.exists() is False: logger.warning(f"Can't find {np_file} or {txt_file}") continue - with open(txt_file, "r") as f: text = f.read().strip() @@ -100,10 +119,14 @@ def run_task(task): "--config", type=click.Path(), default="fish_speech/configs/data/finetune.yaml" ) @click.option("--output", type=click.Path(), default="data/quantized-dataset-ft.protos") -def main(config, output): +@click.option("--filelist", type=click.Path(), default=None) +@click.option("--num_worker", type=int, default=16) +def main(config, output, filelist, num_worker): dataset_fp = open(output, "wb") - with Pool(16) as p: - for result in tqdm(p.imap_unordered(run_task, task_generator(config))): + with Pool(num_worker) as p: + for result in tqdm( + p.imap_unordered(run_task, task_generator(config, filelist)) + ): dataset_fp.write(result) dataset_fp.close() diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py index dce63f0f..ac2f5235 100644 --- a/tools/vqgan/create_train_split.py +++ b/tools/vqgan/create_train_split.py @@ -1,4 +1,5 @@ import math +import os from pathlib import Path from random import Random @@ -12,8 +13,29 @@ @click.argument("root", type=click.Path(exists=True, path_type=Path)) @click.option("--val-ratio", type=float, default=0.2) @click.option("--val-count", type=int, default=None) -def main(root, val_ratio, val_count): - files = list_files(root, AUDIO_EXTENSIONS, recursive=True) +@click.option("--filelist", default=None, type=Path) +def main(root, val_ratio, val_count, filelist): + if filelist: + with open(filelist, "r", encoding="utf-8") as f: + # files = [Path(line..strip().split("|")[0]) for line in f] + files = set() + countSame = 0 + countNotFound = 0 + for line in f.readlines(): + file = Path(line.strip().split("|")[0]) + if file in files: + print(f"重复音频文本:{line}") + countSame += 1 + continue + if not os.path.isfile(file): + # 过滤数据集错误:不存在对应音频 + print(f"没有找到对应的音频:{file}") + countNotFound += 1 + continue + files.add(file) + files = list(files) + else: + files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) print(f"Found {len(files)} files") files = [str(file.relative_to(root)) for file in tqdm(files)] diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py index bdc055a7..0f0242c0 100644 --- a/tools/vqgan/extract_vq.py +++ b/tools/vqgan/extract_vq.py @@ -145,12 +145,14 @@ def process_batch(files: list[Path], model) -> float: default="checkpoints/vqgan-v1.pth", ) @click.option("--batch-size", default=64) +@click.option("--filelist", default=None, type=Path) def main( folder: str, num_workers: int, config_name: str, checkpoint_path: str, batch_size: int, + filelist: Path, ): if num_workers > 1 and WORLD_SIZE != num_workers: assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both" @@ -185,7 +187,28 @@ def main( # This is a worker logger.info(f"Starting worker") - files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True) + if filelist: + with open(filelist, "r", encoding="utf-8") as f: + # files = [Path(line..strip().split("|")[0]) for line in f] + files = set() + countSame = 0 + countNotFound = 0 + for line in f.readlines(): + file = Path(line.strip().split("|")[0]) + if file in files: + print(f"重复音频文本:{line}") + countSame += 1 + continue + if not os.path.isfile(file): + # 过滤数据集错误:不存在对应音频 + print(f"没有找到对应的音频:{file}") + countNotFound += 1 + continue + files.add(file) + files = list(files) + print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}") + else: + files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True) Random(42).shuffle(files) total_files = len(files)