-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Docker Compose and Data Preprocessing Script (#10)
* 1.add Docker Compose for development; 2.add pre_data for precess dataset * 1.add Docker Compose for development; 2.add pre_data for precess dataset * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change pre_dataset to whisper_asr.py * change pre_dataset to whisper_asr.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
d3c0dee
commit cf69582
Showing
2 changed files
with
125 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
version: '3.8' | ||
|
||
services: | ||
fish-speech: | ||
build: . | ||
container_name: fish-speech | ||
volumes: | ||
- ./data:/exp/data | ||
- ./raw_data:/exp/raw_data | ||
deploy: | ||
resources: | ||
reservations: | ||
devices: | ||
- driver: nvidia | ||
count: all | ||
capabilities: [gpu] | ||
command: tail -f /dev/null | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,183 +1,126 @@ | ||
# This file is used to convert the audio files to text files using the Whisper model. | ||
# It's mainly used to generate the training data for the VQ model. | ||
|
||
""" | ||
Used to transcribe all audio files in one folder into another folder. | ||
e.g. | ||
Directory structure: | ||
--pre_data_root | ||
----SP_1 | ||
------01.wav | ||
------02.wav | ||
------...... | ||
----SP_2 | ||
------01.wav | ||
------02.wav | ||
------...... | ||
Use | ||
python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1 | ||
to transcribe the first speaker. | ||
Use | ||
python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2 | ||
to transcribe the second speaker. | ||
Note: Be aware of your audio sample rate, which defaults to 44.1kHz. | ||
""" | ||
|
||
import argparse | ||
import os | ||
import subprocess as sp | ||
import time | ||
from datetime import timedelta | ||
from functools import lru_cache | ||
from pathlib import Path | ||
from random import Random | ||
|
||
import click | ||
import librosa | ||
import numpy as np | ||
import torch | ||
from loguru import logger | ||
from transformers import WhisperProcessor | ||
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim | ||
import whisper | ||
from scipy.io import wavfile | ||
from tqdm import tqdm | ||
|
||
from fish_speech.modules.flash_whisper import FlashWhisperForConditionalGeneration | ||
|
||
RANK_STR = "" | ||
def load_and_normalize_audio(filepath, target_sr): | ||
wav, sr = librosa.load(filepath, sr=None, mono=True) | ||
wav, _ = librosa.effects.trim(wav, top_db=20) | ||
peak = np.abs(wav).max() | ||
if peak > 1.0: | ||
wav /= peak / 0.98 | ||
return librosa.resample(wav, orig_sr=sr, target_sr=target_sr), target_sr | ||
|
||
|
||
@lru_cache(maxsize=1) | ||
def get_whisper_model(): | ||
model = FlashWhisperForConditionalGeneration.from_pretrained( | ||
"openai/whisper-medium" | ||
).cuda() | ||
model.eval() | ||
logger.info(f"{RANK_STR}Loaded model") | ||
def transcribe_audio(model, filepath): | ||
return model.transcribe( | ||
filepath, word_timestamps=True, task="transcribe", beam_size=5, best_of=5 | ||
) | ||
|
||
return model | ||
|
||
def save_audio_segments(segments, wav, sr, save_path): | ||
for i, seg in enumerate(segments): | ||
start_time, end_time = seg["start"], seg["end"] | ||
wav_seg = wav[int(start_time * sr) : int(end_time * sr)] | ||
wav_seg_name = f"{save_path.stem}_{i}.wav" | ||
out_fpath = save_path / wav_seg_name | ||
wavfile.write( | ||
out_fpath, rate=sr, data=(wav_seg * np.iinfo(np.int16).max).astype(np.int16) | ||
) | ||
|
||
@lru_cache(maxsize=1) | ||
def get_whisper_processor(): | ||
return WhisperProcessor.from_pretrained("openai/whisper-medium") | ||
|
||
def transcribe_segment(model, filepath): | ||
audio = whisper.load_audio(filepath) | ||
audio = whisper.pad_or_trim(audio) | ||
mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(model.device) | ||
_, probs = model.detect_language(mel) | ||
lang = max(probs, key=probs.get) | ||
options = whisper.DecodingOptions(beam_size=5) | ||
result = whisper.decode(model, mel, options) | ||
return result.text, lang | ||
|
||
|
||
def process_output(save_dir, language, out_file): | ||
with open(out_file, "w", encoding="utf-8") as wf: | ||
ch_name = save_dir.stem | ||
for file in save_dir.glob("*.lab"): | ||
with open(file, "r", encoding="utf-8") as perFile: | ||
line = perFile.readline().strip() | ||
result = ( | ||
f"{save_dir}/{ch_name}/{file.stem}.wav|{ch_name}|{language}|{line}" | ||
) | ||
wf.write(f"{result}\n") | ||
|
||
def transcribe_batch(files: list[str], language: str): | ||
wavs = [load_audio(file, 16000) for file in files] | ||
total_time = sum([len(wav) for wav in wavs]) / 16000 | ||
wavs = [pad_or_trim(wav) for wav in wavs] | ||
|
||
wavs = torch.from_numpy(np.stack(wavs)).float().cuda() | ||
mels = log_mel_spectrogram(wavs).cuda() | ||
model = get_whisper_model() | ||
processor = get_whisper_processor() | ||
forced_decoder_ids = processor.get_decoder_prompt_ids( | ||
language=language, task="transcribe" | ||
) | ||
def main(model_size, audio_dir, save_dir, out_sr, language): | ||
model = whisper.load_model(model_size) | ||
audio_dir, save_dir = Path(audio_dir), Path(save_dir) | ||
save_dir.mkdir(exist_ok=True) | ||
|
||
with torch.no_grad(): | ||
outputs = model.generate( | ||
input_features=mels, | ||
max_length=448, | ||
do_sample=False, | ||
forced_decoder_ids=forced_decoder_ids, | ||
) | ||
for filepath in tqdm(list(audio_dir.glob("*.wav")), desc="Processing files"): | ||
wav, sr = load_and_normalize_audio(filepath, out_sr) | ||
transcription = transcribe_audio(model, filepath) | ||
save_path = save_dir / filepath.stem | ||
save_audio_segments(transcription["segments"], wav, sr, save_path) | ||
|
||
outputs = outputs.cpu().tolist() | ||
|
||
# Remove EOS token | ||
for output in outputs: | ||
while output[-1] in [ | ||
processor.tokenizer.pad_token_id, | ||
processor.tokenizer.eos_token_id, | ||
]: | ||
output.pop() | ||
output.append(processor.tokenizer.eos_token_id) | ||
|
||
transcriptions = processor.batch_decode(outputs, skip_special_tokens=False) | ||
tokens = [",".join(map(str, line)) for line in outputs] | ||
transcriptions = [ | ||
f"{token}\t{transcription}" | ||
for token, transcription in zip(tokens, transcriptions) | ||
] | ||
|
||
return transcriptions, total_time | ||
|
||
|
||
@click.command() | ||
@click.argument("folder") | ||
@click.option("--rank", default=0) | ||
@click.option("--world-size", default=1) | ||
@click.option("--num-workers", default=1) | ||
@click.option("--language", default="english") | ||
def main(folder: str, rank: int, world_size: int, num_workers: int, language: str): | ||
global RANK_STR | ||
|
||
if num_workers > 1 and world_size != num_workers: | ||
RANK_STR = "[Master] " | ||
logger.info(f"{RANK_STR}Spawning {num_workers} workers") | ||
|
||
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) | ||
if visible_devices is None: | ||
visible_devices = list(range(torch.cuda.device_count())) | ||
else: | ||
visible_devices = visible_devices.split(",") | ||
|
||
processes = [] | ||
for i in range(num_workers): | ||
env = os.environ.copy() | ||
env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)]) | ||
args = [ | ||
"python", | ||
__file__, | ||
"--rank", | ||
str(i), | ||
"--world-size", | ||
str(num_workers), | ||
"--language", | ||
language, | ||
folder, | ||
] | ||
processes.append( | ||
sp.Popen( | ||
args, | ||
env=env, | ||
) | ||
) | ||
|
||
for p in processes: | ||
p.wait() | ||
|
||
logger.info(f"{RANK_STR}All workers finished") | ||
return | ||
|
||
# This is a worker | ||
RANK_STR = f"[Rank: {rank}] " | ||
logger.info(f"{RANK_STR}Starting worker") | ||
|
||
files = [ | ||
str(file) | ||
for file in Path(folder).rglob("*") | ||
if file.suffix in [".wav", ".flac"] | ||
] | ||
|
||
logger.info(f"{RANK_STR}Found {len(files)} files") | ||
|
||
files = sorted(files) | ||
Random(42).shuffle(files) | ||
files = files[rank::world_size] | ||
logger.info(f"{RANK_STR}Processing {len(files)} files") | ||
|
||
# Batch size 64 | ||
total_time = 0 | ||
begin_time = time.time() | ||
processed_files = 0 | ||
|
||
for n_batch, idx in enumerate(range(0, len(files), 64)): | ||
batch = files[idx : idx + 64] | ||
trascriptions, batch_time = transcribe_batch(batch, language) | ||
total_time += batch_time | ||
processed_files += len(batch) | ||
|
||
if (n_batch + 1) % 10 == 0: | ||
eta = ( | ||
(time.time() - begin_time) | ||
/ processed_files | ||
* (len(files) - processed_files) | ||
) | ||
logger.info( | ||
f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s" | ||
) | ||
|
||
# Write to file | ||
for file, transcription in zip(batch, trascriptions): | ||
Path(file).with_suffix(".whisper.txt").write_text( | ||
transcription, encoding="utf-8" | ||
) | ||
|
||
# Stop if total time is more than 1000 / world_size hours | ||
if total_time > 1000 / world_size * 3600: | ||
break | ||
|
||
logger.info( | ||
f"{RANK_STR}Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio" | ||
) | ||
for segment_file in tqdm( | ||
list(save_path.glob("*.wav")), desc="Transcribing segments" | ||
): | ||
text, _ = transcribe_segment(model, segment_file) | ||
with open(segment_file.with_suffix(".lab"), "w", encoding="utf-8") as f: | ||
f.write(text) | ||
|
||
# process_output(save_dir, language, save_dir / "output.txt") # Dont need summarize to one file | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
parser = argparse.ArgumentParser(description="Audio Transcription with Whisper") | ||
parser.add_argument( | ||
"--model_size", type=str, default="large", help="Size of the Whisper model" | ||
) | ||
parser.add_argument( | ||
"--audio_dir", type=str, required=True, help="Directory containing audio files" | ||
) | ||
parser.add_argument( | ||
"--save_dir", | ||
type=str, | ||
required=True, | ||
help="Directory to save processed audio files", | ||
) | ||
parser.add_argument( | ||
"--language", type=str, default="ZH", help="Language of the transcription" | ||
) | ||
parser.add_argument("--out_sr", type=int, default=44100, help="Output sample rate") | ||
args = parser.parse_args() | ||
|
||
main(args.model_size, args.audio_dir, args.save_dir, args.out_sr, args.language) |