Skip to content

Commit

Permalink
Filter out bad stitched samples by checking the speech to text transc…
Browse files Browse the repository at this point in the history
…ription (#72)

* sphinx keyword detector

* support audio sample verification with keyword detection

* fix incorrect name
  • Loading branch information
ljj7975 authored Apr 9, 2021
1 parent 333497c commit be80330
Show file tree
Hide file tree
Showing 28 changed files with 193 additions and 60 deletions.
6 changes: 4 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ repos:
- id: trailing-whitespace
- id: check-json
- id: name-tests-test
- repo: https://github.com/timothycrosley/isort
rev: 5.2.0
- repo: https://github.com/pycqa/isort
rev: 5.5.4
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/psf/black
rev: 19.10b0
hooks:
Expand All @@ -27,6 +28,7 @@ repos:
rev: 3.8.3
hooks:
- id: flake8
args: [--max-line-length=120]
- repo: https://github.com/pre-commit/mirrors-pylint
rev: v2.6.0
hooks:
Expand Down
108 changes: 80 additions & 28 deletions howl/data/stitcher.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import itertools
import random
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple

import soundfile
import torch
from howl.data.dataset import (AudioClipDataset, AudioClipExample,
AudioClipMetadata, AudioDataset, DatasetType)
from tqdm import tqdm

from howl.data.dataset import (
AudioClipDataset,
AudioClipExample,
AudioClipMetadata,
AudioDataset,
DatasetType,
)
from howl.data.tokenize import Vocab
from howl.settings import SETTINGS
from tqdm import tqdm
from howl.utils.sphinx_keyword_detector import SphinxKeywordDetector

__all__ = ['WordStitcher']
__all__ = ["WordStitcher"]


@dataclass
Expand All @@ -24,18 +30,27 @@ class FrameLabelledSample:


class Stitcher:
def __init__(self,
vocab: Vocab):
def __init__(self, vocab: Vocab, detect_keyword: bool = True):
"""Base Stitcher class
Args:
vocab (Vocab): vocab containing wakeword
detect_keyword (bool, optional): drop invalid stitched samples through secondary keyword detection step
"""
self.sequence = SETTINGS.inference_engine.inference_sequence
self.sr = SETTINGS.audio.sample_rate
self.vocab = vocab
self.wakeword = ' '.join(self.vocab[x]
for x in self.sequence)
self.wakeword = " ".join(self.vocab[x] for x in self.sequence)

self.detect_keyword = detect_keyword
self.keyword_detector = []
if self.detect_keyword:
for x in self.sequence:
self.keyword_detector.append(SphinxKeywordDetector(self.vocab[x]))


class WordStitcher(Stitcher):
def __init__(self,
**kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def concatenate_end_timestamps(self, end_timestamps_list: List[List[float]]) -> List[float]:
Expand All @@ -60,7 +75,7 @@ def concatenate_end_timestamps(self, end_timestamps_list: List[List[float]]) ->

return concatnated_timestamps[:-1] # discard last space timestamp

def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, * datasets: AudioDataset):
def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, *datasets: AudioDataset):
"""collect vocab samples from datasets and generate stitched wakeword samples
Args:
Expand All @@ -85,8 +100,14 @@ def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, * datase
for char_idx in char_indices:
adjusted_end_timestamps.append(sample.metadata.end_timestamps[char_idx] - start_timestamp)

sample_set[label].append(FrameLabelledSample(
sample.audio_data[audio_start_idx:audio_end_idx], end_timestamp-start_timestamp, adjusted_end_timestamps, label))
sample_set[label].append(
FrameLabelledSample(
sample.audio_data[audio_start_idx:audio_end_idx],
end_timestamp - start_timestamp,
adjusted_end_timestamps,
label,
)
)

audio_dir = stitched_dataset_dir / "audio"
audio_dir.mkdir(exist_ok=True)
Expand All @@ -100,34 +121,63 @@ def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, * datase

# generate AudioClipExample for each vocab sample
self.stitched_samples = []
for sample_idx in tqdm(range(num_stitched_samples), desc="Generating stitched samples"):

pbar = tqdm(total=num_stitched_samples, desc="Generating stitched samples")
sample_idx = 0
num_skipped_samples = 0
while True:
if sample_idx == num_stitched_samples:
break

sample_set = []
for sample_list in sample_lists:
sample_set.append(random.choice(sample_list))

audio_data = torch.cat([labelled_data.audio_data for labelled_data in sample_set])

if self.detect_keyword:
temp_audio_file_path = "/tmp/temp.wav"
soundfile.write(temp_audio_file_path, audio_data.numpy(), self.sr)

keyword_exists = True
for detector in self.keyword_detector:
# sphinx keyword detection may not be sufficient for audio with repeated words
if len(detector.detect(temp_audio_file_path)) == 0:
keyword_exists = False
break

if keyword_exists:
num_skipped_samples += 1
continue

metatdata = AudioClipMetadata(
path=Path(audio_dir / f"{sample_idx}").with_suffix(".wav"),
transcription=self.wakeword,
end_timestamps=self.concatenate_end_timestamps(
[labelled_data.end_timestamps for labelled_data in sample_set])
[labelled_data.end_timestamps for labelled_data in sample_set]
),
)

# TODO:: dataset writer load the samples upon write and does not use data in memory
# writer classes need to be refactored to use audio data if exist
audio_data = torch.cat([labelled_data.audio_data for labelled_data in sample_set])
soundfile.write(metatdata.path, audio_data.numpy(), self.sr)

stitched_sample = AudioClipExample(
metadata=metatdata,
audio_data=audio_data,
sample_rate=self.sr)
stitched_sample = AudioClipExample(metadata=metatdata, audio_data=audio_data, sample_rate=self.sr)

self.stitched_samples.append(stitched_sample)

def load_splits(self,
train_pct: float,
dev_pct: float,
test_pct: float) -> Tuple[AudioClipDataset, AudioClipDataset, AudioClipDataset]:
sample_idx += 1
pbar.update()

if self.detect_keyword:
print(
f"While generating {num_stitched_samples} stithced samples, "
f"{num_skipped_samples} are filtered by keyword detection"
)

def load_splits(
self, train_pct: float, dev_pct: float, test_pct: float
) -> Tuple[AudioClipDataset, AudioClipDataset, AudioClipDataset]:
"""split the generated stitched samples based on the given pct
first train_pct samples are used to generate train set
next dev_pct samples are used to generate dev set
Expand Down Expand Up @@ -161,6 +211,8 @@ def load_splits(self,
test_split.append(sample.metadata)

ds_kwargs = dict(sr=self.sr, mono=SETTINGS.audio.use_mono)
return (AudioClipDataset(metadata_list=train_split, set_type=DatasetType.TRAINING, **ds_kwargs),
AudioClipDataset(metadata_list=dev_split, set_type=DatasetType.DEV, **ds_kwargs),
AudioClipDataset(metadata_list=test_split, set_type=DatasetType.TEST, **ds_kwargs))
return (
AudioClipDataset(metadata_list=train_split, set_type=DatasetType.TRAINING, **ds_kwargs),
AudioClipDataset(metadata_list=dev_split, set_type=DatasetType.DEV, **ds_kwargs),
AudioClipDataset(metadata_list=test_split, set_type=DatasetType.TEST, **ds_kwargs),
)
35 changes: 35 additions & 0 deletions howl/utils/sphinx_keyword_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

from pocketsphinx import AudioFile


class SphinxKeywordDetector():
def __init__(self, target_transcription, threshold=1e-20, verbose=False):
self.target_transcription = target_transcription
self.verbose = verbose
self.kws_config = {
'verbose': self.verbose,
'keyphrase': self.target_transcription,
'kws_threshold': threshold,
'lm': False,
}

def detect(self, file_name):

kws_results = []

self.kws_config['audio_file'] = file_name
audio = AudioFile(**self.kws_config)

for phrase in audio:
result = phrase.segments(detailed=True)

# TODO:: confirm that when multiple keywords are detected, every detection is valid
if len(result) == 1:
start_time = result[0][2] * 10
end_time = result[0][3] * 10
if self.verbose:
print('%4sms ~ %4sms' % (start_time, end_time))
kws_results.append((start_time, end_time))

return kws_results
5 changes: 3 additions & 2 deletions requirements_training.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
openpyxl
pocketsphinx==0.1.15
praat-textgrids==1.3.1
webrtcvad==2.0.10
pytest
pre-commit
pytest
webrtcvad==2.0.10
11 changes: 5 additions & 6 deletions test/data/stitcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import unittest
from pathlib import Path

import torch
from howl.data.dataset import WakeWordDatasetLoader, WordFrameLabeler
from howl.data.stitcher import WordStitcher
from howl.data.tokenize import Vocab
from howl.settings import SETTINGS


class TestStitcher(unittest.TestCase):

def test_compute_statistics(self):
random.seed(1)

Expand All @@ -20,25 +18,26 @@ def test_compute_statistics(self):
SETTINGS.training.token_type = "word"
SETTINGS.inference_engine.inference_sequence = [0, 1, 2]

vocab = Vocab({"hey": 0, "fire": 1, "fox": 2}, oov_token_id=3, oov_word_repr='<OOV>')
vocab = Vocab({"hey": 0, "fire": 1, "fox": 2}, oov_token_id=3, oov_word_repr="<OOV>")
labeler = WordFrameLabeler(vocab)

loader = WakeWordDatasetLoader()
ds_kwargs = dict(sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono, frame_labeler=labeler)

test_dataset_path = Path("test/test_data")
test_dataset_path = Path("test/test_data/stitcher")
stitched_dataset_path = test_dataset_path / "stitched"
stitched_dataset_path.mkdir(exist_ok=True)

test_ds, _, _ = loader.load_splits(test_dataset_path, **ds_kwargs)
stitcher = WordStitcher(vocab=vocab)
stitcher = WordStitcher(vocab=vocab, detect_keyword=True)
stitcher.stitch(20, stitched_dataset_path, test_ds)

stitched_train_ds, stitched_dev_ds, stitched_test_ds = stitcher.load_splits(0.5, 0.25, 0.25)

self.assertEqual(len(stitched_train_ds), 10)
self.assertEqual(len(stitched_dev_ds), 5)
self.assertEqual(len(stitched_test_ds), 5)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
stitched
stitched_dataset
File renamed without changes.
File renamed without changes.
28 changes: 28 additions & 0 deletions test/utils/sphinx_keyword_detector_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest

from howl.utils.sphinx_keyword_detector import SphinxKeywordDetector


class TestSphinxKeywordDetector(unittest.TestCase):

def test_detect(self):
"""test word detection from an audio file
"""

hello_world_file = "test/test_data/sphinx_keyword_detector/hello_world.wav"
hello_extractor = SphinxKeywordDetector("hello")
self.assertTrue(len(hello_extractor.detect(hello_world_file)) > 0)
world_extractor = SphinxKeywordDetector("world")
self.assertTrue(len(world_extractor.detect(hello_world_file)) > 0)

hey_fire_fox_file = "test/test_data/sphinx_keyword_detector/hey_fire_fox.wav"
hey_extractor = SphinxKeywordDetector("hey")
self.assertTrue(len(hey_extractor.detect(hey_fire_fox_file)) > 0)
fire_extractor = SphinxKeywordDetector("fire")
self.assertTrue(len(fire_extractor.detect(hey_fire_fox_file)) > 0)
fox_extractor = SphinxKeywordDetector("fox")
self.assertTrue(len(fox_extractor.detect(hey_fire_fox_file)) > 0)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit be80330

Please sign in to comment.