diff --git a/docs/zh/index.md b/docs/zh/index.md index 2f1121f0..11324bd3 100644 --- a/docs/zh/index.md +++ b/docs/zh/index.md @@ -13,8 +13,8 @@ !!! warning - 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
- 此代码库与所有模型根据 CC-BY-NC-SA-4.0 许可证发布. + 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
+ 此代码库与所有模型根据 CC-BY-NC-SA-4.0 许可证发布.

@@ -33,23 +33,23 @@ Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法 1. 解压项目压缩包。 2. 点击 `install_env.bat` 安装环境。 - - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。 - - `USE_MIRROR=false` 使用原始站下载最新稳定版 `torch` 环境。`USE_MIRROR=true` 为从镜像站下载最新 `torch` 环境。默认为 `true`。 - - 可以通过编辑 `install_env.bat` 的 `INSTALL_TYPE` 项来决定是否启用可编译环境下载。 - - `INSTALL_TYPE=preview` 下载开发版编译环境。`INSTALL_TYPE=stable` 下载稳定版不带编译环境。 + - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。 + - `USE_MIRROR=false` 使用原始站下载最新稳定版 `torch` 环境。`USE_MIRROR=true` 为从镜像站下载最新 `torch` 环境。默认为 `true`。 + - 可以通过编辑 `install_env.bat` 的 `INSTALL_TYPE` 项来决定是否启用可编译环境下载。 + - `INSTALL_TYPE=preview` 下载开发版编译环境。`INSTALL_TYPE=stable` 下载稳定版不带编译环境。 3. 若第 2 步 `INSTALL_TYPE=preview` 则执行这一步(可跳过,此步为激活编译模型环境) - 1. 使用如下链接下载 LLVM 编译器。 - - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true) - - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true) - - 下载完 `LLVM-17.0.6-win64.exe` 后,双击进行安装,选择合适的安装位置,最重要的是勾选 `Add Path to Current User` 添加环境变量。 - - 确认安装完成。 - 2. 下载安装 Microsoft Visual C++ 可再发行程序包,解决潜在 .dll 丢失问题。 - - [MSVC++ 14.40.33810.0 下载](https://aka.ms/vs/17/release/vc_redist.x64.exe) - 3. 下载安装 Visual Studio 社区版以获取 MSVC++ 编译工具, 解决 LLVM 的头文件依赖问题。 - - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/) - - 安装好 Visual Studio Installer 之后,下载 Visual Studio Community 2022 - - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载 - 4. 下载安装 [CUDA Toolkit 12](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64) + 1. 使用如下链接下载 LLVM 编译器。 + - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true) + - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true) + - 下载完 `LLVM-17.0.6-win64.exe` 后,双击进行安装,选择合适的安装位置,最重要的是勾选 `Add Path to Current User` 添加环境变量。 + - 确认安装完成。 + 2. 下载安装 Microsoft Visual C++ 可再发行程序包,解决潜在 .dll 丢失问题。 + - [MSVC++ 14.40.33810.0 下载](https://aka.ms/vs/17/release/vc_redist.x64.exe) + 3. 下载安装 Visual Studio 社区版以获取 MSVC++ 编译工具, 解决 LLVM 的头文件依赖问题。 + - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/) + - 安装好 Visual Studio Installer 之后,下载 Visual Studio Community 2022 + - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载 + 4. 下载安装 [CUDA Toolkit 12](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64) 4. 双击 `start.bat` 打开训练推理 WebUI 管理界面. 如有需要,可照下列提示修改`API_FLAGS`. !!! info "可选" diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py index 4047aa53..78c82640 100644 --- a/fish_speech/utils/file.py +++ b/fish_speech/utils/file.py @@ -1,55 +1,5 @@ import os -from glob import glob from pathlib import Path -from typing import Union - -from loguru import logger -from natsort import natsorted - -AUDIO_EXTENSIONS = { - ".mp3", - ".wav", - ".flac", - ".ogg", - ".m4a", - ".wma", - ".aac", - ".aiff", - ".aif", - ".aifc", -} - - -def list_files( - path: Union[Path, str], - extensions: set[str] = None, - recursive: bool = False, - sort: bool = True, -) -> list[Path]: - """List files in a directory. - - Args: - path (Path): Path to the directory. - extensions (set, optional): Extensions to filter. Defaults to None. - recursive (bool, optional): Whether to search recursively. Defaults to False. - sort (bool, optional): Whether to sort the files. Defaults to True. - - Returns: - list: List of files. - """ - - if isinstance(path, str): - path = Path(path) - - if not path.exists(): - raise FileNotFoundError(f"Directory {path} does not exist.") - - files = [file for ext in extensions for file in path.rglob(f"*{ext}")] - - if sort: - files = natsorted(files) - - return files def get_latest_checkpoint(path: Path | str) -> Path | None: @@ -64,56 +14,3 @@ def get_latest_checkpoint(path: Path | str) -> Path | None: return None return ckpts[-1] - - -def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]: - """ - Load a Bert-VITS2 style filelist. - """ - - files = set() - results = [] - count_duplicated, count_not_found = 0, 0 - - LANGUAGE_TO_LANGUAGES = { - "zh": ["zh", "en"], - "jp": ["jp", "en"], - "en": ["en"], - } - - with open(path, "r", encoding="utf-8") as f: - for line in f.readlines(): - splits = line.strip().split("|", maxsplit=3) - if len(splits) != 4: - logger.warning(f"Invalid line: {line}") - continue - - filename, speaker, language, text = splits - file = Path(filename) - language = language.strip().lower() - - if language == "ja": - language = "jp" - - assert language in ["zh", "jp", "en"], f"Invalid language {language}" - languages = LANGUAGE_TO_LANGUAGES[language] - - if file in files: - logger.warning(f"Duplicated file: {file}") - count_duplicated += 1 - continue - - if not file.exists(): - logger.warning(f"File not found: {file}") - count_not_found += 1 - continue - - results.append((file, speaker, languages, text)) - - if count_duplicated > 0: - logger.warning(f"Total duplicated files: {count_duplicated}") - - if count_not_found > 0: - logger.warning(f"Total files not found: {count_not_found}") - - return results diff --git a/pyproject.toml b/pyproject.toml index 80f54d3a..83518957 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,11 @@ dependencies = [ "zstandard>=0.22.0", "pydub", "faster_whisper", - "modelscope==1.16.1", - "funasr==1.1.2", + "modelscope==1.17.1", + "funasr==1.1.5", "opencc-python-reimplemented==0.1.7", + "audio-seperator[gpu]==0.18.3", + "silero-vad", ] [project.optional-dependencies] diff --git a/run_cmd.bat b/run_cmd.bat index 05fda82d..c2af8a9b 100644 --- a/run_cmd.bat +++ b/run_cmd.bat @@ -29,7 +29,7 @@ set INSTALL_ENV_DIR=%cd%\fishenv\env set PYTHONNOUSERSITE=1 -set PYTHONPATH= +set PYTHONPATH=%~dp0 set PYTHONHOME= diff --git a/tools/api.py b/tools/api.py index 9b6f3d2e..05b31338 100644 --- a/tools/api.py +++ b/tools/api.py @@ -3,6 +3,7 @@ import json import queue import random +import sys import traceback import wave from argparse import ArgumentParser @@ -10,11 +11,11 @@ from pathlib import Path from typing import Annotated, Literal, Optional -import librosa import numpy as np import pyrootutils import soundfile as sf import torch +import torchaudio from kui.asgi import ( Body, HTTPException, @@ -87,7 +88,18 @@ def load_audio(reference_audio, sr): except base64.binascii.Error: raise ValueError("Invalid path or base64 string") - audio, _ = librosa.load(reference_audio, sr=sr, mono=True) + waveform, original_sr = torchaudio.load( + reference_audio, backend="sox" if sys.platform == "linux" else "soundfile" + ) + + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + if original_sr != sr: + resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr) + waveform = resampler(waveform) + + audio = waveform.squeeze().numpy() return audio diff --git a/tools/file.py b/tools/file.py new file mode 100644 index 00000000..b4b8051d --- /dev/null +++ b/tools/file.py @@ -0,0 +1,108 @@ +from pathlib import Path +from typing import Union + +from loguru import logger +from natsort import natsorted + +AUDIO_EXTENSIONS = { + ".mp3", + ".wav", + ".flac", + ".ogg", + ".m4a", + ".wma", + ".aac", + ".aiff", + ".aif", + ".aifc", +} + +VIDEO_EXTENSIONS = { + ".mp4", + ".avi", +} + + +def list_files( + path: Union[Path, str], + extensions: set[str] = None, + recursive: bool = False, + sort: bool = True, +) -> list[Path]: + """List files in a directory. + + Args: + path (Path): Path to the directory. + extensions (set, optional): Extensions to filter. Defaults to None. + recursive (bool, optional): Whether to search recursively. Defaults to False. + sort (bool, optional): Whether to sort the files. Defaults to True. + + Returns: + list: List of files. + """ + + if isinstance(path, str): + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Directory {path} does not exist.") + + files = [file for ext in extensions for file in path.rglob(f"*{ext}")] + + if sort: + files = natsorted(files) + + return files + + +def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]: + """ + Load a Bert-VITS2 style filelist. + """ + + files = set() + results = [] + count_duplicated, count_not_found = 0, 0 + + LANGUAGE_TO_LANGUAGES = { + "zh": ["zh", "en"], + "jp": ["jp", "en"], + "en": ["en"], + } + + with open(path, "r", encoding="utf-8") as f: + for line in f.readlines(): + splits = line.strip().split("|", maxsplit=3) + if len(splits) != 4: + logger.warning(f"Invalid line: {line}") + continue + + filename, speaker, language, text = splits + file = Path(filename) + language = language.strip().lower() + + if language == "ja": + language = "jp" + + assert language in ["zh", "jp", "en"], f"Invalid language {language}" + languages = LANGUAGE_TO_LANGUAGES[language] + + if file in files: + logger.warning(f"Duplicated file: {file}") + count_duplicated += 1 + continue + + if not file.exists(): + logger.warning(f"File not found: {file}") + count_not_found += 1 + continue + + results.append((file, speaker, languages, text)) + + if count_duplicated > 0: + logger.warning(f"Total duplicated files: {count_duplicated}") + + if count_not_found > 0: + logger.warning(f"Total files not found: {count_not_found}") + + return results diff --git a/tools/merge_asr_files.py b/tools/merge_asr_files.py index d86d29a7..cc120620 100644 --- a/tools/merge_asr_files.py +++ b/tools/merge_asr_files.py @@ -4,7 +4,7 @@ from pydub import AudioSegment from tqdm import tqdm -from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files +from tools.file import AUDIO_EXTENSIONS, list_files def merge_and_delete_files(save_dir, original_files): diff --git a/tools/sensevoice/README.md b/tools/sensevoice/README.md new file mode 100644 index 00000000..9a2078aa --- /dev/null +++ b/tools/sensevoice/README.md @@ -0,0 +1,59 @@ +# FunASR Command Line Interface + +This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files. + +## Requirements + +- Python >= 3.10 +- PyTorch <= 2.3.1 +- ffmpeg, pydub, audio-separator[gpu]. + +## Installation + +Install the required packages: + +```bash +pip install -e .[stable] +``` + +Make sure you have `ffmpeg` installed and available in your `PATH`. + +## Usage + +### Basic Usage + +To run the tool with default settings: + +```bash +python tools/sensevoice/fun_asr.py --audio-dir --save-dir +``` + +## Options + +| Option | Description | +| :-----------------------: | :---------------------------------------------------------------------------: | +| --audio-dir | Directory containing audio or video files. | +| --save-dir | Directory to save processed audio files. | +| --device | Device to use for processing. Options: cuda (default) or cpu. | +| --language | Language of the transcription. Default is auto. | +| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. | +| --punc | Enable punctuation prediction. | +| --denoise | Enable noise reduction (vocal separation). | + +## Example + +To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled: + +```bash +python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise +``` + +## Additional Notes + +- The tool supports `both audio and video files`. Videos will be converted to audio automatically. +- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks. +- The script will automatically create necessary directories in the `--save-dir`. + +## Troubleshooting + +If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency. diff --git a/tools/sensevoice/__init__.py b/tools/sensevoice/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/sensevoice/auto_model.py b/tools/sensevoice/auto_model.py new file mode 100644 index 00000000..dd2e1866 --- /dev/null +++ b/tools/sensevoice/auto_model.py @@ -0,0 +1,573 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import copy +import json +import logging +import os.path +import random +import re +import string +import time + +import numpy as np +import torch +from funasr.download.download_model_from_hub import download_model +from funasr.download.file import download_from_url +from funasr.register import tables +from funasr.train_utils.load_pretrained_model import load_pretrained_model +from funasr.train_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import export_utils, misc +from funasr.utils.load_utils import load_audio_text_image_video, load_bytes +from funasr.utils.misc import deep_update +from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en +from tqdm import tqdm + +from .vad_utils import merge_vad, slice_padding_audio_samples + +try: + from funasr.models.campplus.cluster_backend import ClusterBackend + from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk +except: + pass + + +def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): + """ """ + data_list = [] + key_list = [] + filelist = [".scp", ".txt", ".json", ".jsonl", ".text"] + + chars = string.ascii_letters + string.digits + if isinstance(data_in, str): + if data_in.startswith("http://") or data_in.startswith("https://"): # url + data_in = download_from_url(data_in) + + if isinstance(data_in, str) and os.path.exists( + data_in + ): # wav_path; filelist: wav.scp, file.jsonl;text.txt; + _, file_extension = os.path.splitext(data_in) + file_extension = file_extension.lower() + if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt; + with open(data_in, encoding="utf-8") as fin: + for line in fin: + key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) + if data_in.endswith( + ".jsonl" + ): # file.jsonl: json.dumps({"source": data}) + lines = json.loads(line.strip()) + data = lines["source"] + key = data["key"] if "key" in data else key + else: # filelist, wav.scp, text.txt: id \t data or data + lines = line.strip().split(maxsplit=1) + data = lines[1] if len(lines) > 1 else lines[0] + key = lines[0] if len(lines) > 1 else key + + data_list.append(data) + key_list.append(key) + else: + if key is None: + # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) + key = misc.extract_filename_without_extension(data_in) + data_list = [data_in] + key_list = [key] + elif isinstance(data_in, (list, tuple)): + if data_type is not None and isinstance( + data_type, (list, tuple) + ): # mutiple inputs + data_list_tmp = [] + for data_in_i, data_type_i in zip(data_in, data_type): + key_list, data_list_i = prepare_data_iterator( + data_in=data_in_i, data_type=data_type_i + ) + data_list_tmp.append(data_list_i) + data_list = [] + for item in zip(*data_list_tmp): + data_list.append(item) + else: + # [audio sample point, fbank, text] + data_list = data_in + key_list = [] + for data_i in data_in: + if isinstance(data_i, str) and os.path.exists(data_i): + key = misc.extract_filename_without_extension(data_i) + else: + if key is None: + key = "rand_key_" + "".join( + random.choice(chars) for _ in range(13) + ) + key_list.append(key) + + else: # raw text; audio sample point, fbank; bytes + if isinstance(data_in, bytes): # audio bytes + data_in = load_bytes(data_in) + if key is None: + key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) + data_list = [data_in] + key_list = [key] + + return key_list, data_list + + +class AutoModel: + + def __init__(self, **kwargs): + + try: + from funasr.utils.version_checker import check_for_update + + print( + "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel" + ) + check_for_update(disable=kwargs.get("disable_update", False)) + except: + pass + + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + logging.basicConfig(level=log_level) + + model, kwargs = self.build_model(**kwargs) + + # if vad_model is not None, build vad model else None + vad_model = kwargs.get("vad_model", None) + vad_kwargs = ( + {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {}) + ) + if vad_model is not None: + logging.info("Building VAD model.") + vad_kwargs["model"] = vad_model + vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master") + vad_kwargs["device"] = kwargs["device"] + vad_model, vad_kwargs = self.build_model(**vad_kwargs) + + # if punc_model is not None, build punc model else None + punc_model = kwargs.get("punc_model", None) + punc_kwargs = ( + {} + if kwargs.get("punc_kwargs", {}) is None + else kwargs.get("punc_kwargs", {}) + ) + if punc_model is not None: + logging.info("Building punc model.") + punc_kwargs["model"] = punc_model + punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master") + punc_kwargs["device"] = kwargs["device"] + punc_model, punc_kwargs = self.build_model(**punc_kwargs) + + # if spk_model is not None, build spk model else None + spk_model = kwargs.get("spk_model", None) + spk_kwargs = ( + {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {}) + ) + if spk_model is not None: + logging.info("Building SPK model.") + spk_kwargs["model"] = spk_model + spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master") + spk_kwargs["device"] = kwargs["device"] + spk_model, spk_kwargs = self.build_model(**spk_kwargs) + self.cb_model = ClusterBackend().to(kwargs["device"]) + spk_mode = kwargs.get("spk_mode", "punc_segment") + if spk_mode not in ["default", "vad_segment", "punc_segment"]: + logging.error( + "spk_mode should be one of default, vad_segment and punc_segment." + ) + self.spk_mode = spk_mode + + self.kwargs = kwargs + self.model = model + self.vad_model = vad_model + self.vad_kwargs = vad_kwargs + self.punc_model = punc_model + self.punc_kwargs = punc_kwargs + self.spk_model = spk_model + self.spk_kwargs = spk_kwargs + self.model_path = kwargs.get("model_path") + + @staticmethod + def build_model(**kwargs): + assert "model" in kwargs + if "model_conf" not in kwargs: + logging.info( + "download models from model hub: {}".format(kwargs.get("hub", "ms")) + ) + kwargs = download_model(**kwargs) + + set_all_random_seed(kwargs.get("seed", 0)) + + device = kwargs.get("device", "cuda") + if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0: + device = "cpu" + kwargs["batch_size"] = 1 + kwargs["device"] = device + + torch.set_num_threads(kwargs.get("ncpu", 4)) + + # build tokenizer + tokenizer = kwargs.get("tokenizer", None) + if tokenizer is not None: + tokenizer_class = tables.tokenizer_classes.get(tokenizer) + tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {})) + kwargs["token_list"] = ( + tokenizer.token_list if hasattr(tokenizer, "token_list") else None + ) + kwargs["token_list"] = ( + tokenizer.get_vocab() + if hasattr(tokenizer, "get_vocab") + else kwargs["token_list"] + ) + vocab_size = ( + len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 + ) + if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): + vocab_size = tokenizer.get_vocab_size() + else: + vocab_size = -1 + kwargs["tokenizer"] = tokenizer + + # build frontend + frontend = kwargs.get("frontend", None) + kwargs["input_size"] = None + if frontend is not None: + frontend_class = tables.frontend_classes.get(frontend) + frontend = frontend_class(**kwargs.get("frontend_conf", {})) + kwargs["input_size"] = ( + frontend.output_size() if hasattr(frontend, "output_size") else None + ) + kwargs["frontend"] = frontend + # build model + model_class = tables.model_classes.get(kwargs["model"]) + assert model_class is not None, f'{kwargs["model"]} is not registered' + model_conf = {} + deep_update(model_conf, kwargs.get("model_conf", {})) + deep_update(model_conf, kwargs) + model = model_class(**model_conf, vocab_size=vocab_size) + + # init_param + init_param = kwargs.get("init_param", None) + if init_param is not None: + if os.path.exists(init_param): + logging.info(f"Loading pretrained params from {init_param}") + load_pretrained_model( + model=model, + path=init_param, + ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), + oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", []), + excludes=kwargs.get("excludes", None), + ) + else: + print(f"error, init_param does not exist!: {init_param}") + + # fp16 + if kwargs.get("fp16", False): + model.to(torch.float16) + elif kwargs.get("bf16", False): + model.to(torch.bfloat16) + model.to(device) + + if not kwargs.get("disable_log", True): + tables.print() + + return model, kwargs + + def __call__(self, *args, **cfg): + kwargs = self.kwargs + deep_update(kwargs, cfg) + res = self.model(*args, kwargs) + return res + + def generate(self, input, input_len=None, **cfg): + if self.vad_model is None: + return self.inference(input, input_len=input_len, **cfg) + + else: + return self.inference_with_vad(input, input_len=input_len, **cfg) + + def inference( + self, input, input_len=None, model=None, kwargs=None, key=None, **cfg + ): + kwargs = self.kwargs if kwargs is None else kwargs + if "cache" in kwargs: + kwargs.pop("cache") + deep_update(kwargs, cfg) + model = self.model if model is None else model + model.eval() + + batch_size = kwargs.get("batch_size", 1) + # if kwargs.get("device", "cpu") == "cpu": + # batch_size = 1 + + key_list, data_list = prepare_data_iterator( + input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key + ) + + speed_stats = {} + asr_result_list = [] + num_samples = len(data_list) + disable_pbar = self.kwargs.get("disable_pbar", False) + pbar = ( + tqdm(colour="blue", total=num_samples, dynamic_ncols=True) + if not disable_pbar + else None + ) + time_speech_total = 0.0 + time_escape_total = 0.0 + for beg_idx in range(0, num_samples, batch_size): + end_idx = min(num_samples, beg_idx + batch_size) + data_batch = data_list[beg_idx:end_idx] + key_batch = key_list[beg_idx:end_idx] + batch = {"data_in": data_batch, "key": key_batch} + + if (end_idx - beg_idx) == 1 and kwargs.get( + "data_type", None + ) == "fbank": # fbank + batch["data_in"] = data_batch[0] + batch["data_lengths"] = input_len + + time1 = time.perf_counter() + with torch.no_grad(): + res = model.inference(**batch, **kwargs) + if isinstance(res, (list, tuple)): + results = res[0] if len(res) > 0 else [{"text": ""}] + meta_data = res[1] if len(res) > 1 else {} + time2 = time.perf_counter() + + asr_result_list.extend(results) + + # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item() + batch_data_time = meta_data.get("batch_data_time", -1) + time_escape = time2 - time1 + speed_stats["load_data"] = meta_data.get("load_data", 0.0) + speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0) + speed_stats["forward"] = f"{time_escape:0.3f}" + speed_stats["batch_size"] = f"{len(results)}" + speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}" + description = f"{speed_stats}, " + if pbar: + pbar.update(end_idx - beg_idx) + pbar.set_description(description) + time_speech_total += batch_data_time + time_escape_total += time_escape + + if pbar: + # pbar.update(1) + pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") + torch.cuda.empty_cache() + return asr_result_list + + def vad(self, input, input_len=None, **cfg): + kwargs = self.kwargs + # step.1: compute the vad model + deep_update(self.vad_kwargs, cfg) + beg_vad = time.time() + res = self.inference( + input, + input_len=input_len, + model=self.vad_model, + kwargs=self.vad_kwargs, + **cfg, + ) + end_vad = time.time() + # FIX(gcf): concat the vad clips for sense vocie model for better aed + if cfg.get("merge_vad", False): + for i in range(len(res)): + res[i]["value"] = merge_vad( + res[i]["value"], kwargs.get("merge_length_s", 15) * 1000 + ) + elapsed = end_vad - beg_vad + return elapsed, res + + def inference_with_vadres(self, input, vad_res, input_len=None, **cfg): + + kwargs = self.kwargs + + # step.2 compute asr model + model = self.model + deep_update(kwargs, cfg) + batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1) + batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000 + kwargs["batch_size"] = batch_size + + key_list, data_list = prepare_data_iterator( + input, input_len=input_len, data_type=kwargs.get("data_type", None) + ) + results_ret_list = [] + time_speech_total_all_samples = 1e-6 + + beg_total = time.time() + pbar_total = ( + tqdm(colour="red", total=len(vad_res), dynamic_ncols=True) + if not kwargs.get("disable_pbar", False) + else None + ) + + for i in range(len(vad_res)): + key = vad_res[i]["key"] + vadsegments = vad_res[i]["value"] + input_i = data_list[i] + fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000 + speech = load_audio_text_image_video( + input_i, fs=fs, audio_fs=kwargs.get("fs", 16000) + ) + speech_lengths = len(speech) + n = len(vadsegments) + data_with_index = [(vadsegments[i], i) for i in range(n)] + sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) + results_sorted = [] + + if not len(sorted_data): + results_ret_list.append({"key": key, "text": "", "timestamp": []}) + logging.info("decoding, utt: {}, empty speech".format(key)) + continue + + if len(sorted_data) > 0 and len(sorted_data[0]) > 0: + batch_size = max( + batch_size, sorted_data[0][0][1] - sorted_data[0][0][0] + ) + + if kwargs["device"] == "cpu": + batch_size = 0 + + beg_idx = 0 + beg_asr_total = time.time() + time_speech_total_per_sample = speech_lengths / 16000 + time_speech_total_all_samples += time_speech_total_per_sample + + # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True) + + all_segments = [] + max_len_in_batch = 0 + end_idx = 1 + + for j, _ in enumerate(range(0, n)): + # pbar_sample.update(1) + sample_length = sorted_data[j][0][1] - sorted_data[j][0][0] + potential_batch_length = max(max_len_in_batch, sample_length) * ( + j + 1 - beg_idx + ) + # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0] + if ( + j < n - 1 + and sample_length < batch_size_threshold_ms + and potential_batch_length < batch_size + ): + max_len_in_batch = max(max_len_in_batch, sample_length) + end_idx += 1 + continue + + speech_j, speech_lengths_j, intervals = slice_padding_audio_samples( + speech, speech_lengths, sorted_data[beg_idx:end_idx] + ) + results = self.inference( + speech_j, input_len=None, model=model, kwargs=kwargs, **cfg + ) + + for _b in range(len(speech_j)): + results[_b]["interval"] = intervals[_b] + + if self.spk_model is not None: + # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] + for _b in range(len(speech_j)): + vad_segments = [ + [ + sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0, + sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0, + np.array(speech_j[_b]), + ] + ] + segments = sv_chunk(vad_segments) + all_segments.extend(segments) + speech_b = [i[2] for i in segments] + spk_res = self.inference( + speech_b, + input_len=None, + model=self.spk_model, + kwargs=kwargs, + **cfg, + ) + results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"] + + beg_idx = end_idx + end_idx += 1 + max_len_in_batch = sample_length + if len(results) < 1: + continue + results_sorted.extend(results) + + # end_asr_total = time.time() + # time_escape_total_per_sample = end_asr_total - beg_asr_total + # pbar_sample.update(1) + # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " + # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " + # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") + + restored_data = [0] * n + for j in range(n): + index = sorted_data[j][1] + cur = results_sorted[j] + pattern = r"<\|([^|]+)\|>" + emotion_string = re.findall(pattern, cur["text"]) + cur["text"] = re.sub(pattern, "", cur["text"]) + cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string]) + if self.punc_model is not None and len(cur["text"].strip()) > 0: + deep_update(self.punc_kwargs, cfg) + punc_res = self.inference( + cur["text"], + model=self.punc_model, + kwargs=self.punc_kwargs, + **cfg, + ) + cur["text"] = punc_res[0]["text"] + + restored_data[index] = cur + + end_asr_total = time.time() + time_escape_total_per_sample = end_asr_total - beg_asr_total + if pbar_total: + pbar_total.update(1) + pbar_total.set_description( + f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " + f"time_speech: {time_speech_total_per_sample: 0.3f}, " + f"time_escape: {time_escape_total_per_sample:0.3f}" + ) + + # end_total = time.time() + # time_escape_total_all_samples = end_total - beg_total + # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, " + # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, " + # f"time_escape_all: {time_escape_total_all_samples:0.3f}") + return restored_data + + def export(self, input=None, **cfg): + """ + + :param input: + :param type: + :param quantize: + :param fallback_num: + :param calib_num: + :param opset_version: + :param cfg: + :return: + """ + + device = cfg.get("device", "cpu") + model = self.model.to(device=device) + kwargs = self.kwargs + deep_update(kwargs, cfg) + kwargs["device"] = device + del kwargs["model"] + model.eval() + + type = kwargs.get("type", "onnx") + + key_list, data_list = prepare_data_iterator( + input, input_len=None, data_type=kwargs.get("data_type", None), key=None + ) + + with torch.no_grad(): + export_dir = export_utils.export(model=model, data_in=data_list, **kwargs) + + return export_dir diff --git a/tools/sensevoice/fun_asr.py b/tools/sensevoice/fun_asr.py new file mode 100644 index 00000000..02c15a59 --- /dev/null +++ b/tools/sensevoice/fun_asr.py @@ -0,0 +1,332 @@ +import gc +import os +import re + +from audio_separator.separator import Separator + +os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr" +os.environ["UVR5_CACHE"] = "./.cache/uvr5-models" +import json +import subprocess +from pathlib import Path + +import click +import torch +from loguru import logger +from pydub import AudioSegment +from silero_vad import get_speech_timestamps, load_silero_vad, read_audio +from tqdm import tqdm + +from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files +from tools.sensevoice.auto_model import AutoModel + + +def uvr5_cli( + audio_dir: Path, + output_folder: Path, + audio_files: list[Path] | None = None, + output_format: str = "flac", + model: str = "BS-Roformer-Viperx-1296.ckpt", +): + # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"] + sepr = Separator( + model_file_dir=os.environ["UVR5_CACHE"], + output_dir=output_folder, + output_format=output_format, + ) + dictmodel = { + "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt", + "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt", + "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt", + "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt", + } + roformer_model = dictmodel[model] + sepr.load_model(roformer_model) + if audio_files is None: + audio_files = list_files( + path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True + ) + total_files = len(audio_files) + + print(f"{total_files} audio files found") + + res = [] + for audio in tqdm(audio_files, desc="Denoising: "): + file_path = str(audio_dir / audio) + sep_out = sepr.separate(file_path) + if isinstance(sep_out, str): + res.append(sep_out) + elif isinstance(sep_out, list): + res.extend(sep_out) + del sepr + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return res, roformer_model + + +def get_sample_rate(media_path: Path): + result = subprocess.run( + [ + "ffprobe", + "-v", + "quiet", + "-print_format", + "json", + "-show_streams", + str(media_path), + ], + capture_output=True, + text=True, + check=True, + ) + media_info = json.loads(result.stdout) + for stream in media_info.get("streams", []): + if stream.get("codec_type") == "audio": + return stream.get("sample_rate") + return "44100" # Default sample rate if not found + + +def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"): + sr = get_sample_rate(src_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + if src_path.resolve() == out_path.resolve(): + output = str(out_path.with_stem(out_path.stem + f"_{sr}")) + else: + output = str(out_path) + subprocess.run( + [ + "ffmpeg", + "-loglevel", + "error", + "-i", + str(src_path), + "-acodec", + "pcm_s16le" if out_fmt == "wav" else "flac", + "-ar", + sr, + "-ac", + "1", + "-y", + output, + ], + check=True, + ) + return out_path + + +def convert_video_to_audio(video_path: Path, audio_dir: Path): + cur_dir = audio_dir / video_path.relative_to(audio_dir).parent + vocals = [ + p + for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*") + if p.suffix in AUDIO_EXTENSIONS + ] + if len(vocals) > 0: + return vocals[0] + audio_path = cur_dir / f"{video_path.stem}.wav" + convert_to_mono(video_path, audio_path) + return audio_path + + +@click.command() +@click.option("--audio-dir", required=True, help="Directory containing audio files") +@click.option( + "--save-dir", required=True, help="Directory to save processed audio files" +) +@click.option("--device", default="cuda", help="Device to use [cuda / cpu]") +@click.option("--language", default="auto", help="Language of the transcription") +@click.option( + "--max_single_segment_time", + default=20000, + type=int, + help="Maximum of Output single audio duration(ms)", +) +@click.option("--fsmn-vad/--silero-vad", default=False) +@click.option("--punc/--no-punc", default=False) +@click.option("--denoise/--no-denoise", default=False) +@click.option("--save_emo/--no_save_emo", default=False) +def main( + audio_dir: str, + save_dir: str, + device: str, + language: str, + max_single_segment_time: int, + fsmn_vad: bool, + punc: bool, + denoise: bool, + save_emo: bool, +): + + audios_path = Path(audio_dir) + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + video_files = list_files( + path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True + ) + v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files] + + if denoise: + VOCAL = "_(Vocals)" + original_files = [ + p + for p in audios_path.glob("**/*") + if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem + ] + + _, cur_model = uvr5_cli( + audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files + ) + need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")] + need_remove.extend(original_files) + for _ in need_remove: + _.unlink() + vocal_files = [ + p + for p in audios_path.glob("**/*") + if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem + ] + for f in vocal_files: + fn, ext = f.stem, f.suffix + + v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0]) + if v_pos != -1: + new_fn = fn[: v_pos + len(VOCAL)] + new_f = f.with_name(new_fn + ext) + f = f.rename(new_f) + convert_to_mono(f, f, "flac") + f.unlink() + + audio_files = list_files( + path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True + ) + + logger.info("Loading / Downloading Funasr model...") + + model_dir = "iic/SenseVoiceSmall" + + vad_model = "fsmn-vad" if fsmn_vad else None + vad_kwargs = {"max_single_segment_time": max_single_segment_time} + punc_model = "ct-punc" if punc else None + + manager = AutoModel( + model=model_dir, + trust_remote_code=False, + vad_model=vad_model, + vad_kwargs=vad_kwargs, + punc_model=punc_model, + device=device, + ) + + if not fsmn_vad and vad_model is None: + vad_model = load_silero_vad() + + logger.info("Model loaded.") + + pattern = re.compile(r"_\d{3}\.") + + for file_path in tqdm(audio_files, desc="Processing audio file"): + + if pattern.search(file_path.name): + # logger.info(f"Skipping {file_path} as it has already been processed.") + continue + + file_stem = file_path.stem + file_suffix = file_path.suffix + + rel_path = Path(file_path).relative_to(audio_dir) + (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) + + audio = AudioSegment.from_file(file_path) + + cfg = dict( + cache={}, + language=language, # "zh", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, + batch_size_s=60, + ) + + if fsmn_vad: + elapsed, vad_res = manager.vad(input=str(file_path), **cfg) + else: + wav = read_audio( + str(file_path) + ) # backend (sox, soundfile, or ffmpeg) required! + audio_key = file_path.stem + audio_val = [] + speech_timestamps = get_speech_timestamps( + wav, + vad_model, + max_speech_duration_s=max_single_segment_time // 1000, + return_seconds=True, + ) + + audio_val = [ + [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)] + for timestamp in speech_timestamps + ] + vad_res = [] + vad_res.append(dict(key=audio_key, value=audio_val)) + + res = manager.inference_with_vadres( + input=str(file_path), vad_res=vad_res, **cfg + ) + + for i, info in enumerate(res): + [start_ms, end_ms] = info["interval"] + text = info["text"] + emo = info["emo"] + sliced_audio = audio[start_ms:end_ms] + audio_save_path = ( + save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}" + ) + sliced_audio.export(audio_save_path, format=file_suffix[1:]) + print(f"Exported {audio_save_path}: {text}") + + transcript_save_path = ( + save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab" + ) + with open( + transcript_save_path, + "w", + encoding="utf-8", + ) as f: + f.write(text) + + if save_emo: + emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo" + with open( + emo_save_path, + "w", + encoding="utf-8", + ) as f: + f.write(emo) + + if audios_path.resolve() == save_path.resolve(): + file_path.unlink() + + +if __name__ == "__main__": + main() + exit(0) + from funasr.utils.postprocess_utils import rich_transcription_postprocess + + # Load the audio file + audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav") + model_dir = "iic/SenseVoiceSmall" + m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0") + m.eval() + + res = m.inference( + data_in=f"{kwargs['model_path']}/example/zh.mp3", + language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, + ban_emo_unk=False, + **kwargs, + ) + + print(res) + text = rich_transcription_postprocess(res[0][0]["text"]) + print(text) diff --git a/tools/sensevoice/vad_utils.py b/tools/sensevoice/vad_utils.py new file mode 100644 index 00000000..3bef75ed --- /dev/null +++ b/tools/sensevoice/vad_utils.py @@ -0,0 +1,61 @@ +import torch +from torch.nn.utils.rnn import pad_sequence + + +def slice_padding_fbank(speech, speech_lengths, vad_segments): + speech_list = [] + speech_lengths_list = [] + for i, segment in enumerate(vad_segments): + + bed_idx = int(segment[0][0] * 16) + end_idx = min(int(segment[0][1] * 16), speech_lengths[0]) + speech_i = speech[0, bed_idx:end_idx] + speech_lengths_i = end_idx - bed_idx + speech_list.append(speech_i) + speech_lengths_list.append(speech_lengths_i) + feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0) + speech_lengths_pad = torch.Tensor(speech_lengths_list).int() + return feats_pad, speech_lengths_pad + + +def slice_padding_audio_samples(speech, speech_lengths, vad_segments): + speech_list = [] + speech_lengths_list = [] + intervals = [] + for i, segment in enumerate(vad_segments): + bed_idx = int(segment[0][0] * 16) + end_idx = min(int(segment[0][1] * 16), speech_lengths) + speech_i = speech[bed_idx:end_idx] + speech_lengths_i = end_idx - bed_idx + speech_list.append(speech_i) + speech_lengths_list.append(speech_lengths_i) + intervals.append([bed_idx // 16, end_idx // 16]) + + return speech_list, speech_lengths_list, intervals + + +def merge_vad(vad_result, max_length=15000, min_length=0): + new_result = [] + if len(vad_result) <= 1: + return vad_result + time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result] + time_step = sorted(list(set(time_step))) + if len(time_step) == 0: + return [] + bg = 0 + for i in range(len(time_step) - 1): + time = time_step[i] + if time_step[i + 1] - bg < max_length: + continue + if time - bg > min_length: + new_result.append([bg, time]) + # if time - bg < max_length * 1.5: + # new_result.append([bg, time]) + # else: + # split_num = int(time - bg) // max_length + 1 + # spl_l = int(time - bg) // split_num + # for j in range(split_num): + # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l]) + bg = time + new_result.append([bg, time_step[-1]]) + return new_result diff --git a/tools/smart_pad.py b/tools/smart_pad.py index 7fb55d9a..9772168f 100644 --- a/tools/smart_pad.py +++ b/tools/smart_pad.py @@ -8,7 +8,7 @@ import torchaudio from tqdm import tqdm -from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files +from tools.file import AUDIO_EXTENSIONS, list_files threshold = 10 ** (-50 / 20.0) diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py index 977afdf3..d24a5f39 100644 --- a/tools/vqgan/create_train_split.py +++ b/tools/vqgan/create_train_split.py @@ -7,7 +7,7 @@ from pydub import AudioSegment from tqdm import tqdm -from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist +from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist @click.command() diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py index fc84f7e6..bc6bc408 100644 --- a/tools/vqgan/extract_vq.py +++ b/tools/vqgan/extract_vq.py @@ -17,7 +17,7 @@ from loguru import logger from omegaconf import OmegaConf -from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist +from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist # register eval resolver OmegaConf.register_new_resolver("eval", eval) diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py index 6070e735..42e7de8a 100644 --- a/tools/whisper_asr.py +++ b/tools/whisper_asr.py @@ -32,7 +32,7 @@ from pydub import AudioSegment from tqdm import tqdm -from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files +from tools.file import AUDIO_EXTENSIONS, list_files @click.command() @@ -83,8 +83,6 @@ def main( path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True ) - numbered_suffix_pattern = re.compile(r"-\d{3}$") - for file_path in tqdm(audio_files, desc="Processing audio file"): file_stem = file_path.stem file_suffix = file_path.suffix @@ -92,18 +90,6 @@ def main( rel_path = Path(file_path).relative_to(audio_dir) (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) - # Skip files that already have a .lab file or a -{3-digit number} suffix - numbered_suffix = numbered_suffix_pattern.search(file_stem) - lab_file = file_path.with_suffix(".lab") - - if numbered_suffix and lab_file.exists(): - continue - - if not numbered_suffix and lab_file.with_stem(lab_file.stem + "-001").exists(): - if file_path.exists(): - file_path.unlink() - continue - audio = AudioSegment.from_file(file_path) segments, info = model.transcribe( @@ -119,6 +105,7 @@ def main( ) print("Total len(ms): ", len(audio)) + whole_text = None for segment in segments: id, start, end, text = ( segment.id, @@ -127,26 +114,24 @@ def main( segment.text, ) print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text)) - start_ms = int(start * 1000) - end_ms = int(end * 1000) + 200 # add 0.2s avoid truncating - segment_audio = audio[start_ms:end_ms] - audio_save_path = ( - save_path / rel_path.parent / f"{file_stem}-{id:03d}{file_suffix}" - ) - segment_audio.export(audio_save_path, format=file_suffix[1:]) - print(f"Exported {audio_save_path}") - - transcript_save_path = ( - save_path / rel_path.parent / f"{file_stem}-{id:03d}.lab" - ) - with open( - transcript_save_path, - "w", - encoding="utf-8", - ) as f: - f.write(segment.text) - - file_path.unlink() + if not whole_text: + whole_text = text + else: + whole_text += ", " + text + + whole_text += "." + + audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}" + audio.export(audio_save_path, format=file_suffix[1:]) + print(f"Exported {audio_save_path}") + + transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab" + with open( + transcript_save_path, + "w", + encoding="utf-8", + ) as f: + f.write(whole_text) if __name__ == "__main__":