Skip to content

Commit

Permalink
Merge pull request #739 from ftnext/test-whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
ftnext authored Mar 30, 2024
2 parents 4924857 + 8da6e42 commit 9194a47
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install 'pocketsphinx<5'
python -m pip install git+https://github.com/openai/whisper.git soundfile
python -m pip install openai-whisper soundfile
python -m pip install openai
python -m pip install .
- name: Test with unittest
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ Whisper (for Whisper users)
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Whisper is **required if and only if you want to use whisper** (``recognizer_instance.recognize_whisper``).

You can install it with ``python3 -m pip install git+https://github.com/openai/whisper.git soundfile``.
You can install it with ``python3 -m pip install openai-whisper soundfile``.

Whisper API (for Whisper API users)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
16 changes: 0 additions & 16 deletions tests/test_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def setUp(self):
self.AUDIO_FILE_EN = os.path.join(os.path.dirname(os.path.realpath(__file__)), "english.wav")
self.AUDIO_FILE_FR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "french.aiff")
self.AUDIO_FILE_ZH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "chinese.flac")
self.WHISPER_CONFIG = {"temperature": 0}

def test_recognizer_attributes(self):
r = sr.Recognizer()
Expand Down Expand Up @@ -81,21 +80,6 @@ def test_ibm_chinese(self):
with sr.AudioFile(self.AUDIO_FILE_ZH) as source: audio = r.record(source)
self.assertEqual(r.recognize_ibm(audio, username=os.environ["IBM_USERNAME"], password=os.environ["IBM_PASSWORD"], language="zh-CN"), u"砸 自己 的 脚 ")

def test_whisper_english(self):
r = sr.Recognizer()
with sr.AudioFile(self.AUDIO_FILE_EN) as source: audio = r.record(source)
self.assertEqual(r.recognize_whisper(audio, language="english", **self.WHISPER_CONFIG), " 1, 2, 3")

def test_whisper_french(self):
r = sr.Recognizer()
with sr.AudioFile(self.AUDIO_FILE_FR) as source: audio = r.record(source)
self.assertEqual(r.recognize_whisper(audio, language="french", **self.WHISPER_CONFIG), " et c'est la dictée numéro 1.")

def test_whisper_chinese(self):
r = sr.Recognizer()
with sr.AudioFile(self.AUDIO_FILE_ZH) as source: audio = r.record(source)
self.assertEqual(r.recognize_whisper(audio, model="small", language="chinese", **self.WHISPER_CONFIG), u"砸自己的腳")


if __name__ == "__main__":
unittest.main()
78 changes: 78 additions & 0 deletions tests/test_whisper_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from unittest import TestCase
from unittest.mock import MagicMock, patch

import numpy as np

from speech_recognition import AudioData, Recognizer


@patch("speech_recognition.io.BytesIO")
@patch("soundfile.read")
@patch("torch.cuda.is_available")
@patch("whisper.load_model")
class RecognizeWhisperTestCase(TestCase):
def test_default_parameters(
self, load_model, is_available, sf_read, BytesIO
):
whisper_model = load_model.return_value
transcript = whisper_model.transcribe.return_value
audio_array = MagicMock()
dummy_sampling_rate = 99_999
sf_read.return_value = (audio_array, dummy_sampling_rate)

recognizer = Recognizer()
audio_data = MagicMock(spec=AudioData)
actual = recognizer.recognize_whisper(audio_data)

self.assertEqual(actual, transcript.__getitem__.return_value)
load_model.assert_called_once_with("base")
audio_data.get_wav_data.assert_called_once_with(convert_rate=16000)
BytesIO.assert_called_once_with(audio_data.get_wav_data.return_value)
sf_read.assert_called_once_with(BytesIO.return_value)
audio_array.astype.assert_called_once_with(np.float32)
whisper_model.transcribe.assert_called_once_with(
audio_array.astype.return_value,
language=None,
task=None,
fp16=is_available.return_value,
)
transcript.__getitem__.assert_called_once_with("text")

def test_return_as_dict(self, load_model, is_available, sf_read, BytesIO):
whisper_model = load_model.return_value
audio_array = MagicMock()
dummy_sampling_rate = 99_999
sf_read.return_value = (audio_array, dummy_sampling_rate)

recognizer = Recognizer()
audio_data = MagicMock(spec=AudioData)
actual = recognizer.recognize_whisper(audio_data, show_dict=True)

self.assertEqual(actual, whisper_model.transcribe.return_value)

def test_pass_parameters(self, load_model, is_available, sf_read, BytesIO):
whisper_model = load_model.return_value
transcript = whisper_model.transcribe.return_value
audio_array = MagicMock()
dummy_sampling_rate = 99_999
sf_read.return_value = (audio_array, dummy_sampling_rate)

recognizer = Recognizer()
audio_data = MagicMock(spec=AudioData)
actual = recognizer.recognize_whisper(
audio_data,
model="small",
language="english",
translate=True,
temperature=0,
)

self.assertEqual(actual, transcript.__getitem__.return_value)
load_model.assert_called_once_with("small")
whisper_model.transcribe.assert_called_once_with(
audio_array.astype.return_value,
language="english",
task="translate",
fp16=is_available.return_value,
temperature=0,
)

0 comments on commit 9194a47

Please sign in to comment.