Skip to content

Commit

Permalink
Support VITS Filelist Input (#18)
Browse files Browse the repository at this point in the history
* add vits filelist support

* add vits filelist support

* Update create_train_split.py

* Update create_train_split.py

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* fix list not subscriptable

* fix list not subscriptable

* fix path lib

* Add files via upload

* Add files via upload

* fix parent

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add files via upload

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add files via upload

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add files via upload

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Stardust-minus and pre-commit-ci[bot] authored Dec 20, 2023
1 parent 6806674 commit 42e1442
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 10 deletions.
37 changes: 30 additions & 7 deletions tools/llama/build_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import re
from collections import defaultdict
from multiprocessing import Pool
from pathlib import Path

import click
import numpy as np
Expand All @@ -14,7 +16,7 @@
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files


def task_generator(config):
def task_generator(config, filelist):
with open(config, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)

Expand All @@ -28,7 +30,26 @@ def task_generator(config):
)

# Load the files
files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
if filelist:
with open(filelist, "r", encoding="utf-8") as f:
# files = [Path(line..strip().split("|")[0]) for line in f]
files = set()
countSame = 0
countNotFound = 0
for line in f.readlines():
file = Path(line.strip().split("|")[0])
if file in files:
print(f"重复音频文本:{line}")
countSame += 1
continue
if not os.path.isfile(file):
# 过滤数据集错误:不存在对应音频
print(f"没有找到对应的音频:{file}")
countNotFound += 1
continue
files.add(file)
else:
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)

grouped_files = defaultdict(list)
for file in files:
Expand All @@ -38,7 +59,6 @@ def task_generator(config):
p = file.parent.parent.name
else:
raise ValueError(f"Invalid parent level {parent_level}")

grouped_files[p].append(file)

logger.info(f"Found {len(grouped_files)} groups in {root}")
Expand All @@ -57,7 +77,6 @@ def run_task(task):
if np_file.exists() is False or txt_file.exists() is False:
logger.warning(f"Can't find {np_file} or {txt_file}")
continue

with open(txt_file, "r") as f:
text = f.read().strip()

Expand Down Expand Up @@ -100,10 +119,14 @@ def run_task(task):
"--config", type=click.Path(), default="fish_speech/configs/data/finetune.yaml"
)
@click.option("--output", type=click.Path(), default="data/quantized-dataset-ft.protos")
def main(config, output):
@click.option("--filelist", type=click.Path(), default=None)
@click.option("--num_worker", type=int, default=16)
def main(config, output, filelist, num_worker):
dataset_fp = open(output, "wb")
with Pool(16) as p:
for result in tqdm(p.imap_unordered(run_task, task_generator(config))):
with Pool(num_worker) as p:
for result in tqdm(
p.imap_unordered(run_task, task_generator(config, filelist))
):
dataset_fp.write(result)

dataset_fp.close()
Expand Down
26 changes: 24 additions & 2 deletions tools/vqgan/create_train_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import os
from pathlib import Path
from random import Random

Expand All @@ -12,8 +13,29 @@
@click.argument("root", type=click.Path(exists=True, path_type=Path))
@click.option("--val-ratio", type=float, default=0.2)
@click.option("--val-count", type=int, default=None)
def main(root, val_ratio, val_count):
files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
@click.option("--filelist", default=None, type=Path)
def main(root, val_ratio, val_count, filelist):
if filelist:
with open(filelist, "r", encoding="utf-8") as f:
# files = [Path(line..strip().split("|")[0]) for line in f]
files = set()
countSame = 0
countNotFound = 0
for line in f.readlines():
file = Path(line.strip().split("|")[0])
if file in files:
print(f"重复音频文本:{line}")
countSame += 1
continue
if not os.path.isfile(file):
# 过滤数据集错误:不存在对应音频
print(f"没有找到对应的音频:{file}")
countNotFound += 1
continue
files.add(file)
files = list(files)
else:
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
print(f"Found {len(files)} files")

files = [str(file.relative_to(root)) for file in tqdm(files)]
Expand Down
25 changes: 24 additions & 1 deletion tools/vqgan/extract_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,14 @@ def process_batch(files: list[Path], model) -> float:
default="checkpoints/vqgan-v1.pth",
)
@click.option("--batch-size", default=64)
@click.option("--filelist", default=None, type=Path)
def main(
folder: str,
num_workers: int,
config_name: str,
checkpoint_path: str,
batch_size: int,
filelist: Path,
):
if num_workers > 1 and WORLD_SIZE != num_workers:
assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
Expand Down Expand Up @@ -185,7 +187,28 @@ def main(

# This is a worker
logger.info(f"Starting worker")
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
if filelist:
with open(filelist, "r", encoding="utf-8") as f:
# files = [Path(line..strip().split("|")[0]) for line in f]
files = set()
countSame = 0
countNotFound = 0
for line in f.readlines():
file = Path(line.strip().split("|")[0])
if file in files:
print(f"重复音频文本:{line}")
countSame += 1
continue
if not os.path.isfile(file):
# 过滤数据集错误:不存在对应音频
print(f"没有找到对应的音频:{file}")
countNotFound += 1
continue
files.add(file)
files = list(files)
print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}")
else:
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
Random(42).shuffle(files)

total_files = len(files)
Expand Down

0 comments on commit 42e1442

Please sign in to comment.