Skip to content

Commit

Permalink
1.Supports FLAC, WAV, MP3 2.Fixed conversion path issue. (#22)
Browse files Browse the repository at this point in the history
* 1.Supports FLAC, WAV, MP3  2.Fixed conversion path issue.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 1.Use list_files to filter audio 2.Use the click library 3.Implement sample rate conversion.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
kenwaytis and pre-commit-ci[bot] authored Dec 21, 2023
1 parent c36a375 commit fc5b9f5
Showing 1 changed file with 56 additions and 78 deletions.
134 changes: 56 additions & 78 deletions tools/whisper_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,17 @@
Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
"""

import argparse
import os
from pathlib import Path

import librosa
import numpy as np
import click
import whisper
from scipy.io import wavfile
from tqdm import tqdm


def load_and_normalize_audio(filepath, target_sr):
wav, sr = librosa.load(filepath, sr=None, mono=True)
wav, _ = librosa.effects.trim(wav, top_db=20)
peak = np.abs(wav).max()
if peak > 1.0:
wav /= peak / 0.98
return librosa.resample(wav, orig_sr=sr, target_sr=target_sr), target_sr

from pydub import AudioSegment

def transcribe_audio(model, filepath):
return model.transcribe(
filepath, word_timestamps=True, task="transcribe", beam_size=5, best_of=5
)
from fish_speech.utils.file import list_files


def save_audio_segments(segments, wav, sr, save_path):
for i, seg in enumerate(segments):
start_time, end_time = seg["start"], seg["end"]
wav_seg = wav[int(start_time * sr) : int(end_time * sr)]
wav_seg_name = f"{save_path.stem}_{i}.wav"
out_fpath = save_path / wav_seg_name
wavfile.write(
out_fpath, rate=sr, data=(wav_seg * np.iinfo(np.int16).max).astype(np.int16)
)
def transcribe_audio(model, filepath, language):
return model.transcribe(filepath, language=language)


def transcribe_segment(model, filepath):
Expand All @@ -70,57 +45,60 @@ def transcribe_segment(model, filepath):
return result.text, lang


def process_output(save_dir, language, out_file):
with open(out_file, "w", encoding="utf-8") as wf:
ch_name = save_dir.stem
for file in save_dir.glob("*.lab"):
with open(file, "r", encoding="utf-8") as perFile:
line = perFile.readline().strip()
result = (
f"{save_dir}/{ch_name}/{file.stem}.wav|{ch_name}|{language}|{line}"
)
wf.write(f"{result}\n")


def load_audio(file_path, file_suffix):
try:
if file_suffix == ".wav":
audio = AudioSegment.from_wav(file_path)
elif file_suffix == ".mp3":
audio = AudioSegment.from_mp3(file_path)
elif file_suffix == ".flac":
audio = AudioSegment.from_file(file_path, format="flac")
return audio
except Exception as e:
print(f"Error processing file {file_path}: {e}")
return None


@click.command()
@click.option("--model_size", default="large", help="Size of the Whisper model")
@click.option("--audio_dir", required=True, help="Directory containing audio files")
@click.option(
"--save_dir", required=True, help="Directory to save processed audio files"
)
@click.option("--language", default="ZH", help="Language of the transcription")
@click.option("--out_sr", default=44100, type=int, help="Output sample rate")
def main(model_size, audio_dir, save_dir, out_sr, language):
print("Loading/Downloading OpenAI Whisper model...")
model = whisper.load_model(model_size)
audio_dir, save_dir = Path(audio_dir), Path(save_dir)
save_dir.mkdir(exist_ok=True)

for filepath in tqdm(list(audio_dir.glob("*.wav")), desc="Processing files"):
wav, sr = load_and_normalize_audio(filepath, out_sr)
transcription = transcribe_audio(model, filepath)
save_path = save_dir / filepath.stem
save_audio_segments(transcription["segments"], wav, sr, save_path)

for segment_file in tqdm(
list(save_path.glob("*.wav")), desc="Transcribing segments"
):
text, _ = transcribe_segment(model, segment_file)
with open(segment_file.with_suffix(".lab"), "w", encoding="utf-8") as f:
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
audio_files = list_files(
path=audio_dir, extensions=[".wav", ".mp3", ".flac"], recursive=True
)
for file_path in tqdm(audio_files, desc="Processing audio file"):
file_stem = file_path.stem
file_suffix = file_path.suffix
file_path = str(file_path)
audio = load_audio(file_path, file_suffix)
if not audio:
continue
transcription = transcribe_audio(model, file_path, language)
for segment in transcription.get("segments", []):
print(segment)
id, text, start, end = (
segment["id"],
segment["text"],
segment["start"],
segment["end"],
)
extract = audio[int(start * 1000) : int(end * 1000)].set_frame_rate(out_sr)
extract.export(
save_path / f"{file_stem}_{id}{file_suffix}",
format=file_suffix.lower().strip("."),
)
with open(save_path / f"{file_stem}_{id}.lab", "w", encoding="utf-8") as f:
f.write(text)

# process_output(save_dir, language, save_dir / "output.txt") # Dont need summarize to one file


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Audio Transcription with Whisper")
parser.add_argument(
"--model_size", type=str, default="large", help="Size of the Whisper model"
)
parser.add_argument(
"--audio_dir", type=str, required=True, help="Directory containing audio files"
)
parser.add_argument(
"--save_dir",
type=str,
required=True,
help="Directory to save processed audio files",
)
parser.add_argument(
"--language", type=str, default="ZH", help="Language of the transcription"
)
parser.add_argument("--out_sr", type=int, default=44100, help="Output sample rate")
args = parser.parse_args()

main(args.model_size, args.audio_dir, args.save_dir, args.out_sr, args.language)
main()

0 comments on commit fc5b9f5

Please sign in to comment.