Skip to content

Commit

Permalink
Use a ratio to split train vs val set in tools/vqgan/create_train_spl…
Browse files Browse the repository at this point in the history
…it.py (#11)

* feat(tools): vqgan/create_train_split: use a ratio to split train vs val set

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(tools): vqgan/create_train_split: support using --val_ratio and --val_count to adjust split

* feat(tools): vqgan/create_train_split: support using --val_ratio and --val_count to adjust split

* fix(tools): vqgan/create_train_split: fix wrong option name

* Clean some code

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Leng Yue <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2023
1 parent cf69582 commit 8dac6ec
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions tools/vqgan/create_train_split.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from pathlib import Path
from random import Random

Expand All @@ -9,19 +10,28 @@

@click.command()
@click.argument("root", type=click.Path(exists=True, path_type=Path))
def main(root):
@click.option("--val-ratio", type=float, default=0.2)
@click.option("--val-count", type=int, default=None)
def main(root, val_ratio, val_count):
files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
print(f"Found {len(files)} files")

files = [str(file.relative_to(root)) for file in tqdm(files)]

Random(42).shuffle(files)

with open(root / "vq_train_filelist.txt", "w") as f:
f.write("\n".join(files[:-100]))
if val_count is not None:
if val_count < 1 or val_count > len(files):
raise ValueError("val_count must be between 1 and number of files")
val_size = val_count
else:
val_size = math.ceil(len(files) * val_ratio)

with open(root / "vq_val_filelist.txt", "w") as f:
f.write("\n".join(files[-100:]))
with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
f.write("\n".join(files[val_size:]))

with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
f.write("\n".join(files[:val_size]))

print("Done")

Expand Down

0 comments on commit 8dac6ec

Please sign in to comment.