Skip to content

Commit

Permalink
Optimize whisper asr (support maintain folder structure & optional re…
Browse files Browse the repository at this point in the history
…sample)
  • Loading branch information
leng-yue committed Dec 21, 2023
1 parent fc5b9f5 commit 0c56a07
Showing 1 changed file with 39 additions and 49 deletions.
88 changes: 39 additions & 49 deletions tools/whisper_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,79 +24,69 @@
from pathlib import Path

import click
import librosa
import soundfile as sf
import whisper
from pydub import AudioSegment
from loguru import logger
from tqdm import tqdm

from fish_speech.utils.file import list_files


def transcribe_audio(model, filepath, language):
return model.transcribe(filepath, language=language)


def transcribe_segment(model, filepath):
audio = whisper.load_audio(filepath)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(model.device)
_, probs = model.detect_language(mel)
lang = max(probs, key=probs.get)
options = whisper.DecodingOptions(beam_size=5)
result = whisper.decode(model, mel, options)
return result.text, lang


def load_audio(file_path, file_suffix):
try:
if file_suffix == ".wav":
audio = AudioSegment.from_wav(file_path)
elif file_suffix == ".mp3":
audio = AudioSegment.from_mp3(file_path)
elif file_suffix == ".flac":
audio = AudioSegment.from_file(file_path, format="flac")
return audio
except Exception as e:
print(f"Error processing file {file_path}: {e}")
return None
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files


@click.command()
@click.option("--model_size", default="large", help="Size of the Whisper model")
@click.option("--audio_dir", required=True, help="Directory containing audio files")
@click.option("--model-size", default="large", help="Size of the Whisper model")
@click.option("--audio-dir", required=True, help="Directory containing audio files")
@click.option(
"--save-dir", required=True, help="Directory to save processed audio files"
)
@click.option(
"--save_dir", required=True, help="Directory to save processed audio files"
"--sample-rate",
default=None,
type=int,
help="Output sample rate, default to input sample rate",
)
@click.option("--language", default="ZH", help="Language of the transcription")
@click.option("--out_sr", default=44100, type=int, help="Output sample rate")
def main(model_size, audio_dir, save_dir, out_sr, language):
print("Loading/Downloading OpenAI Whisper model...")
def main(model_size, audio_dir, save_dir, sample_rate, language):
logger.info("Loading / Downloading OpenAI Whisper model...")
model = whisper.load_model(model_size)
logger.info("Model loaded.")

save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)

audio_files = list_files(
path=audio_dir, extensions=[".wav", ".mp3", ".flac"], recursive=True
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
)
for file_path in tqdm(audio_files, desc="Processing audio file"):
file_stem = file_path.stem
file_suffix = file_path.suffix
file_path = str(file_path)
audio = load_audio(file_path, file_suffix)
if not audio:
continue
transcription = transcribe_audio(model, file_path, language)

rel_path = Path(file_path).relative_to(audio_dir)
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)

audio, sr = librosa.load(file_path, sr=sample_rate, mono=False)
transcription = model.transcribe(str(file_path), language=language)

for segment in transcription.get("segments", []):
print(segment)
id, text, start, end = (
segment["id"],
segment["text"],
segment["start"],
segment["end"],
)
extract = audio[int(start * 1000) : int(end * 1000)].set_frame_rate(out_sr)
extract.export(
save_path / f"{file_stem}_{id}{file_suffix}",
format=file_suffix.lower().strip("."),

extract = audio[..., int(start * sr) : int(end * sr)]
sf.write(
save_path / rel_path.parent / f"{file_stem}_{id}{file_suffix}",
extract,
samplerate=sr,
)
with open(save_path / f"{file_stem}_{id}.lab", "w", encoding="utf-8") as f:

with open(
save_path / rel_path.parent / f"{file_stem}_{id}.lab",
"w",
encoding="utf-8",
) as f:
f.write(text)


Expand Down

0 comments on commit 0c56a07

Please sign in to comment.