Skip to content

Commit

Permalink
⚡️ Switch to faster-whisper
Browse files Browse the repository at this point in the history
This also disables the alignment task, because faster-whisper can give
proper word-timestamps. It also simplifies the code a bit, as
faster-whisper already contains logic to re-combine words from the
tokens etc.
  • Loading branch information
pajowu committed Dec 2, 2023
1 parent e406b9f commit c81003d
Show file tree
Hide file tree
Showing 11 changed files with 1,032 additions and 496 deletions.
706 changes: 706 additions & 0 deletions backend/data/models.json

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion backend/transcribee_backend/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,11 @@ def from_orm(cls, task: Task, update={}) -> Self:
id=task.id,
state=task.state,
dependencies=[x.dependant_on_id for x in task.dependency_links],
current_attempt=None,
current_attempt=(
TaskAttemptResponse.from_orm(task.current_attempt)
if task.current_attempt is not None
else None
),
document_id=task.document_id,
task_type=task.task_type,
task_parameters=task.task_parameters,
Expand Down
10 changes: 1 addition & 9 deletions backend/transcribee_backend/routers/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,20 +282,12 @@ def create_default_tasks_for_document(
)
session.add(transcribe_task)

align_task = Task(
task_type=TaskType.ALIGN,
task_parameters={},
document_id=document.id,
dependencies=[transcribe_task],
)
session.add(align_task)

if number_of_speakers != 0 and number_of_speakers != 1:
speaker_identification_task = Task(
task_type=TaskType.IDENTIFY_SPEAKERS,
task_parameters={"number_of_speakers": number_of_speakers},
document_id=document.id,
dependencies=[align_task],
dependencies=[transcribe_task],
)
session.add(speaker_identification_task)

Expand Down
4 changes: 3 additions & 1 deletion shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ let
});
ld_packages = [
pkgs.file
# for ctranslate2
pkgs.stdenv.cc.cc.lib
];

