diff --git a/README.md b/README.md index 3cb84ec4..fe7434d6 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ client.start().join() Assuming MFA is installed using `download_mfa.sh` and [Common Voice dataset](https://commonvoice.mozilla.org/) is downloaded already, one can easily generate a dataset for custom wakeword using `generate_dataset.sh` script. ```bash -./generate_dataset.sh +./generate_dataset.sh <(Optional) "true" to skip negative dataset generation> ``` In the example that follows, we describe the process of generating a dataste for the word, "fire." @@ -102,10 +102,16 @@ mfa_align data/fire-positive/audio eng.dict pretrained_models/english.zip output DATASET_PATH=data/fire-positive python -m training.run.attach_alignment --align-type mfa -i output-folder ``` +8. (Optional) Stitch vocab samples of aligned dataset to generate wakeword samples + +```bash +VOCAB='["fire"]' INFERENCE_SEQUENCE=[0] python -m training.run.stitch_vocab_samples --aligned-dataset "data/fire-positive" --stitched-dataset "data/fire-stitched" +``` + ### Training and Running a Model 1. Source the relevant environment variables for training the `res8` model: `source envs/res8.env`. -2. Train the model: `python -m training.run.train -i data/fire-positive data/fire-negative --model res8 --workspace workspaces/fire-res8`. +2. Train the model: `python -m training.run.train -i data/fire-positive data/fire-negative data/fire-stitched --model res8 --workspace workspaces/fire-res8`. 3. For the CLI demo, run `python -m training.run.demo --model res8 --workspace workspaces/fire-res8`. `train_model.sh` is also available which encaspulates individual command into a single bash script diff --git a/generate_dataset.sh b/generate_dataset.sh index e80d2f7d..8c321dfc 100755 --- a/generate_dataset.sh +++ b/generate_dataset.sh @@ -1,14 +1,18 @@ #bin/bash -# TODO:: enable this flag after fixing segfault issue of create_new_dataset -# set -e +set -e COMMON_VOICE_DATASET_PATH=${1} # common voice dataset path DATASET_NAME=${2} # underscore separated wakeword (e.g. hey_fire_fox) INFERENCE_SEQUENCE=${3} # inference sequence (e.g. [0,1,2]) +#${4} pass true to skip generating negative dataset if [ $# -lt 3 ]; then - echo 1>&2 "invalid arguments: ./generate_dataset.sh " - exit 2 + echo 1>&2 "invalid arguments: ./generate_dataset.sh " + exit 2 +elif [ $# -eq 4 ]; then + SKIP_NEG_DATASET=${4} +else + SKIP_NEG_DATASET="false" fi echo "COMMON_VOICE_DATASET_PATH: ${COMMON_VOICE_DATASET_PATH}" @@ -27,13 +31,15 @@ DATASET_FOLDER="data/${DATASET_NAME}" echo ">>> generating datasets for ${VOCAB} at ${DATASET_FOLDER}" mkdir -p "${DATASET_FOLDER}" -NEG_DATASET_PATH="${DATASET_FOLDER}/negative" -echo ">>> generating negative dataset: ${NEG_DATASET_PATH}" -mkdir -p "${NEG_DATASET_PATH}" -time VOCAB=${VOCAB} INFERENCE_SEQUENCE=${INFERENCE_SEQUENCE} DATASET_PATH=${NEG_DATASET_PATH} python -m training.run.create_raw_dataset -i ${COMMON_VOICE_DATASET_PATH} --positive-pct 0 --negative-pct 5 +if [ ${SKIP_NEG_DATASET} != "true" ]; then + NEG_DATASET_PATH="${DATASET_FOLDER}/negative" + echo ">>> generating negative dataset: ${NEG_DATASET_PATH}" + mkdir -p "${NEG_DATASET_PATH}" + time VOCAB=${VOCAB} INFERENCE_SEQUENCE=${INFERENCE_SEQUENCE} DATASET_PATH=${NEG_DATASET_PATH} python -m training.run.create_raw_dataset -i ${COMMON_VOICE_DATASET_PATH} --positive-pct 0 --negative-pct 5 -echo ">>> generating mock alignment for the negative set" -time DATASET_PATH=${NEG_DATASET_PATH} python -m training.run.attach_alignment --align-type stub + echo ">>> generating mock alignment for the negative set" + time DATASET_PATH=${NEG_DATASET_PATH} python -m training.run.attach_alignment --align-type stub +fi POS_DATASET_PATH="${DATASET_FOLDER}/positive" echo ">>> generating positive dataset: ${POS_DATASET_PATH}" @@ -53,6 +59,10 @@ time yes n | ./bin/mfa_align --verbose --clean --num_jobs 12 "../${POS_DATASET_P popd echo ">>> attaching the MFA alignment to the positive dataset" -DATASET_PATH=${POS_DATASET_PATH} python -m training.run.attach_alignment --align-type mfa -i "${POS_DATASET_ALIGNMENT}" +time DATASET_PATH=${POS_DATASET_PATH} python -m training.run.attach_alignment --align-type mfa -i "${POS_DATASET_ALIGNMENT}" + +STITCHED_DATASET="${DATASET_FOLDER}/stitched" +echo ">>> stitching vocab samples to generate a datset made up of stitched wakeword samples: ${STITCHED_DATASET}" +time VOCAB=${VOCAB} INFERENCE_SEQUENCE=${INFERENCE_SEQUENCE} python -m training.run.stitch_vocab_samples --aligned-dataset "${POS_DATASET_PATH}" --stitched-dataset "${STITCHED_DATASET}" echo ">>> Dataset is ready for ${VOCAB}" diff --git a/howl/context.py b/howl/context.py index 773de802..c24cd365 100644 --- a/howl/context.py +++ b/howl/context.py @@ -48,19 +48,20 @@ def __init__(self, elif token_type == 'word': self.add_vocab(vocab) + # initialize vocab set for the system + self.negative_label = len(self.adjusted_vocab) + self.vocab = Vocab({word: idx for idx, word in enumerate( + self.adjusted_vocab)}, oov_token_id=self.negative_label) + # initialize labeler; make sure this is located before adding other labels if token_type == 'phone': phone_phrases = [PhonePhrase.from_string( x) for x in self.adjusted_vocab] self.labeler = PhoneticFrameLabeler(phone_phrases) elif token_type == 'word': - print('labeler vocab: ', self.adjusted_vocab) - self.labeler = WordFrameLabeler(self.adjusted_vocab) + self.labeler = WordFrameLabeler(self.vocab) - # initialize vocab set for the system and add negative label - self.negative_label = len(self.adjusted_vocab) - self.vocab = Vocab({word: idx for idx, word in enumerate( - self.adjusted_vocab)}, oov_token_id=self.negative_label) + # add negative label self.add_vocab(['[OOV]']) # initialize TranscriptSearcher with the processed targets diff --git a/howl/data/dataset/base.py b/howl/data/dataset/base.py index ef9a21b0..53ec4ae4 100644 --- a/howl/data/dataset/base.py +++ b/howl/data/dataset/base.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass from pathlib import Path -from typing import Generic, List, Mapping, Optional, TypeVar +from typing import Generic, List, Mapping, Optional, Tuple, TypeVar import torch from pydantic import BaseModel @@ -31,6 +31,8 @@ @dataclass class FrameLabelData: timestamp_label_map: Mapping[float, int] + start_timestamp: List[Tuple[int, float]] + char_indices: List[Tuple[int, List[int]]] @dataclass @@ -158,7 +160,10 @@ def emplaced_audio_data(self, new: bool = False) -> 'WakeWordClipExample': ex = super().emplaced_audio_data(audio_data, scale, bias, new) label_data = {} if new else {scale * k + bias: v for k, v in self.label_data.timestamp_label_map.items()} - return WakeWordClipExample(FrameLabelData(label_data), ex.metadata, audio_data, self.sample_rate) + return WakeWordClipExample(FrameLabelData(label_data, self.label_data.start_timestamp, self.label_data.char_indices), + ex.metadata, + audio_data, + self.sample_rate) @dataclass diff --git a/howl/data/dataset/labeller.py b/howl/data/dataset/labeller.py index f2b714b7..97544eba 100644 --- a/howl/data/dataset/labeller.py +++ b/howl/data/dataset/labeller.py @@ -3,6 +3,7 @@ from typing import List from howl.data.dataset.phone import PhonePhrase +from howl.data.tokenize import Vocab from .base import AudioClipMetadata, FrameLabelData @@ -38,22 +39,28 @@ def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData: class WordFrameLabeler(FrameLabeler): - def __init__(self, words: List[str], ceil_word_boundary: bool = False): - self.words = words + def __init__(self, vocab: Vocab, ceil_word_boundary: bool = False): + self.vocab = vocab self.ceil_word_boundary = ceil_word_boundary def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData: frame_labels = dict() - t = f' {metadata.transcription} ' - start = 0 - for idx, word in enumerate(self.words): - while True: - try: - start = t.index(word, start) - except ValueError: - break - while self.ceil_word_boundary and start + len(word) < len(t) - 1 and t[start + len(word)] != ' ': - start += 1 - frame_labels[metadata.end_timestamps[start + len(word.rstrip()) - 2]] = idx - start += 1 - return FrameLabelData(frame_labels) + char_indices = [] + start_timestamp = [] + + char_idx = 0 + for word in metadata.transcription.split(): + vocab_found, remaining_transcript = self.vocab.trie.max_split(word) + word_size = len(word.rstrip()) + + # if the current word is in vocab, store necessary informations + if vocab_found and remaining_transcript == "": + label = self.vocab[word] + end_timestamp = metadata.end_timestamps[char_idx + word_size - 1] + frame_labels[end_timestamp] = label + char_indices.append((label, list(range(char_idx, char_idx + word_size)))) + start_timestamp.append((label, metadata.end_timestamps[char_idx-1] if char_idx > 0 else 0.0)) + + char_idx += word_size + 1 # space + + return FrameLabelData(frame_labels, start_timestamp, char_indices) diff --git a/howl/data/dataset/serialize.py b/howl/data/dataset/serialize.py index 65b54f69..98235f51 100644 --- a/howl/data/dataset/serialize.py +++ b/howl/data/dataset/serialize.py @@ -1,24 +1,23 @@ +import json +import logging from collections import defaultdict from copy import deepcopy from functools import partial -from typing import Tuple, TypeVar, List -from pathlib import Path from multiprocessing import Pool -import json -import logging +from pathlib import Path +from typing import List, Tuple, TypeVar -from tqdm import tqdm import pandas as pd import soundfile - -from .base import DatasetType, AudioClipMetadata, UNKNOWN_TRANSCRIPTION -from .dataset import AudioClipDataset, WakeWordDataset, AudioClassificationDataset, AudioDataset, \ - HonkSpeechCommandsDataset from howl.registered import RegisteredObjectBase from howl.utils.audio import silent_load from howl.utils.hash import sha256_int from howl.utils.transcribe import SpeechToText +from tqdm import tqdm +from .base import UNKNOWN_TRANSCRIPTION, AudioClipMetadata, DatasetType +from .dataset import (AudioClassificationDataset, AudioClipDataset, + AudioDataset, HonkSpeechCommandsDataset, WakeWordDataset) __all__ = ['AudioDatasetWriter', 'AudioClipDatasetLoader', @@ -58,14 +57,15 @@ def __exit__(self, *args): class AudioDatasetWriter: - def __init__(self, dataset: AudioClipDataset, mode: str = 'w', print_progress: bool = True): + def __init__(self, dataset: AudioClipDataset, prefix: str = '', mode: str = 'w', print_progress: bool = True): self.dataset = dataset self.print_progress = print_progress self.mode = mode + self.prefix = prefix def write(self, folder: Path): def process(metadata: AudioClipMetadata): - new_path = audio_folder / metadata.path.with_suffix('.wav').name + new_path = (audio_folder / metadata.audio_id).with_suffix('.wav') if not new_path.exists(): audio_data = silent_load(str(metadata.path), self.dataset.sr, self.dataset.mono) soundfile.write(str(new_path), audio_data, self.dataset.sr) @@ -75,7 +75,7 @@ def process(metadata: AudioClipMetadata): folder.mkdir(exist_ok=True) audio_folder = folder / 'audio' audio_folder.mkdir(exist_ok=True) - with AudioDatasetMetadataWriter(folder, self.dataset.set_type, mode=self.mode) as writer: + with AudioDatasetMetadataWriter(folder, self.dataset.set_type, prefix=self.prefix, mode=self.mode) as writer: for metadata in tqdm(self.dataset.metadata_list, disable=not self.print_progress, desc='Writing files'): try: process(metadata) @@ -133,15 +133,17 @@ class WakeWordDatasetLoader(MetadataLoaderMixin, PathDatasetLoader): dataset_class = WakeWordDataset metadata_class = AudioClipMetadata + def transcribe_hey_snips_audio(path, metadata): stt = SpeechToText() path = (path / metadata['audio_file_path']).absolute() transcription = 'hey snips' - if metadata['is_hotword'] == 0: # negative sample + if metadata['is_hotword'] == 0: # negative sample transcription = stt.transcribe(path) return path, transcription + class HeySnipsWakeWordLoader(RegisteredPathDatasetLoader, name='hey-snips'): def __init__(self, num_processes=8): self.stt = SpeechToText() @@ -192,7 +194,7 @@ def load(filename, set_type): return (load('train.json', DatasetType.TRAINING), load('dev.json', DatasetType.DEV), load('test.json', DatasetType.TEST)) - + class GoogleSpeechCommandsDatasetLoader(RegisteredPathDatasetLoader, name='gsc'): def __init__(self, vocab: List[str] = None, use_bg_noise: bool = False): @@ -237,7 +239,8 @@ def load(filename, set_type): df = pd.read_csv(str(path / filename), sep='\t', quoting=3, na_filter=False) metadata_list = [] for tup in df.itertuples(): - metadata_list.append(AudioClipMetadata(path=(path / 'clips' / tup.path).absolute(), transcription=tup.sentence)) + metadata_list.append(AudioClipMetadata( + path=(path / 'clips' / tup.path).absolute(), transcription=tup.sentence)) return AudioClipDataset(metadata_list=metadata_list, set_type=set_type, **dataset_kwargs) assert path.exists(), 'dataset path doesn\'t exist' diff --git a/howl/data/stitcher.py b/howl/data/stitcher.py new file mode 100644 index 00000000..875ead01 --- /dev/null +++ b/howl/data/stitcher.py @@ -0,0 +1,166 @@ +import itertools +import random +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple + +import soundfile +import torch +from howl.data.dataset import (AudioClipDataset, AudioClipExample, + AudioClipMetadata, AudioDataset, DatasetType) +from howl.data.tokenize import Vocab +from howl.settings import SETTINGS +from tqdm import tqdm + +__all__ = ['WordStitcher'] + + +@dataclass +class FrameLabelledSample: + audio_data: torch.Tensor + audio_length_ms: float + end_timestamps: Optional[List[float]] + label: int + + +class Stitcher: + def __init__(self, + vocab: Vocab): + self.sequence = SETTINGS.inference_engine.inference_sequence + self.sr = SETTINGS.audio.sample_rate + self.vocab = vocab + self.wakeword = ' '.join(self.vocab[x] + for x in self.sequence) + + +class WordStitcher(Stitcher): + def __init__(self, + **kwargs): + super().__init__(**kwargs) + + def concatenate_end_timestamps(self, end_timestamps_list: List[List[float]]) -> List[float]: + """concatenate given list of end timestamps for single audio sample + + Args: + end_timestamps_list (List[List[float]]): list of timestamps to concatenate + + Returns: + List[float]: concatenated end timestamps + """ + concatnated_timestamps = [] + last_timestamp = 0 + for end_timestamps in end_timestamps_list: + for timestamp in end_timestamps: + concatnated_timestamps.append(timestamp + last_timestamp) + + # when stitching space will be added between the vocabs + # therefore last timestamp is repeated once to make up for the added space + concatnated_timestamps.append(concatnated_timestamps[-1]) + last_timestamp = concatnated_timestamps[-1] + + return concatnated_timestamps[:-1] # discard last space timestamp + + def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, * datasets: AudioDataset): + """collect vocab samples from datasets and generate stitched wakeword samples + + Args: + num_stitched_samples (int): number of stitched wakeword samples to geneate + stitched_dataset_dir (Path): folder for the stitched dataset where the audio samples will be saved + datasets (Path): list of datasets to collect vocab samples from + """ + sample_set = [[] for _ in range(len(self.vocab))] + + for dataset in datasets: + # for each audio sample, collect vocab audio sample based on alignment + for sample in dataset: + for (label, char_indices) in sample.label_data.char_indices: + vocab_start_idx = char_indices[0] - 1 if char_indices[0] > 0 else 0 + start_timestamp = sample.metadata.end_timestamps[vocab_start_idx] + end_timestamp = sample.metadata.end_timestamps[char_indices[-1]] + + audio_start_idx = int(start_timestamp * self.sr / 1000) + audio_end_idx = int(end_timestamp * self.sr / 1000) + + adjusted_end_timestamps = [] + for char_idx in char_indices: + adjusted_end_timestamps.append(sample.metadata.end_timestamps[char_idx] - start_timestamp) + + sample_set[label].append(FrameLabelledSample( + sample.audio_data[audio_start_idx:audio_end_idx], end_timestamp-start_timestamp, adjusted_end_timestamps, label)) + + audio_dir = stitched_dataset_dir / "audio" + audio_dir.mkdir(exist_ok=True) + + # reorganize and make sure there are enough samples for each vocab + sample_lists = [] + for element in self.sequence: + print(f"number of samples for vocab {self.vocab[element]}: {len(sample_set[element])}") + assert len(sample_set[element]) > 0, "There must be at least one sample for each vocab" + sample_lists.append(sample_set[element]) + + # generate AudioClipExample for each vocab sample + self.stitched_samples = [] + for sample_idx in tqdm(range(num_stitched_samples), desc="Generating stitched samples"): + sample_set = [] + for sample_list in sample_lists: + sample_set.append(random.choice(sample_list)) + + metatdata = AudioClipMetadata( + path=Path(audio_dir / f"{sample_idx}").with_suffix(".wav"), + transcription=self.wakeword, + end_timestamps=self.concatenate_end_timestamps( + [labelled_data.end_timestamps for labelled_data in sample_set]) + ) + + # TODO:: dataset writer load the samples upon write and does not use data in memory + # writer classes need to be refactored to use audio data if exist + audio_data = torch.cat([labelled_data.audio_data for labelled_data in sample_set]) + soundfile.write(metatdata.path, audio_data.numpy(), self.sr) + + stitched_sample = AudioClipExample( + metadata=metatdata, + audio_data=audio_data, + sample_rate=self.sr) + + self.stitched_samples.append(stitched_sample) + + def load_splits(self, + train_pct: float, + dev_pct: float, + test_pct: float) -> Tuple[AudioClipDataset, AudioClipDataset, AudioClipDataset]: + """split the generated stitched samples based on the given pct + first train_pct samples are used to generate train set + next dev_pct samples are used to generate dev set + next test_pct samples are used to generate test set + + Args: + train_pct (float): train set perceptage (0, 1) + dev_pct (float): dev set perceptage (0, 1) + test_pct (float): test set perceptage (0, 1) + + Returns: + Tuple[AudioClipDataset, AudioClipDataset, AudioClipDataset]: train/dev/test datasets + """ + + num_samples = len(self.stitched_samples) + train_bucket = int(train_pct * num_samples) + dev_bucket = int((train_pct + dev_pct) * num_samples) + test_bucket = int((train_pct + dev_pct + test_pct) * num_samples) + + random.shuffle(self.stitched_samples) + train_split = [] + dev_split = [] + test_split = [] + + for idx, sample in enumerate(self.stitched_samples): + if idx < train_bucket: + train_split.append(sample.metadata) + elif idx < dev_bucket: + dev_split.append(sample.metadata) + elif idx < test_bucket: + test_split.append(sample.metadata) + + ds_kwargs = dict(sr=self.sr, mono=SETTINGS.audio.use_mono) + return (AudioClipDataset(metadata_list=train_split, set_type=DatasetType.TRAINING, **ds_kwargs), + AudioClipDataset(metadata_list=dev_split, set_type=DatasetType.DEV, **ds_kwargs), + AudioClipDataset(metadata_list=test_split, set_type=DatasetType.TEST, **ds_kwargs)) diff --git a/howl/data/tokenize.py b/howl/data/tokenize.py index 0f0ca44b..cf2c43ba 100644 --- a/howl/data/tokenize.py +++ b/howl/data/tokenize.py @@ -60,9 +60,11 @@ def max_split(self, tokens: str) -> Tuple[str, str]: class Vocab: def __init__(self, - word2idx: Mapping[str, int], + word2idx: Union[Mapping[str, int], List[str]], oov_token_id: int = None, oov_word_repr: str = '[OOV]'): + if isinstance(word2idx, List): + word2idx = {word: idx for idx, word in enumerate(word2idx)} self.word2idx = {k.lower(): v for k, v in word2idx.items()} self.idx2word = {v: k for k, v in word2idx.items()} self.oov_token_id = oov_token_id @@ -99,10 +101,10 @@ def encode(self, transcript: str) -> List[int]: encoded_output = [] for word in transcript.lower().split(): - vocab_found, transcript = self.vocab.trie.max_split(word) + vocab_found, remaining_transcript = self.vocab.trie.max_split(word) # append corresponding label - if vocab_found and transcript == "": + if vocab_found and remaining_transcript == "": # word exists in the vocab encoded_output.append(self.vocab[word]) elif not self.ignore_oov: diff --git a/howl/data/transform/base.py b/howl/data/transform/base.py index bbb74b71..ad4e9e29 100644 --- a/howl/data/transform/base.py +++ b/howl/data/transform/base.py @@ -1,15 +1,14 @@ -from typing import Sequence, Iterable, List import random +from typing import Iterable, List, Sequence import librosa.effects as effects import numpy as np import torch import torch.nn as nn - -from howl.data.dataset import WakeWordClipExample, ClassificationBatch, EmplacableExample, SequenceBatch +from howl.data.dataset import (ClassificationBatch, EmplacableExample, + SequenceBatch, WakeWordClipExample) from howl.data.tokenize import TranscriptTokenizer - __all__ = ['Composition', 'compose', 'ZmuvTransform', @@ -92,10 +91,13 @@ def tensorize_audio_data(audio_data_lst: List[torch.Tensor], max_length = max(audio_data.size(-1) for audio_data in audio_data_lst) audio_tensor = [] for audio_data in audio_data_lst: + squeezed_data = audio_data.squeeze() + if squeezed_data.dim() == 0: + squeezed_data = squeezed_data.unsqueeze(0) if rand_append and random.random() < 0.5: - x = (torch.zeros(max_length - audio_data.size(-1)), audio_data.squeeze()) + x = (torch.zeros(max_length - audio_data.size(-1)), squeezed_data) else: - x = (audio_data.squeeze(), torch.zeros(max_length - audio_data.size(-1))) + x = (squeezed_data, torch.zeros(max_length - audio_data.size(-1))) audio_tensor.append(torch.cat(x, -1)) return torch.stack(audio_tensor), extra_data_lists diff --git a/howl/settings.py b/howl/settings.py index 2128aa90..11b9676a 100644 --- a/howl/settings.py +++ b/howl/settings.py @@ -2,7 +2,6 @@ from pydantic import BaseSettings - __all__ = ['AudioSettings', 'RawDatasetSettings', 'DatasetSettings', 'SETTINGS'] @@ -34,6 +33,7 @@ class InferenceEngineSettings(BaseSettings): class TrainingSettings(BaseSettings): seed: int = 0 + # TODO:: vocab should not belong to training vocab: List[str] = ['fire'] num_epochs: int = 10 num_labels: int = 2 diff --git a/test/data/dataset.py b/test/data/dataset.py index a01e45e0..2c496582 100644 --- a/test/data/dataset.py +++ b/test/data/dataset.py @@ -45,7 +45,6 @@ def test_compute_statistics(self): SETTINGS.training.vocab = ["hello", "world"] SETTINGS.training.token_type = "word" SETTINGS.inference_engine.inference_sequence = [0, 1] - SETTINGS.inference_engine.inference_sequence = [0, 1] vocab = Vocab({"Hello": 0, "World": 1}, oov_token_id=2, oov_word_repr='') diff --git a/test/data/stitcher.py b/test/data/stitcher.py new file mode 100644 index 00000000..2bf80272 --- /dev/null +++ b/test/data/stitcher.py @@ -0,0 +1,44 @@ +import random +import unittest +from pathlib import Path + +import torch +from howl.data.dataset import WakeWordDatasetLoader, WordFrameLabeler +from howl.data.stitcher import WordStitcher +from howl.data.tokenize import Vocab +from howl.settings import SETTINGS + + +class TestStitcher(unittest.TestCase): + + def test_compute_statistics(self): + random.seed(1) + + """test compute statistics + """ + SETTINGS.training.vocab = ["hey", "fire", "fox"] + SETTINGS.training.token_type = "word" + SETTINGS.inference_engine.inference_sequence = [0, 1, 2] + + vocab = Vocab({"hey": 0, "fire": 1, "fox": 2}, oov_token_id=3, oov_word_repr='') + labeler = WordFrameLabeler(vocab) + + loader = WakeWordDatasetLoader() + ds_kwargs = dict(sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono, frame_labeler=labeler) + + test_dataset_path = Path("test/test_data") + stitched_dataset_path = test_dataset_path / "stitched" + stitched_dataset_path.mkdir(exist_ok=True) + + test_ds, _, _ = loader.load_splits(test_dataset_path, **ds_kwargs) + stitcher = WordStitcher(vocab=vocab) + stitcher.stitch(20, stitched_dataset_path, test_ds) + + stitched_train_ds, stitched_dev_ds, stitched_test_ds = stitcher.load_splits(0.5, 0.25, 0.25) + self.assertEqual(len(stitched_train_ds), 10) + self.assertEqual(len(stitched_dev_ds), 5) + self.assertEqual(len(stitched_test_ds), 5) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_data/aligned-metadata-dev.jsonl b/test/test_data/aligned-metadata-dev.jsonl new file mode 100644 index 00000000..cb69658a --- /dev/null +++ b/test/test_data/aligned-metadata-dev.jsonl @@ -0,0 +1,2 @@ +{"path": "3MH9DQ757YLOJ8YX3MFN1SFXBYWGUV.wav", "transcription": "hey fire fox", "end_timestamps": [530.0, 685.0, 840.0, 840.0, 870.0, 1000.0, 1130.0, 1260.0, 1390.0, 1520.0, 1650.0, 1780.0]} +{"path": "35BLDD71I86ON4I2Y4DHILXM0O5VZG.wav", "transcription": "hey fire fox", "end_timestamps": [770.0, 900.0, 1030.0, 1030.0, 1030.0, 1155.7142857142858, 1281.4285714285713, 1407.142857142857, 1532.857142857143, 1658.5714285714284, 1784.2857142857142, 1910.0]} diff --git a/test/test_data/aligned-metadata-test.jsonl b/test/test_data/aligned-metadata-test.jsonl new file mode 100644 index 00000000..40ee3a3b --- /dev/null +++ b/test/test_data/aligned-metadata-test.jsonl @@ -0,0 +1 @@ +{"path": "36U2A8VAG38A23EL3EXR25D3EWVKYT.wav", "transcription": "hey fire fox", "end_timestamps": [700.0, 795.0, 890.0, 890.0, 890.0, 1031.4285714285713, 1172.857142857143, 1314.2857142857142, 1455.7142857142858, 1597.142857142857, 1738.5714285714284, 1880.0]} diff --git a/test/test_data/aligned-metadata-training.jsonl b/test/test_data/aligned-metadata-training.jsonl new file mode 100644 index 00000000..0724a8be --- /dev/null +++ b/test/test_data/aligned-metadata-training.jsonl @@ -0,0 +1,12 @@ +{"path": "common_voice_en_18673330.wav", "transcription": "this fire rises through the grate and ignites the charcoal", "end_timestamps": [1230.0, 1320.0, 1410.0, 1500.0, 1500.0, 1500.0, 1663.3333333333333, 1826.6666666666667, 1990.0, 1990.0, 2020.0, 2187.5, 2355.0, 2522.5, 2690.0, 2690.0, 2690.0, 2745.0, 2800.0, 2855.0, 2910.0, 2965.0, 3020.0, 3020.0, 3020.0, 3125.0, 3230.0, 3230.0, 3260.0, 3402.5, 3545.0, 3687.5, 3830.0, 3830.0, 3900.0, 4045.0, 4190.0, 4190.0, 4190.0, 4301.666666666667, 4413.333333333333, 4525.0, 4636.666666666667, 4748.333333333333, 4860.0, 4860.0, 4890.0, 4960.0, 5030.0, 5030.0, 5030.0, 5137.142857142857, 5244.285714285715, 5351.428571428572, 5458.571428571428, 5565.714285714285, 5672.857142857143, 5780.0]} +{"path": "common_voice_en_19687522.wav", "transcription": "mark king is fire chief and gary carlson serves as fire marshal", "end_timestamps": [850.0, 976.6666666666666, 1103.3333333333333, 1230.0, 1230.0, 1230.0, 1373.3333333333333, 1516.6666666666667, 1660.0, 1660.0, 1730.0, 2040.0, 2040.0, 2040.0, 2140.0, 2240.0, 2340.0, 2340.0, 2340.0, 2470.0, 2600.0, 2730.0, 2860.0, 2860.0, 3000.0, 3125.0, 3250.0, 3250.0, 3250.0, 3350.0, 3450.0, 3550.0, 3550.0, 3550.0, 3648.3333333333335, 3746.6666666666665, 3845.0, 3943.3333333333335, 4041.6666666666665, 4140.0, 4140.0, 4140.0, 4224.0, 4308.0, 4392.0, 4476.0, 4560.0, 4560.0, 4560.0, 4740.0, 4740.0, 4740.0, 4853.333333333333, 4966.666666666667, 5080.0, 5080.0, 5080.0, 5171.666666666667, 5263.333333333333, 5355.0, 5446.666666666667, 5538.333333333333, 5630.0]} +{"path": "common_voice_en_19649226.wav", "transcription": "the song hey ", "end_timestamps": [1200.0, 1260.0, 1320.0, 1320.0, 1320.0, 1523.3333333333333, 1726.6666666666667, 1930.0, 1930.0, 2080.0, 2285.0, 2490.0, 2490.0, 2490.0, 2630.0, 2770.0, 2910.0, 3050.0]} +{"path": "common_voice_en_19964009.wav", "transcription": "this contributed to their decision to make follow up hey venus", "end_timestamps": [690.0, 760.0, 830.0, 900.0, 900.0, 900.0, 966.0, 1032.0, 1098.0, 1164.0, 1230.0, 1296.0, 1362.0, 1428.0, 1494.0, 1560.0, 1560.0, 1560.0, 1660.0, 1660.0, 1660.0, 1700.0, 1740.0, 1780.0, 1820.0, 1820.0, 1820.0, 1882.857142857143, 1945.7142857142858, 2008.5714285714284, 2071.4285714285716, 2134.285714285714, 2197.142857142857, 2260.0, 2260.0, 2260.0, 2370.0, 2370.0, 2370.0, 2446.6666666666665, 2523.3333333333335, 2600.0, 2600.0, 2600.0, 2664.0, 2728.0, 2792.0, 2856.0, 2920.0, 2920.0, 2920.0, 3220.0, 3220.0, 3220.0, 3320.0, 3420.0, 3420.0, 3420.0, 3565.0, 3710.0, 3855.0, 4000.0]} +{"path": "common_voice_en_19725016.wav", "transcription": "this prompted paul heyman to leave the announce table and enter the ring", "end_timestamps": [520.0, 623.3333333333334, 726.6666666666666, 830.0, 830.0, 830.0, 908.5714285714286, 987.1428571428571, 1065.7142857142858, 1144.2857142857142, 1222.857142857143, 1301.4285714285716, 1380.0, 1380.0, 1610.0, 1716.6666666666667, 1823.3333333333333, 1930.0, 1930.0, 1930.0, 2038.0, 2146.0, 2254.0, 2362.0, 2470.0, 2470.0, 2570.0, 2860.0, 2860.0, 2860.0, 2952.5, 3045.0, 3137.5, 3230.0, 3230.0, 3230.0, 3290.0, 3350.0, 3350.0, 3350.0, 3435.714285714286, 3521.4285714285716, 3607.142857142857, 3692.8571428571427, 3778.5714285714284, 3864.285714285714, 3950.0, 3950.0, 3950.0, 4082.5, 4215.0, 4347.5, 4480.0, 4480.0, 4480.0, 4570.0, 4660.0, 4660.0, 4660.0, 4790.0, 4920.0, 5050.0, 5180.0, 5180.0, 5180.0, 5285.0, 5390.0, 5390.0, 5460.0, 5560.0, 5660.0, 5760.0]} +{"path": "common_voice_en_19963788.wav", "transcription": "juju performed the song at hey", "end_timestamps": [880.0, 1126.6666666666667, 1373.3333333333333, 1620.0, 1620.0, 1620.0, 1683.75, 1747.5, 1811.25, 1875.0, 1938.75, 2002.5, 2066.25, 2130.0, 2130.0, 2130.0, 2165.0, 2200.0, 2200.0, 2200.0, 2340.0, 2480.0, 2620.0, 2620.0, 2620.0, 2920.0, 2920.0, 2920.0, 3220.0, 3520.0]} +{"path": "common_voice_en_19946691.wav", "transcription": "although hey ya", "end_timestamps": [820.0, 910.0, 1000.0, 1090.0, 1180.0, 1270.0, 1360.0, 1450.0, 1450.0, 1630.0, 1765.0, 1900.0, 1900.0, 1930.0, 2380.0]} +{"path": "common_voice_en_19632093.wav", "transcription": "the school is a member of the illinois high school association fox valley conference", "end_timestamps": [930.0, 990.0, 1050.0, 1050.0, 1050.0, 1116.0, 1182.0, 1248.0, 1314.0, 1380.0, 1380.0, 1380.0, 1450.0, 1450.0, 1450.0, 1500.0, 1500.0, 1586.0, 1672.0, 1758.0, 1844.0, 1930.0, 1930.0, 1930.0, 2070.0, 2070.0, 2070.0, 2135.0, 2200.0, 2200.0, 2200.0, 2252.8571428571427, 2305.714285714286, 2358.5714285714284, 2411.4285714285716, 2464.285714285714, 2517.142857142857, 2570.0, 2570.0, 2570.0, 2666.6666666666665, 2763.3333333333335, 2860.0, 2860.0, 2860.0, 2906.0, 2952.0, 2998.0, 3044.0, 3090.0, 3090.0, 3090.0, 3185.0, 3280.0, 3375.0, 3470.0, 3565.0, 3660.0, 3755.0, 3850.0, 3945.0, 4040.0, 4040.0, 4040.0, 4205.0, 4370.0, 4370.0, 4370.0, 4434.0, 4498.0, 4562.0, 4626.0, 4690.0, 4690.0, 4690.0, 4768.888888888889, 4847.777777777777, 4926.666666666667, 5005.555555555556, 5084.444444444444, 5163.333333333333, 5242.222222222223, 5321.111111111111, 5400.0]} +{"path": "common_voice_en_19644798.wav", "transcription": "the book incorporates hours of interviews with fox and many former tropicana employees", "end_timestamps": [470.0, 635.0, 800.0, 800.0, 800.0, 896.6666666666666, 993.3333333333334, 1090.0, 1090.0, 1090.0, 1178.1818181818182, 1266.3636363636365, 1354.5454545454545, 1442.7272727272727, 1530.909090909091, 1619.090909090909, 1707.2727272727273, 1795.4545454545455, 1883.6363636363637, 1971.818181818182, 2060.0, 2060.0, 2090.0, 2207.5, 2325.0, 2442.5, 2560.0, 2560.0, 2560.0, 2700.0, 2700.0, 2700.0, 2787.777777777778, 2875.5555555555557, 2963.3333333333335, 3051.1111111111113, 3138.8888888888887, 3226.6666666666665, 3314.4444444444443, 3402.222222222222, 3490.0, 3490.0, 3490.0, 3566.6666666666665, 3643.3333333333335, 3720.0, 3720.0, 3720.0, 3980.0, 4240.0, 4240.0, 4240.0, 4370.0, 4500.0, 4500.0, 4500.0, 4616.666666666667, 4733.333333333333, 4850.0, 4850.0, 4850.0, 4946.0, 5042.0, 5138.0, 5234.0, 5330.0, 5330.0, 5330.0, 5426.25, 5522.5, 5618.75, 5715.0, 5811.25, 5907.5, 6003.75, 6100.0, 6100.0, 6100.0, 6212.5, 6325.0, 6437.5, 6550.0, 6662.5, 6775.0, 6887.5, 7000.0]} +{"path": "common_voice_en_19768928.wav", "transcription": "it also continued to broadcast fox kids programming on weekend mornings", "end_timestamps": [720.0, 900.0, 900.0, 900.0, 986.6666666666666, 1073.3333333333333, 1160.0, 1160.0, 1160.0, 1227.5, 1295.0, 1362.5, 1430.0, 1497.5, 1565.0, 1632.5, 1700.0, 1700.0, 1700.0, 1850.0, 1850.0, 1850.0, 1930.0, 2010.0, 2090.0, 2170.0, 2250.0, 2330.0, 2410.0, 2490.0, 2490.0, 2490.0, 2690.0, 2890.0, 2890.0, 2890.0, 3010.0, 3130.0, 3250.0, 3250.0, 3250.0, 3324.0, 3398.0, 3472.0, 3546.0, 3620.0, 3694.0, 3768.0, 3842.0, 3916.0, 3990.0, 3990.0, 4290.0, 4480.0, 4480.0, 4480.0, 4541.666666666667, 4603.333333333333, 4665.0, 4726.666666666667, 4788.333333333333, 4850.0, 4850.0, 4850.0, 4942.857142857143, 5035.714285714285, 5128.571428571428, 5221.428571428572, 5314.285714285715, 5407.142857142857, 5500.0]} +{"path": "common_voice_en_19895325.wav", "transcription": "metra service and track ownership ends at fox lake", "end_timestamps": [310.0, 377.5, 445.0, 512.5, 580.0, 580.0, 580.0, 653.3333333333334, 726.6666666666666, 800.0, 873.3333333333333, 946.6666666666666, 1020.0, 1020.0, 1020.0, 1085.0, 1150.0, 1150.0, 1150.0, 1240.0, 1330.0, 1420.0, 1510.0, 1510.0, 1540.0, 1592.5, 1645.0, 1697.5, 1750.0, 1802.5, 1855.0, 1907.5, 1960.0, 1960.0, 2029.9999999999998, 2136.6666666666665, 2243.333333333333, 2350.0, 2350.0, 2350.0, 2460.0, 2460.0, 2520.0, 2670.0, 2820.0, 2820.0, 2820.0, 2916.6666666666665, 3013.3333333333335, 3110.0]} +{"path": "common_voice_en_19745716.wav", "transcription": "it had the highest ratings of any fox network television special to that date", "end_timestamps": [1100.0, 1250.0, 1250.0, 1250.0, 1390.0, 1530.0, 1530.0, 1530.0, 1570.0, 1610.0, 1610.0, 1610.0, 1691.6666666666667, 1773.3333333333333, 1855.0, 1936.6666666666667, 2018.3333333333335, 2100.0, 2100.0, 2100.0, 2175.0, 2250.0, 2325.0, 2400.0, 2475.0, 2550.0, 2550.0, 2550.0, 2700.0, 2700.0, 2700.0, 2855.0, 3010.0, 3010.0, 3040.0, 3215.0, 3390.0, 3390.0, 3390.0, 3465.0, 3540.0, 3615.0, 3690.0, 3765.0, 3840.0, 3840.0, 3840.0, 3913.3333333333335, 3986.6666666666665, 4060.0, 4133.333333333333, 4206.666666666667, 4280.0, 4353.333333333333, 4426.666666666667, 4500.0, 4500.0, 4500.0, 4586.666666666667, 4673.333333333333, 4760.0, 4846.666666666667, 4933.333333333333, 5020.0, 5020.0, 5020.0, 5220.0, 5220.0, 5220.0, 5310.0, 5400.0, 5490.0, 5490.0, 5540.0, 5676.666666666667, 5813.333333333333, 5950.0]} \ No newline at end of file diff --git a/test/test_data/audio/35BLDD71I86ON4I2Y4DHILXM0O5VZG.wav b/test/test_data/audio/35BLDD71I86ON4I2Y4DHILXM0O5VZG.wav new file mode 100644 index 00000000..3b96ec3d Binary files /dev/null and b/test/test_data/audio/35BLDD71I86ON4I2Y4DHILXM0O5VZG.wav differ diff --git a/test/test_data/audio/36U2A8VAG38A23EL3EXR25D3EWVKYT.wav b/test/test_data/audio/36U2A8VAG38A23EL3EXR25D3EWVKYT.wav new file mode 100644 index 00000000..b4164a68 Binary files /dev/null and b/test/test_data/audio/36U2A8VAG38A23EL3EXR25D3EWVKYT.wav differ diff --git a/test/test_data/audio/3MH9DQ757YLOJ8YX3MFN1SFXBYWGUV.wav b/test/test_data/audio/3MH9DQ757YLOJ8YX3MFN1SFXBYWGUV.wav new file mode 100644 index 00000000..46425c78 Binary files /dev/null and b/test/test_data/audio/3MH9DQ757YLOJ8YX3MFN1SFXBYWGUV.wav differ diff --git a/test/test_data/audio/common_voice_en_18673330.wav b/test/test_data/audio/common_voice_en_18673330.wav new file mode 100644 index 00000000..1b88c5ee Binary files /dev/null and b/test/test_data/audio/common_voice_en_18673330.wav differ diff --git a/test/test_data/audio/common_voice_en_19632093.wav b/test/test_data/audio/common_voice_en_19632093.wav new file mode 100644 index 00000000..ca9ef3cb Binary files /dev/null and b/test/test_data/audio/common_voice_en_19632093.wav differ diff --git a/test/test_data/audio/common_voice_en_19644798.wav b/test/test_data/audio/common_voice_en_19644798.wav new file mode 100644 index 00000000..b13c7684 Binary files /dev/null and b/test/test_data/audio/common_voice_en_19644798.wav differ diff --git a/test/test_data/audio/common_voice_en_19649226.wav b/test/test_data/audio/common_voice_en_19649226.wav new file mode 100644 index 00000000..afcb27be Binary files /dev/null and b/test/test_data/audio/common_voice_en_19649226.wav differ diff --git a/test/test_data/audio/common_voice_en_19687522.wav b/test/test_data/audio/common_voice_en_19687522.wav new file mode 100644 index 00000000..0e4c50b8 Binary files /dev/null and b/test/test_data/audio/common_voice_en_19687522.wav differ diff --git a/test/test_data/audio/common_voice_en_19725016.wav b/test/test_data/audio/common_voice_en_19725016.wav new file mode 100644 index 00000000..6edf5069 Binary files /dev/null and b/test/test_data/audio/common_voice_en_19725016.wav differ diff --git a/test/test_data/audio/common_voice_en_19745716.wav b/test/test_data/audio/common_voice_en_19745716.wav new file mode 100644 index 00000000..a5085e80 Binary files /dev/null and b/test/test_data/audio/common_voice_en_19745716.wav differ diff --git a/test/test_data/audio/common_voice_en_19768928.wav b/test/test_data/audio/common_voice_en_19768928.wav new file mode 100644 index 00000000..e94a7b7c Binary files /dev/null and b/test/test_data/audio/common_voice_en_19768928.wav differ diff --git a/test/test_data/audio/common_voice_en_19895325.wav b/test/test_data/audio/common_voice_en_19895325.wav new file mode 100644 index 00000000..054e7b29 Binary files /dev/null and b/test/test_data/audio/common_voice_en_19895325.wav differ diff --git a/test/test_data/audio/common_voice_en_19946691.wav b/test/test_data/audio/common_voice_en_19946691.wav new file mode 100644 index 00000000..be86bf3f Binary files /dev/null and b/test/test_data/audio/common_voice_en_19946691.wav differ diff --git a/test/test_data/audio/common_voice_en_19963788.wav b/test/test_data/audio/common_voice_en_19963788.wav new file mode 100644 index 00000000..8bead452 Binary files /dev/null and b/test/test_data/audio/common_voice_en_19963788.wav differ diff --git a/test/test_data/audio/common_voice_en_19964009.wav b/test/test_data/audio/common_voice_en_19964009.wav new file mode 100644 index 00000000..4b059ee0 Binary files /dev/null and b/test/test_data/audio/common_voice_en_19964009.wav differ diff --git a/train_model.sh b/train_model.sh index adcbc278..3e6482ca 100755 --- a/train_model.sh +++ b/train_model.sh @@ -15,15 +15,12 @@ echo "MODEL_TYPE: ${MODEL_TYPE}" echo "WORKSPACE_PATH: ${WORKSPACE_PATH}" echo "DATASET_PATHS: ${@:4}" -DATASET_ARGUMENT="" +DATASET_ARGUMENT="--dataset-paths" for DATASET_PATH in ${@:4}; do - DATASET_ARGUMENT+=" --dataset-paths ${DATASET_PATH}" + DATASET_ARGUMENT+=" ${DATASET_PATH}" done source ${ENV_FILE_PATH} echo ">>> training a model for ${VOCAB}; model will be stored at ${WORKSPACE_PATH}" time python -m training.run.train --model ${MODEL_TYPE} --workspace "${WORKSPACE_PATH}" ${DATASET_ARGUMENT} - -echo ">>> evaluating the trained model" -time python -m training.run.train --eval --model ${MODEL_TYPE} --workspace "${WORKSPACE_PATH}" ${DATASET_ARGUMENT} diff --git a/training/run/demo.py b/training/run/demo.py index 8e48e910..502084ae 100644 --- a/training/run/demo.py +++ b/training/run/demo.py @@ -4,7 +4,6 @@ import numpy as np import pyaudio import torch - from howl.context import InferenceContext from howl.data.transform import ZmuvTransform from howl.model import RegisteredModel, Workspace @@ -42,6 +41,7 @@ def __init__(self, self._audio_buf = [] self.device = device self.stream = stream + print(f"ready to detect {' '.join(self.words[x] for x in self.engine.sequence)}") stream.start_stream() def join(self): @@ -52,7 +52,7 @@ def _on_audio(self, in_data, frame_count, time_info, status): data_ok = (in_data, pyaudio.paContinue) self.last_data = in_data self._audio_buf.append(in_data) - if len(self._audio_buf) != (self.engine.sample_rate / self.chunk_size): # 1 sec window + if len(self._audio_buf) != (self.engine.sample_rate / self.chunk_size): # 1 sec window return data_ok audio_data = b''.join(self._audio_buf) self._audio_buf = self._audio_buf[2:] @@ -101,5 +101,6 @@ def main(): client = InferenceClient(engine, device, SETTINGS.training.vocab) client.join() + if __name__ == '__main__': main() diff --git a/training/run/stitch_vocab_samples.py b/training/run/stitch_vocab_samples.py new file mode 100644 index 00000000..ed28c9d0 --- /dev/null +++ b/training/run/stitch_vocab_samples.py @@ -0,0 +1,68 @@ +import argparse +from itertools import chain +from pathlib import Path + +from howl.data.dataset import (AudioClipDatasetLoader, + AudioDatasetMetadataWriter, AudioDatasetWriter, + WakeWordDatasetLoader, WordFrameLabeler) +from howl.data.dataset.base import AudioClipMetadata +from howl.data.searcher import WordTranscriptSearcher +from howl.data.stitcher import WordStitcher +from howl.data.tokenize import Vocab +from howl.settings import SETTINGS +from textgrids import TextGrid +from tqdm import tqdm +from training.align import MfaTextGridConverter, StubAligner + +"""Using aligned dataset, generate wakeword samples by stitching vocab samples + +VOCAB="[vocab1,vocab2,vocab3]" INFERENCE_SEQUENCE=[1,2,3] \ + python -m training.run.stitch_vocab_samples \ + --aligned-dataset "aligned-dataset" \ + --stitched-dataset "stitched-dataset" \ + --num-stitched-samples 500 \ + --stitched-dataset-pct [0.5, 0.25, 0.25] +""" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--stitched-dataset', type=str, default='', + help='if provided, stitched wakeword samples are saved to the given path. (default: dataset-path/stitched)') + parser.add_argument('--aligned-dataset', type=str, + help='dataset with frame labelled samples which stitched wakeword samples are generated from') + parser.add_argument('--num-stitched-samples', type=int, default=10000, + help='number of stitched wakeword samples to geneate (default: 10000)') + parser.add_argument('--stitched-dataset-pct', type=int, nargs=3, default=[0.5, 0.25, 0.25], + help='train/dev/test pct for stitched dataset (default: [0.5, 0.25, 0.25])') + + args = parser.parse_args() + aligned_ds_path = Path(args.aligned_dataset) + stitched_ds_path = aligned_ds_path / 'stitched' if args.stitched_dataset == '' else Path(args.stitched_dataset) + stitched_ds_path.mkdir(exist_ok=True) + + vocab = Vocab(SETTINGS.training.vocab) + labeler = WordFrameLabeler(vocab) + ds_kwargs = dict(sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono, frame_labeler=labeler) + + # load aligned datasets + train_ds, dev_ds, test_ds = WakeWordDatasetLoader().load_splits(aligned_ds_path, **ds_kwargs) + + # stitch vocab samples + stitcher = WordStitcher(vocab=vocab) + stitcher.stitch(args.num_stitched_samples, stitched_ds_path, train_ds, dev_ds, test_ds) + + # split the stitched samples + stitched_train_ds, stitched_dev_ds, stitched_test_ds = stitcher.load_splits(*args.stitched_dataset_pct) + + # save metadata + for ds in stitched_train_ds, stitched_dev_ds, stitched_test_ds: + try: + AudioDatasetWriter(ds, prefix='aligned-').write(stitched_ds_path) + except KeyboardInterrupt: + print('Skipping...') + pass + + +if __name__ == '__main__': + main() diff --git a/training/run/train.py b/training/run/train.py index 3c9a4c08..ce7037d4 100644 --- a/training/run/train.py +++ b/training/run/train.py @@ -128,12 +128,14 @@ def do_evaluate(): print_stats(f'Wake word dataset', ctx, ww_train_ds, ww_dev_ds, ww_test_ds) ww_dev_pos_ds = ww_dev_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True) + print_stats('Dev pos dataset', ctx, ww_dev_pos_ds) ww_dev_neg_ds = ww_dev_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True) + print_stats('Dev neg dataset', ctx, ww_dev_neg_ds) ww_test_pos_ds = ww_test_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True) + print_stats('Test pos dataset', ctx, ww_test_pos_ds) ww_test_neg_ds = ww_test_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True) + print_stats('Test neg dataset', ctx, ww_test_neg_ds) - print_stats('Dev dataset', ctx, ww_dev_pos_ds, ww_dev_neg_ds) - print_stats('Test dataset', ctx, ww_test_pos_ds, ww_test_neg_ds) device = torch.device(SETTINGS.training.device) std_transform = StandardAudioTransform().to(device).eval() zmuv_transform = ZmuvTransform().to(device)