Skip to content

Commit

Permalink
Optimize bert-vits2 parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Dec 20, 2023
1 parent 42e1442 commit 4f02d63
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 69 deletions.
50 changes: 50 additions & 0 deletions fish_speech/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pathlib import Path
from typing import Union

from loguru import logger

AUDIO_EXTENSIONS = {
".mp3",
".wav",
Expand Down Expand Up @@ -72,3 +74,51 @@ def get_latest_checkpoint(path: Path | str) -> Path | None:
return None

return ckpts[-1]


def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
"""
Load a Bert-VITS2 style filelist.
"""

files = set()
results = []
count_duplicated, count_not_found = 0, 0

LANGUAGE_TO_LANGUAGES = {
"zh": ["zh", "en"],
"jp": ["jp", "en"],
"en": ["en"],
}

with open(path, "r", encoding="utf-8") as f:
for line in f.readlines():
filename, speaker, language, text = line.strip().split("|")
file = Path(filename)
language = language.strip().lower()

if language == "ja":
language = "jp"

assert language in ["zh", "jp", "en"], f"Invalid language {language}"
languages = LANGUAGE_TO_LANGUAGES[language]

if file in files:
logger.warning(f"Duplicated file: {file}")
count_duplicated += 1
continue

if not file.exists():
logger.warning(f"File not found: {file}")
count_not_found += 1
continue

results.append((file, speaker, languages, text))

if count_duplicated > 0:
logger.warning(f"Total duplicated files: {count_duplicated}")

if count_not_found > 0:
logger.warning(f"Total files not found: {count_not_found}")

return results
61 changes: 32 additions & 29 deletions tools/llama/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
from fish_speech.text import g2p
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist


def task_generator(config, filelist):
def task_generator_yaml(config):
with open(config, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)

Expand All @@ -30,26 +30,7 @@ def task_generator(config, filelist):
)

# Load the files
if filelist:
with open(filelist, "r", encoding="utf-8") as f:
# files = [Path(line..strip().split("|")[0]) for line in f]
files = set()
countSame = 0
countNotFound = 0
for line in f.readlines():
file = Path(line.strip().split("|")[0])
if file in files:
print(f"重复音频文本:{line}")
countSame += 1
continue
if not os.path.isfile(file):
# 过滤数据集错误:不存在对应音频
print(f"没有找到对应的音频:{file}")
countNotFound += 1
continue
files.add(file)
else:
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)

grouped_files = defaultdict(list)
for file in files:
Expand All @@ -63,22 +44,44 @@ def task_generator(config, filelist):

logger.info(f"Found {len(grouped_files)} groups in {root}")
for name, subset in grouped_files.items():
yield name, subset, source, languages, extension
yield name, subset, source, languages, extension, None


def task_generator_filelist(filelist):
grouped_files = defaultdict(list)
for filename, speaker, languages, text in load_filelist(filelist):
if speaker in grouped_files:
assert (
languages == grouped_files[speaker][0][2]
), f"Speaker {speaker} has different languages"

grouped_files[speaker].append((Path(filename), text, languages))

logger.info(f"Found {len(grouped_files)} groups in {filelist}")
for speaker, (filename, txt, languages) in grouped_files.items():
yield speaker, filename, "filelist", languages, None, txt


def run_task(task):
name, subset, source, languages, extension = task
name, subset, source, languages, extension, text = task

# Parse the files
sentences = []
for file in subset:
np_file = file.with_suffix(".npy")
txt_file = file.with_suffix(extension)
if np_file.exists() is False or txt_file.exists() is False:
logger.warning(f"Can't find {np_file} or {txt_file}")
if np_file.exists() is False:
logger.warning(f"Can't find {np_file}")
continue
with open(txt_file, "r") as f:
text = f.read().strip()

if text is None:
txt_file = file.with_suffix(extension)

if txt_file.exists() is False:
logger.warning(f"Can't find {txt_file}")
continue

with open(txt_file, "r") as f:
text = f.read().strip()

# Simple cleaning: replace { xxx } and < xxx > with space
text = re.sub(r"\{.*?\}", " ", text)
Expand Down
23 changes: 3 additions & 20 deletions tools/vqgan/create_train_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import click
from tqdm import tqdm

from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist


@click.command()
Expand All @@ -16,28 +16,11 @@
@click.option("--filelist", default=None, type=Path)
def main(root, val_ratio, val_count, filelist):
if filelist:
with open(filelist, "r", encoding="utf-8") as f:
# files = [Path(line..strip().split("|")[0]) for line in f]
files = set()
countSame = 0
countNotFound = 0
for line in f.readlines():
file = Path(line.strip().split("|")[0])
if file in files:
print(f"重复音频文本:{line}")
countSame += 1
continue
if not os.path.isfile(file):
# 过滤数据集错误:不存在对应音频
print(f"没有找到对应的音频:{file}")
countNotFound += 1
continue
files.add(file)
files = list(files)
files = [i[0] for i in load_filelist(filelist)]
else:
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
print(f"Found {len(files)} files")

print(f"Found {len(files)} files")
files = [str(file.relative_to(root)) for file in tqdm(files)]

Random(42).shuffle(files)
Expand Down
23 changes: 3 additions & 20 deletions tools/vqgan/extract_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from omegaconf import OmegaConf

from fish_speech.models.vqgan.utils import sequence_mask
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist

# register eval resolver
OmegaConf.register_new_resolver("eval", eval)
Expand Down Expand Up @@ -188,27 +188,10 @@ def main(
# This is a worker
logger.info(f"Starting worker")
if filelist:
with open(filelist, "r", encoding="utf-8") as f:
# files = [Path(line..strip().split("|")[0]) for line in f]
files = set()
countSame = 0
countNotFound = 0
for line in f.readlines():
file = Path(line.strip().split("|")[0])
if file in files:
print(f"重复音频文本:{line}")
countSame += 1
continue
if not os.path.isfile(file):
# 过滤数据集错误:不存在对应音频
print(f"没有找到对应的音频:{file}")
countNotFound += 1
continue
files.add(file)
files = list(files)
print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}")
files = [i[0] for i in load_filelist(filelist)]
else:
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)

Random(42).shuffle(files)

total_files = len(files)
Expand Down

0 comments on commit 4f02d63

Please sign in to comment.