in
Expand Down Expand Up @@ -67,7 +69,7 @@ pkgs.mkShell {
# Some libraries are not found if not added directly to LD_LIBRARY_PATH / DYLD_LIBRARY_PATH (on darwin)
# However just adding them there is not enough, because macOS purges the DYLD_* variables in some conditions
# This means we have to set them again in some script (e.g. ./start_backend.sh) -> we need a "safe" env var to pass them to the script
export TRANSCRIBEE_DYLD_LIBRARY_PATH=${builtins.concatStringsSep ":" (map (x: x + "/lib") ld_packages)}
export TRANSCRIBEE_DYLD_LIBRARY_PATH=${pkgs.lib.makeLibraryPath ld_packages}
export LD_LIBRARY_PATH=$LD_SEARCH_PATH:$TRANSCRIBEE_DYLD_LIBRARY_PATH
'' + pkgs.lib.optionalString pkgs.stdenv.isDarwin ''
export CPPFLAGS="-I${pkgs.libcxx.dev}/include/c++/v1"
Expand Down
268 changes: 209 additions & 59 deletions worker/pdm.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion worker/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ dependencies = [
"torch>=2.0.0",
"automerge @ git+https://github.com/transcribee/automerge-py.git@057303bd087401f12b166e2adb7161f0fcb3a9dc",
"websockets>=10.4",
"whispercppy>=0.0.4",
"scikit-learn>=1.2.2",
"watchfiles>=0.19.0",
"speechbrain>=0.5.14",
"ffmpeg-python>=0.2.0",
"transcribee-proto @ file:///${PROJECT_ROOT}/../proto",
"PyICU>=2.11",
"faster-whisper>=0.10.0",
]

requires-python = ">=3.10"
Expand All @@ -45,6 +45,8 @@ build-backend = "pdm.pep517.api"
soundfile = "0.11.0"




[[tool.pdm.source]]
type = "find_links"
url = "https://download.pytorch.org/whl/cpu/torch_stable.html"
Expand Down
36 changes: 5 additions & 31 deletions worker/scripts/generate_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
import re
from typing import List

from faster_whisper.tokenizer import _LANGUAGE_CODES
from faster_whisper.utils import _MODELS
from pydantic import BaseModel
from transcribee_worker.torchaudio_align import (
DEFAULT_ALIGN_MODELS_HF,
DEFAULT_ALIGN_MODELS_TORCH,
)
from transcribee_worker.whisper_transcribe import get_context


def is_english_only(model_name):
Expand All @@ -26,37 +23,14 @@ class ModelConfig(BaseModel):
parser.add_argument("out", type=argparse.FileType("w"))
args = parser.parse_args()

models = [
"tiny.en",
"tiny",
"base.en",
"base",
"small.en",
"small",
"medium.en",
"medium",
"large-v1",
"large",
]

alignable_languages = set(DEFAULT_ALIGN_MODELS_HF.keys()) | set(
DEFAULT_ALIGN_MODELS_TORCH.keys()
)

model_configs = []
multilingual_model_langs = list(sorted(_LANGUAGE_CODES))

context = get_context("tiny")
multilingual_model_langs = set(
context.lang_id_to_str(i) for i in range(context.lang_max_id + 1)
)

for model in models:
for model in _MODELS:
if is_english_only(model):
languages = ["en"]
else:
languages = ["auto"] + list(
sorted(multilingual_model_langs & alignable_languages)
)
languages = ["auto"] + multilingual_model_langs

model_configs.append(
ModelConfig(
Expand Down
83 changes: 0 additions & 83 deletions worker/tests/data/test_combine_tokens_to_words-simple.json

This file was deleted.

56 changes: 10 additions & 46 deletions worker/tests/test_transcribe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import glob
from pathlib import Path
from typing import List
Expand All @@ -7,24 +6,18 @@
from pydantic import BaseModel, parse_file_as
from transcribee_proto.document import Paragraph
from transcribee_worker.whisper_transcribe import (
combine_tokens_to_words,
move_space_to_prev_token,
strict_sentence_paragraphs,
)


async def list_to_async_iter(list):
for item in list:
yield item


def async_doc_chain_func_to_list(*funcs):
async def wrapper(input: List[Paragraph]):
def doc_chain_func_to_list(*funcs):
def wrapper(input: List[Paragraph]):
res = []
iter = list_to_async_iter(input)
iter_ = iter(input)
for func in funcs:
iter = func(iter)
async for x in iter:
iter_ = func(iter_)
for x in iter_:
res.append(x)
return res

Expand All @@ -36,23 +29,6 @@ class SpecInput(BaseModel):
expected: List[Paragraph]


@pytest.mark.parametrize(
"data_file",
glob.glob(
str(Path(__file__).parent / "data" / "test_combine_tokens_to_words*.json"),
),
)
def test_combine_tokens_to_words(data_file):
test_data = parse_file_as(SpecInput, data_file)

output = list(
asyncio.run(
async_doc_chain_func_to_list(combine_tokens_to_words)(test_data.input)
)
)
assert output == test_data.expected


@pytest.mark.parametrize(
"data_file",
glob.glob(
Expand All @@ -62,11 +38,7 @@ def test_combine_tokens_to_words(data_file):
def test_strict_sentence_paragraphs(data_file):
test_data = parse_file_as(SpecInput, data_file)

output = list(
asyncio.run(
async_doc_chain_func_to_list(strict_sentence_paragraphs)(test_data.input)
)
)
output = list(doc_chain_func_to_list(strict_sentence_paragraphs)(test_data.input))
assert [x.text() for x in output] == [x.text() for x in test_data.expected]
assert output == test_data.expected

Expand All @@ -80,11 +52,7 @@ def test_strict_sentence_paragraphs(data_file):
def test_move_space_to_prev_token(data_file):
test_data = parse_file_as(SpecInput, data_file)

output = list(
asyncio.run(
async_doc_chain_func_to_list(move_space_to_prev_token)(test_data.input)
)
)
output = doc_chain_func_to_list(move_space_to_prev_token)(test_data.input)
assert output == test_data.expected


Expand All @@ -97,13 +65,9 @@ def test_move_space_to_prev_token(data_file):
def test_space_and_sentences(data_file):
test_data = parse_file_as(SpecInput, data_file)

output = list(
asyncio.run(
async_doc_chain_func_to_list(
move_space_to_prev_token, strict_sentence_paragraphs
)(test_data.input)
)
)
output = doc_chain_func_to_list(
move_space_to_prev_token, strict_sentence_paragraphs
)(test_data.input)
for p in output:
print(p.json())
assert output == test_data.expected
Loading

0 comments on commit c81003d

Please sign in to comment.