Skip to content

Commit

Permalink
Docker Compose and Data Preprocessing Script (#10)
Browse files Browse the repository at this point in the history
* 1.add Docker Compose for development; 2.add pre_data for precess dataset

* 1.add Docker Compose for development; 2.add pre_data for precess dataset

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change pre_dataset to whisper_asr.py

* change pre_dataset to whisper_asr.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
kenwaytis and pre-commit-ci[bot] authored Dec 19, 2023
1 parent d3c0dee commit cf69582
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 164 deletions.
18 changes: 18 additions & 0 deletions docker-compose.dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
version: '3.8'

services:
fish-speech:
build: .
container_name: fish-speech
volumes:
- ./data:/exp/data
- ./raw_data:/exp/raw_data
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
command: tail -f /dev/null

271 changes: 107 additions & 164 deletions tools/whisper_asr.py
Original file line number Diff line number Diff line change
@@ -1,183 +1,126 @@
# This file is used to convert the audio files to text files using the Whisper model.
# It's mainly used to generate the training data for the VQ model.

"""
Used to transcribe all audio files in one folder into another folder.
e.g.
Directory structure:
--pre_data_root
----SP_1
------01.wav
------02.wav
------......
----SP_2
------01.wav
------02.wav
------......
Use
python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1
to transcribe the first speaker.
Use
python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2
to transcribe the second speaker.
Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
"""

import argparse
import os
import subprocess as sp
import time
from datetime import timedelta
from functools import lru_cache
from pathlib import Path
from random import Random

import click
import librosa
import numpy as np
import torch
from loguru import logger
from transformers import WhisperProcessor
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
import whisper
from scipy.io import wavfile
from tqdm import tqdm

from fish_speech.modules.flash_whisper import FlashWhisperForConditionalGeneration

RANK_STR = ""
def load_and_normalize_audio(filepath, target_sr):
wav, sr = librosa.load(filepath, sr=None, mono=True)
wav, _ = librosa.effects.trim(wav, top_db=20)
peak = np.abs(wav).max()
if peak > 1.0:
wav /= peak / 0.98
return librosa.resample(wav, orig_sr=sr, target_sr=target_sr), target_sr


@lru_cache(maxsize=1)
def get_whisper_model():
model = FlashWhisperForConditionalGeneration.from_pretrained(
"openai/whisper-medium"
).cuda()
model.eval()
logger.info(f"{RANK_STR}Loaded model")
def transcribe_audio(model, filepath):
return model.transcribe(
filepath, word_timestamps=True, task="transcribe", beam_size=5, best_of=5
)

return model

def save_audio_segments(segments, wav, sr, save_path):
for i, seg in enumerate(segments):
start_time, end_time = seg["start"], seg["end"]
wav_seg = wav[int(start_time * sr) : int(end_time * sr)]
wav_seg_name = f"{save_path.stem}_{i}.wav"
out_fpath = save_path / wav_seg_name
wavfile.write(
out_fpath, rate=sr, data=(wav_seg * np.iinfo(np.int16).max).astype(np.int16)
)

@lru_cache(maxsize=1)
def get_whisper_processor():
return WhisperProcessor.from_pretrained("openai/whisper-medium")

def transcribe_segment(model, filepath):
audio = whisper.load_audio(filepath)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(model.device)
_, probs = model.detect_language(mel)
lang = max(probs, key=probs.get)
options = whisper.DecodingOptions(beam_size=5)
result = whisper.decode(model, mel, options)
return result.text, lang


def process_output(save_dir, language, out_file):
with open(out_file, "w", encoding="utf-8") as wf:
ch_name = save_dir.stem
for file in save_dir.glob("*.lab"):
with open(file, "r", encoding="utf-8") as perFile:
line = perFile.readline().strip()
result = (
f"{save_dir}/{ch_name}/{file.stem}.wav|{ch_name}|{language}|{line}"
)
wf.write(f"{result}\n")

def transcribe_batch(files: list[str], language: str):
wavs = [load_audio(file, 16000) for file in files]
total_time = sum([len(wav) for wav in wavs]) / 16000
wavs = [pad_or_trim(wav) for wav in wavs]

wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
mels = log_mel_spectrogram(wavs).cuda()
model = get_whisper_model()
processor = get_whisper_processor()
forced_decoder_ids = processor.get_decoder_prompt_ids(
language=language, task="transcribe"
)
def main(model_size, audio_dir, save_dir, out_sr, language):
model = whisper.load_model(model_size)
audio_dir, save_dir = Path(audio_dir), Path(save_dir)
save_dir.mkdir(exist_ok=True)

with torch.no_grad():
outputs = model.generate(
input_features=mels,
max_length=448,
do_sample=False,
forced_decoder_ids=forced_decoder_ids,
)
for filepath in tqdm(list(audio_dir.glob("*.wav")), desc="Processing files"):
wav, sr = load_and_normalize_audio(filepath, out_sr)
transcription = transcribe_audio(model, filepath)
save_path = save_dir / filepath.stem
save_audio_segments(transcription["segments"], wav, sr, save_path)

outputs = outputs.cpu().tolist()

# Remove EOS token
for output in outputs:
while output[-1] in [
processor.tokenizer.pad_token_id,
processor.tokenizer.eos_token_id,
]:
output.pop()
output.append(processor.tokenizer.eos_token_id)

transcriptions = processor.batch_decode(outputs, skip_special_tokens=False)
tokens = [",".join(map(str, line)) for line in outputs]
transcriptions = [
f"{token}\t{transcription}"
for token, transcription in zip(tokens, transcriptions)
]

return transcriptions, total_time


@click.command()
@click.argument("folder")
@click.option("--rank", default=0)
@click.option("--world-size", default=1)
@click.option("--num-workers", default=1)
@click.option("--language", default="english")
def main(folder: str, rank: int, world_size: int, num_workers: int, language: str):
global RANK_STR

if num_workers > 1 and world_size != num_workers:
RANK_STR = "[Master] "
logger.info(f"{RANK_STR}Spawning {num_workers} workers")

visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if visible_devices is None:
visible_devices = list(range(torch.cuda.device_count()))
else:
visible_devices = visible_devices.split(",")

processes = []
for i in range(num_workers):
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
args = [
"python",
__file__,
"--rank",
str(i),
"--world-size",
str(num_workers),
"--language",
language,
folder,
]
processes.append(
sp.Popen(
args,
env=env,
)
)

for p in processes:
p.wait()

logger.info(f"{RANK_STR}All workers finished")
return

# This is a worker
RANK_STR = f"[Rank: {rank}] "
logger.info(f"{RANK_STR}Starting worker")

files = [
str(file)
for file in Path(folder).rglob("*")
if file.suffix in [".wav", ".flac"]
]

logger.info(f"{RANK_STR}Found {len(files)} files")

files = sorted(files)
Random(42).shuffle(files)
files = files[rank::world_size]
logger.info(f"{RANK_STR}Processing {len(files)} files")

# Batch size 64
total_time = 0
begin_time = time.time()
processed_files = 0

for n_batch, idx in enumerate(range(0, len(files), 64)):
batch = files[idx : idx + 64]
trascriptions, batch_time = transcribe_batch(batch, language)
total_time += batch_time
processed_files += len(batch)

if (n_batch + 1) % 10 == 0:
eta = (
(time.time() - begin_time)
/ processed_files
* (len(files) - processed_files)
)
logger.info(
f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
)

# Write to file
for file, transcription in zip(batch, trascriptions):
Path(file).with_suffix(".whisper.txt").write_text(
transcription, encoding="utf-8"
)

# Stop if total time is more than 1000 / world_size hours
if total_time > 1000 / world_size * 3600:
break

logger.info(
f"{RANK_STR}Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
)
for segment_file in tqdm(
list(save_path.glob("*.wav")), desc="Transcribing segments"
):
text, _ = transcribe_segment(model, segment_file)
with open(segment_file.with_suffix(".lab"), "w", encoding="utf-8") as f:
f.write(text)

# process_output(save_dir, language, save_dir / "output.txt") # Dont need summarize to one file


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser(description="Audio Transcription with Whisper")
parser.add_argument(
"--model_size", type=str, default="large", help="Size of the Whisper model"
)
parser.add_argument(
"--audio_dir", type=str, required=True, help="Directory containing audio files"
)
parser.add_argument(
"--save_dir",
type=str,
required=True,
help="Directory to save processed audio files",
)
parser.add_argument(
"--language", type=str, default="ZH", help="Language of the transcription"
)
parser.add_argument("--out_sr", type=int, default=44100, help="Output sample rate")
args = parser.parse_args()

main(args.model_size, args.audio_dir, args.save_dir, args.out_sr, args.language)

0 comments on commit cf69582

Please sign in to comment.