Skip to content
This repository has been archived by the owner on Feb 15, 2025. It is now read-only.

Commit

Permalink
chore(whisper): Pass through variables down to whisper (#840)
Browse files Browse the repository at this point in the history
Ensures that all variables are passed down to the actual whisper model (language, temperature, prompt)
  • Loading branch information
CollectiveUnicorn authored Jul 30, 2024
1 parent c993ad5 commit 4e8092a
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
26 changes: 24 additions & 2 deletions packages/whisper/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,32 @@ def make_transcribe_request(filename, task, language, temperature, prompt):
device = "cuda" if GPU_ENABLED else "cpu"
model = WhisperModel(model_path, device=device, compute_type="float32")

segments, info = model.transcribe(filename, task=task, beam_size=5)
# Prepare kwargs with non-None values
kwargs = {}
if task:
if task in ["transcribe", "translate"]:
kwargs["task"] = task
else:
logger.error(f"Task {task} is not supported")
return {"text": ""}
if language:
if language in model.supported_languages:
kwargs["language"] = language
else:
logger.error(f"Language {language} is not supported")
if temperature:
kwargs["temperature"] = temperature
if prompt:
kwargs["initial_prompt"] = prompt

try:
# Call transcribe with only non-None parameters
segments, info = model.transcribe(filename, beam_size=5, **kwargs)
except Exception as e:
logger.error(f"Error transcribing audio: {e}")
return {"text": ""}

output = ""

for segment in segments:
output += segment.text

Expand Down
2 changes: 1 addition & 1 deletion packages/whisper/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ version = "0.9.2"
# x-release-please-end

dependencies = [
"faster-whisper == 0.10.0",
"faster-whisper == 1.0.3",
"leapfrogai-sdk",
]
requires-python = "~=3.11"
Expand Down
Empty file added tests/e2e/__init__.py
Empty file.
10 changes: 5 additions & 5 deletions tests/e2e/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def test_transcriptions():
timestamp_granularities=["word", "segment"],
)

assert len(transcription.text) > 0 # The transcription should not be empty
assert len(transcription.text) < 500 # The transcription should not be too long
assert len(transcription.text) > 0, "The transcription should not be empty"
assert len(transcription.text) < 500, "The transcription should not be too long"


def test_translations():
Expand All @@ -65,8 +65,8 @@ def test_translations():
temperature=0.3,
)

assert len(translation.text) > 0 # The translation should not be empty
assert len(translation.text) < 500 # The translation should not be too long
assert len(translation.text) > 0, "The translation should not be empty"
assert len(translation.text) < 500, "The translation should not be too long"

def is_english_or_punctuation(c):
if c in string.punctuation or c.isspace():
Expand All @@ -78,4 +78,4 @@ def is_english_or_punctuation(c):

english_chars = [is_english_or_punctuation(c) for c in translation.text]

assert all(english_chars) # Check that only English characters are returned
assert all(english_chars), "Non-English characters have been returned"

0 comments on commit 4e8092a

Please sign in to comment.