diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 289589b6..8f631409 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,10 +14,11 @@ repos: - id: trailing-whitespace - id: check-json - id: name-tests-test -- repo: https://github.com/timothycrosley/isort - rev: 5.2.0 +- repo: https://github.com/pycqa/isort + rev: 5.5.4 hooks: - id: isort + args: ["--profile", "black"] - repo: https://github.com/psf/black rev: 19.10b0 hooks: @@ -27,6 +28,7 @@ repos: rev: 3.8.3 hooks: - id: flake8 + args: [--max-line-length=120] - repo: https://github.com/pre-commit/mirrors-pylint rev: v2.6.0 hooks: diff --git a/howl/data/stitcher.py b/howl/data/stitcher.py index 875ead01..1df0e431 100644 --- a/howl/data/stitcher.py +++ b/howl/data/stitcher.py @@ -1,4 +1,3 @@ -import itertools import random from dataclasses import dataclass from pathlib import Path @@ -6,13 +5,20 @@ import soundfile import torch -from howl.data.dataset import (AudioClipDataset, AudioClipExample, - AudioClipMetadata, AudioDataset, DatasetType) +from tqdm import tqdm + +from howl.data.dataset import ( + AudioClipDataset, + AudioClipExample, + AudioClipMetadata, + AudioDataset, + DatasetType, +) from howl.data.tokenize import Vocab from howl.settings import SETTINGS -from tqdm import tqdm +from howl.utils.sphinx_keyword_detector import SphinxKeywordDetector -__all__ = ['WordStitcher'] +__all__ = ["WordStitcher"] @dataclass @@ -24,18 +30,27 @@ class FrameLabelledSample: class Stitcher: - def __init__(self, - vocab: Vocab): + def __init__(self, vocab: Vocab, detect_keyword: bool = True): + """Base Stitcher class + + Args: + vocab (Vocab): vocab containing wakeword + detect_keyword (bool, optional): drop invalid stitched samples through secondary keyword detection step + """ self.sequence = SETTINGS.inference_engine.inference_sequence self.sr = SETTINGS.audio.sample_rate self.vocab = vocab - self.wakeword = ' '.join(self.vocab[x] - for x in self.sequence) + self.wakeword = " ".join(self.vocab[x] for x in self.sequence) + + self.detect_keyword = detect_keyword + self.keyword_detector = [] + if self.detect_keyword: + for x in self.sequence: + self.keyword_detector.append(SphinxKeywordDetector(self.vocab[x])) class WordStitcher(Stitcher): - def __init__(self, - **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) def concatenate_end_timestamps(self, end_timestamps_list: List[List[float]]) -> List[float]: @@ -60,7 +75,7 @@ def concatenate_end_timestamps(self, end_timestamps_list: List[List[float]]) -> return concatnated_timestamps[:-1] # discard last space timestamp - def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, * datasets: AudioDataset): + def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, *datasets: AudioDataset): """collect vocab samples from datasets and generate stitched wakeword samples Args: @@ -85,8 +100,14 @@ def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, * datase for char_idx in char_indices: adjusted_end_timestamps.append(sample.metadata.end_timestamps[char_idx] - start_timestamp) - sample_set[label].append(FrameLabelledSample( - sample.audio_data[audio_start_idx:audio_end_idx], end_timestamp-start_timestamp, adjusted_end_timestamps, label)) + sample_set[label].append( + FrameLabelledSample( + sample.audio_data[audio_start_idx:audio_end_idx], + end_timestamp - start_timestamp, + adjusted_end_timestamps, + label, + ) + ) audio_dir = stitched_dataset_dir / "audio" audio_dir.mkdir(exist_ok=True) @@ -100,34 +121,63 @@ def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, * datase # generate AudioClipExample for each vocab sample self.stitched_samples = [] - for sample_idx in tqdm(range(num_stitched_samples), desc="Generating stitched samples"): + + pbar = tqdm(total=num_stitched_samples, desc="Generating stitched samples") + sample_idx = 0 + num_skipped_samples = 0 + while True: + if sample_idx == num_stitched_samples: + break + sample_set = [] for sample_list in sample_lists: sample_set.append(random.choice(sample_list)) + audio_data = torch.cat([labelled_data.audio_data for labelled_data in sample_set]) + + if self.detect_keyword: + temp_audio_file_path = "/tmp/temp.wav" + soundfile.write(temp_audio_file_path, audio_data.numpy(), self.sr) + + keyword_exists = True + for detector in self.keyword_detector: + # sphinx keyword detection may not be sufficient for audio with repeated words + if len(detector.detect(temp_audio_file_path)) == 0: + keyword_exists = False + break + + if keyword_exists: + num_skipped_samples += 1 + continue + metatdata = AudioClipMetadata( path=Path(audio_dir / f"{sample_idx}").with_suffix(".wav"), transcription=self.wakeword, end_timestamps=self.concatenate_end_timestamps( - [labelled_data.end_timestamps for labelled_data in sample_set]) + [labelled_data.end_timestamps for labelled_data in sample_set] + ), ) # TODO:: dataset writer load the samples upon write and does not use data in memory # writer classes need to be refactored to use audio data if exist - audio_data = torch.cat([labelled_data.audio_data for labelled_data in sample_set]) soundfile.write(metatdata.path, audio_data.numpy(), self.sr) - stitched_sample = AudioClipExample( - metadata=metatdata, - audio_data=audio_data, - sample_rate=self.sr) + stitched_sample = AudioClipExample(metadata=metatdata, audio_data=audio_data, sample_rate=self.sr) self.stitched_samples.append(stitched_sample) - def load_splits(self, - train_pct: float, - dev_pct: float, - test_pct: float) -> Tuple[AudioClipDataset, AudioClipDataset, AudioClipDataset]: + sample_idx += 1 + pbar.update() + + if self.detect_keyword: + print( + f"While generating {num_stitched_samples} stithced samples, " + f"{num_skipped_samples} are filtered by keyword detection" + ) + + def load_splits( + self, train_pct: float, dev_pct: float, test_pct: float + ) -> Tuple[AudioClipDataset, AudioClipDataset, AudioClipDataset]: """split the generated stitched samples based on the given pct first train_pct samples are used to generate train set next dev_pct samples are used to generate dev set @@ -161,6 +211,8 @@ def load_splits(self, test_split.append(sample.metadata) ds_kwargs = dict(sr=self.sr, mono=SETTINGS.audio.use_mono) - return (AudioClipDataset(metadata_list=train_split, set_type=DatasetType.TRAINING, **ds_kwargs), - AudioClipDataset(metadata_list=dev_split, set_type=DatasetType.DEV, **ds_kwargs), - AudioClipDataset(metadata_list=test_split, set_type=DatasetType.TEST, **ds_kwargs)) + return ( + AudioClipDataset(metadata_list=train_split, set_type=DatasetType.TRAINING, **ds_kwargs), + AudioClipDataset(metadata_list=dev_split, set_type=DatasetType.DEV, **ds_kwargs), + AudioClipDataset(metadata_list=test_split, set_type=DatasetType.TEST, **ds_kwargs), + ) diff --git a/howl/utils/sphinx_keyword_detector.py b/howl/utils/sphinx_keyword_detector.py new file mode 100644 index 00000000..36fb505d --- /dev/null +++ b/howl/utils/sphinx_keyword_detector.py @@ -0,0 +1,35 @@ +import os + +from pocketsphinx import AudioFile + + +class SphinxKeywordDetector(): + def __init__(self, target_transcription, threshold=1e-20, verbose=False): + self.target_transcription = target_transcription + self.verbose = verbose + self.kws_config = { + 'verbose': self.verbose, + 'keyphrase': self.target_transcription, + 'kws_threshold': threshold, + 'lm': False, + } + + def detect(self, file_name): + + kws_results = [] + + self.kws_config['audio_file'] = file_name + audio = AudioFile(**self.kws_config) + + for phrase in audio: + result = phrase.segments(detailed=True) + + # TODO:: confirm that when multiple keywords are detected, every detection is valid + if len(result) == 1: + start_time = result[0][2] * 10 + end_time = result[0][3] * 10 + if self.verbose: + print('%4sms ~ %4sms' % (start_time, end_time)) + kws_results.append((start_time, end_time)) + + return kws_results diff --git a/requirements_training.txt b/requirements_training.txt index b79ce2e3..a15c2430 100644 --- a/requirements_training.txt +++ b/requirements_training.txt @@ -1,5 +1,6 @@ openpyxl +pocketsphinx==0.1.15 praat-textgrids==1.3.1 -webrtcvad==2.0.10 -pytest pre-commit +pytest +webrtcvad==2.0.10 diff --git a/test/data/stitcher_test.py b/test/data/stitcher_test.py index 2bf80272..c81e2fd1 100644 --- a/test/data/stitcher_test.py +++ b/test/data/stitcher_test.py @@ -2,7 +2,6 @@ import unittest from pathlib import Path -import torch from howl.data.dataset import WakeWordDatasetLoader, WordFrameLabeler from howl.data.stitcher import WordStitcher from howl.data.tokenize import Vocab @@ -10,7 +9,6 @@ class TestStitcher(unittest.TestCase): - def test_compute_statistics(self): random.seed(1) @@ -20,25 +18,26 @@ def test_compute_statistics(self): SETTINGS.training.token_type = "word" SETTINGS.inference_engine.inference_sequence = [0, 1, 2] - vocab = Vocab({"hey": 0, "fire": 1, "fox": 2}, oov_token_id=3, oov_word_repr='') + vocab = Vocab({"hey": 0, "fire": 1, "fox": 2}, oov_token_id=3, oov_word_repr="") labeler = WordFrameLabeler(vocab) loader = WakeWordDatasetLoader() ds_kwargs = dict(sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono, frame_labeler=labeler) - test_dataset_path = Path("test/test_data") + test_dataset_path = Path("test/test_data/stitcher") stitched_dataset_path = test_dataset_path / "stitched" stitched_dataset_path.mkdir(exist_ok=True) test_ds, _, _ = loader.load_splits(test_dataset_path, **ds_kwargs) - stitcher = WordStitcher(vocab=vocab) + stitcher = WordStitcher(vocab=vocab, detect_keyword=True) stitcher.stitch(20, stitched_dataset_path, test_ds) stitched_train_ds, stitched_dev_ds, stitched_test_ds = stitcher.load_splits(0.5, 0.25, 0.25) + self.assertEqual(len(stitched_train_ds), 10) self.assertEqual(len(stitched_dev_ds), 5) self.assertEqual(len(stitched_test_ds), 5) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/test_data/sphinx_keyword_detector/hello_world.wav b/test/test_data/sphinx_keyword_detector/hello_world.wav new file mode 100644 index 00000000..8c771e18 Binary files /dev/null and b/test/test_data/sphinx_keyword_detector/hello_world.wav differ diff --git a/test/test_data/sphinx_keyword_detector/hey_fire_fox.wav b/test/test_data/sphinx_keyword_detector/hey_fire_fox.wav new file mode 100644 index 00000000..7707c584 Binary files /dev/null and b/test/test_data/sphinx_keyword_detector/hey_fire_fox.wav differ diff --git a/test/test_data/.gitignore b/test/test_data/stitcher/.gitignore similarity index 65% rename from test/test_data/.gitignore rename to test/test_data/stitcher/.gitignore index d40b5a73..698bd0a4 100644 --- a/test/test_data/.gitignore +++ b/test/test_data/stitcher/.gitignore @@ -1 +1,2 @@ +stitched stitched_dataset diff --git a/test/test_data/aligned-metadata-dev.jsonl b/test/test_data/stitcher/aligned-metadata-dev.jsonl similarity index 100% rename from test/test_data/aligned-metadata-dev.jsonl rename to test/test_data/stitcher/aligned-metadata-dev.jsonl diff --git a/test/test_data/aligned-metadata-test.jsonl b/test/test_data/stitcher/aligned-metadata-test.jsonl similarity index 100% rename from test/test_data/aligned-metadata-test.jsonl rename to test/test_data/stitcher/aligned-metadata-test.jsonl diff --git a/test/test_data/aligned-metadata-training.jsonl b/test/test_data/stitcher/aligned-metadata-training.jsonl similarity index 100% rename from test/test_data/aligned-metadata-training.jsonl rename to test/test_data/stitcher/aligned-metadata-training.jsonl diff --git a/test/test_data/audio/35BLDD71I86ON4I2Y4DHILXM0O5VZG.wav b/test/test_data/stitcher/audio/35BLDD71I86ON4I2Y4DHILXM0O5VZG.wav similarity index 100% rename from test/test_data/audio/35BLDD71I86ON4I2Y4DHILXM0O5VZG.wav rename to test/test_data/stitcher/audio/35BLDD71I86ON4I2Y4DHILXM0O5VZG.wav diff --git a/test/test_data/audio/36U2A8VAG38A23EL3EXR25D3EWVKYT.wav b/test/test_data/stitcher/audio/36U2A8VAG38A23EL3EXR25D3EWVKYT.wav similarity index 100% rename from test/test_data/audio/36U2A8VAG38A23EL3EXR25D3EWVKYT.wav rename to test/test_data/stitcher/audio/36U2A8VAG38A23EL3EXR25D3EWVKYT.wav diff --git a/test/test_data/audio/3MH9DQ757YLOJ8YX3MFN1SFXBYWGUV.wav b/test/test_data/stitcher/audio/3MH9DQ757YLOJ8YX3MFN1SFXBYWGUV.wav similarity index 100% rename from test/test_data/audio/3MH9DQ757YLOJ8YX3MFN1SFXBYWGUV.wav rename to test/test_data/stitcher/audio/3MH9DQ757YLOJ8YX3MFN1SFXBYWGUV.wav diff --git a/test/test_data/audio/common_voice_en_18673330.wav b/test/test_data/stitcher/audio/common_voice_en_18673330.wav similarity index 100% rename from test/test_data/audio/common_voice_en_18673330.wav rename to test/test_data/stitcher/audio/common_voice_en_18673330.wav diff --git a/test/test_data/audio/common_voice_en_19632093.wav b/test/test_data/stitcher/audio/common_voice_en_19632093.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19632093.wav rename to test/test_data/stitcher/audio/common_voice_en_19632093.wav diff --git a/test/test_data/audio/common_voice_en_19644798.wav b/test/test_data/stitcher/audio/common_voice_en_19644798.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19644798.wav rename to test/test_data/stitcher/audio/common_voice_en_19644798.wav diff --git a/test/test_data/audio/common_voice_en_19649226.wav b/test/test_data/stitcher/audio/common_voice_en_19649226.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19649226.wav rename to test/test_data/stitcher/audio/common_voice_en_19649226.wav diff --git a/test/test_data/audio/common_voice_en_19687522.wav b/test/test_data/stitcher/audio/common_voice_en_19687522.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19687522.wav rename to test/test_data/stitcher/audio/common_voice_en_19687522.wav diff --git a/test/test_data/audio/common_voice_en_19725016.wav b/test/test_data/stitcher/audio/common_voice_en_19725016.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19725016.wav rename to test/test_data/stitcher/audio/common_voice_en_19725016.wav diff --git a/test/test_data/audio/common_voice_en_19745716.wav b/test/test_data/stitcher/audio/common_voice_en_19745716.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19745716.wav rename to test/test_data/stitcher/audio/common_voice_en_19745716.wav diff --git a/test/test_data/audio/common_voice_en_19768928.wav b/test/test_data/stitcher/audio/common_voice_en_19768928.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19768928.wav rename to test/test_data/stitcher/audio/common_voice_en_19768928.wav diff --git a/test/test_data/audio/common_voice_en_19895325.wav b/test/test_data/stitcher/audio/common_voice_en_19895325.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19895325.wav rename to test/test_data/stitcher/audio/common_voice_en_19895325.wav diff --git a/test/test_data/audio/common_voice_en_19946691.wav b/test/test_data/stitcher/audio/common_voice_en_19946691.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19946691.wav rename to test/test_data/stitcher/audio/common_voice_en_19946691.wav diff --git a/test/test_data/audio/common_voice_en_19963788.wav b/test/test_data/stitcher/audio/common_voice_en_19963788.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19963788.wav rename to test/test_data/stitcher/audio/common_voice_en_19963788.wav diff --git a/test/test_data/audio/common_voice_en_19964009.wav b/test/test_data/stitcher/audio/common_voice_en_19964009.wav similarity index 100% rename from test/test_data/audio/common_voice_en_19964009.wav rename to test/test_data/stitcher/audio/common_voice_en_19964009.wav diff --git a/test/utils/sphinx_keyword_detector_test.py b/test/utils/sphinx_keyword_detector_test.py new file mode 100644 index 00000000..8b3ce383 --- /dev/null +++ b/test/utils/sphinx_keyword_detector_test.py @@ -0,0 +1,28 @@ +import unittest + +from howl.utils.sphinx_keyword_detector import SphinxKeywordDetector + + +class TestSphinxKeywordDetector(unittest.TestCase): + + def test_detect(self): + """test word detection from an audio file + """ + + hello_world_file = "test/test_data/sphinx_keyword_detector/hello_world.wav" + hello_extractor = SphinxKeywordDetector("hello") + self.assertTrue(len(hello_extractor.detect(hello_world_file)) > 0) + world_extractor = SphinxKeywordDetector("world") + self.assertTrue(len(world_extractor.detect(hello_world_file)) > 0) + + hey_fire_fox_file = "test/test_data/sphinx_keyword_detector/hey_fire_fox.wav" + hey_extractor = SphinxKeywordDetector("hey") + self.assertTrue(len(hey_extractor.detect(hey_fire_fox_file)) > 0) + fire_extractor = SphinxKeywordDetector("fire") + self.assertTrue(len(fire_extractor.detect(hey_fire_fox_file)) > 0) + fox_extractor = SphinxKeywordDetector("fox") + self.assertTrue(len(fox_extractor.detect(hey_fire_fox_file)) > 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/training/run/stitch_vocab_samples.py b/training/run/stitch_vocab_samples.py index ed28c9d0..d64ea1db 100644 --- a/training/run/stitch_vocab_samples.py +++ b/training/run/stitch_vocab_samples.py @@ -1,18 +1,14 @@ import argparse -from itertools import chain from pathlib import Path -from howl.data.dataset import (AudioClipDatasetLoader, - AudioDatasetMetadataWriter, AudioDatasetWriter, - WakeWordDatasetLoader, WordFrameLabeler) -from howl.data.dataset.base import AudioClipMetadata -from howl.data.searcher import WordTranscriptSearcher +from howl.data.dataset import ( + AudioDatasetWriter, + WakeWordDatasetLoader, + WordFrameLabeler, +) from howl.data.stitcher import WordStitcher from howl.data.tokenize import Vocab from howl.settings import SETTINGS -from textgrids import TextGrid -from tqdm import tqdm -from training.align import MfaTextGridConverter, StubAligner """Using aligned dataset, generate wakeword samples by stitching vocab samples @@ -27,18 +23,37 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument('--stitched-dataset', type=str, default='', - help='if provided, stitched wakeword samples are saved to the given path. (default: dataset-path/stitched)') - parser.add_argument('--aligned-dataset', type=str, - help='dataset with frame labelled samples which stitched wakeword samples are generated from') - parser.add_argument('--num-stitched-samples', type=int, default=10000, - help='number of stitched wakeword samples to geneate (default: 10000)') - parser.add_argument('--stitched-dataset-pct', type=int, nargs=3, default=[0.5, 0.25, 0.25], - help='train/dev/test pct for stitched dataset (default: [0.5, 0.25, 0.25])') + parser.add_argument( + "--stitched-dataset", + type=str, + default="", + help="if provided, stitched wakeword samples are saved to the given path. (default: dataset-path/stitched)", + ) + parser.add_argument( + "--aligned-dataset", + type=str, + help="dataset with frame labelled samples which stitched wakeword samples are generated from", + ) + parser.add_argument( + "--num-stitched-samples", + type=int, + default=10000, + help="number of stitched wakeword samples to geneate (default: 10000)", + ) + parser.add_argument( + "--stitched-dataset-pct", + type=int, + nargs=3, + default=[0.5, 0.25, 0.25], + help="train/dev/test pct for stitched dataset (default: [0.5, 0.25, 0.25])", + ) + parser.add_argument( + "--disable-detect-keyword", action="store_false", help="disable keyword detection based verifcation" + ) args = parser.parse_args() aligned_ds_path = Path(args.aligned_dataset) - stitched_ds_path = aligned_ds_path / 'stitched' if args.stitched_dataset == '' else Path(args.stitched_dataset) + stitched_ds_path = aligned_ds_path / "stitched" if args.stitched_dataset == "" else Path(args.stitched_dataset) stitched_ds_path.mkdir(exist_ok=True) vocab = Vocab(SETTINGS.training.vocab) @@ -49,7 +64,7 @@ def main(): train_ds, dev_ds, test_ds = WakeWordDatasetLoader().load_splits(aligned_ds_path, **ds_kwargs) # stitch vocab samples - stitcher = WordStitcher(vocab=vocab) + stitcher = WordStitcher(vocab=vocab, detect_keyword=args.disable_detect_keyword) stitcher.stitch(args.num_stitched_samples, stitched_ds_path, train_ds, dev_ds, test_ds) # split the stitched samples @@ -58,11 +73,11 @@ def main(): # save metadata for ds in stitched_train_ds, stitched_dev_ds, stitched_test_ds: try: - AudioDatasetWriter(ds, prefix='aligned-').write(stitched_ds_path) + AudioDatasetWriter(ds, prefix="aligned-").write(stitched_ds_path) except KeyboardInterrupt: - print('Skipping...') + print("Skipping...") pass -if __name__ == '__main__': +if __name__ == "__main__": main()