diff --git a/docs/zh/index.md b/docs/zh/index.md
index 2f1121f0..11324bd3 100644
--- a/docs/zh/index.md
+++ b/docs/zh/index.md
@@ -13,8 +13,8 @@
!!! warning
- 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
- 此代码库与所有模型根据 CC-BY-NC-SA-4.0 许可证发布.
+ 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
+ 此代码库与所有模型根据 CC-BY-NC-SA-4.0 许可证发布.
@@ -33,23 +33,23 @@ Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法
1. 解压项目压缩包。
2. 点击 `install_env.bat` 安装环境。
- - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。
- - `USE_MIRROR=false` 使用原始站下载最新稳定版 `torch` 环境。`USE_MIRROR=true` 为从镜像站下载最新 `torch` 环境。默认为 `true`。
- - 可以通过编辑 `install_env.bat` 的 `INSTALL_TYPE` 项来决定是否启用可编译环境下载。
- - `INSTALL_TYPE=preview` 下载开发版编译环境。`INSTALL_TYPE=stable` 下载稳定版不带编译环境。
+ - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。
+ - `USE_MIRROR=false` 使用原始站下载最新稳定版 `torch` 环境。`USE_MIRROR=true` 为从镜像站下载最新 `torch` 环境。默认为 `true`。
+ - 可以通过编辑 `install_env.bat` 的 `INSTALL_TYPE` 项来决定是否启用可编译环境下载。
+ - `INSTALL_TYPE=preview` 下载开发版编译环境。`INSTALL_TYPE=stable` 下载稳定版不带编译环境。
3. 若第 2 步 `INSTALL_TYPE=preview` 则执行这一步(可跳过,此步为激活编译模型环境)
- 1. 使用如下链接下载 LLVM 编译器。
- - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
- - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
- - 下载完 `LLVM-17.0.6-win64.exe` 后,双击进行安装,选择合适的安装位置,最重要的是勾选 `Add Path to Current User` 添加环境变量。
- - 确认安装完成。
- 2. 下载安装 Microsoft Visual C++ 可再发行程序包,解决潜在 .dll 丢失问题。
- - [MSVC++ 14.40.33810.0 下载](https://aka.ms/vs/17/release/vc_redist.x64.exe)
- 3. 下载安装 Visual Studio 社区版以获取 MSVC++ 编译工具, 解决 LLVM 的头文件依赖问题。
- - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/)
- - 安装好 Visual Studio Installer 之后,下载 Visual Studio Community 2022
- - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载
- 4. 下载安装 [CUDA Toolkit 12](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+ 1. 使用如下链接下载 LLVM 编译器。
+ - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - 下载完 `LLVM-17.0.6-win64.exe` 后,双击进行安装,选择合适的安装位置,最重要的是勾选 `Add Path to Current User` 添加环境变量。
+ - 确认安装完成。
+ 2. 下载安装 Microsoft Visual C++ 可再发行程序包,解决潜在 .dll 丢失问题。
+ - [MSVC++ 14.40.33810.0 下载](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. 下载安装 Visual Studio 社区版以获取 MSVC++ 编译工具, 解决 LLVM 的头文件依赖问题。
+ - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/)
+ - 安装好 Visual Studio Installer 之后,下载 Visual Studio Community 2022
+ - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载
+ 4. 下载安装 [CUDA Toolkit 12](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
4. 双击 `start.bat` 打开训练推理 WebUI 管理界面. 如有需要,可照下列提示修改`API_FLAGS`.
!!! info "可选"
diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py
index 4047aa53..78c82640 100644
--- a/fish_speech/utils/file.py
+++ b/fish_speech/utils/file.py
@@ -1,55 +1,5 @@
import os
-from glob import glob
from pathlib import Path
-from typing import Union
-
-from loguru import logger
-from natsort import natsorted
-
-AUDIO_EXTENSIONS = {
- ".mp3",
- ".wav",
- ".flac",
- ".ogg",
- ".m4a",
- ".wma",
- ".aac",
- ".aiff",
- ".aif",
- ".aifc",
-}
-
-
-def list_files(
- path: Union[Path, str],
- extensions: set[str] = None,
- recursive: bool = False,
- sort: bool = True,
-) -> list[Path]:
- """List files in a directory.
-
- Args:
- path (Path): Path to the directory.
- extensions (set, optional): Extensions to filter. Defaults to None.
- recursive (bool, optional): Whether to search recursively. Defaults to False.
- sort (bool, optional): Whether to sort the files. Defaults to True.
-
- Returns:
- list: List of files.
- """
-
- if isinstance(path, str):
- path = Path(path)
-
- if not path.exists():
- raise FileNotFoundError(f"Directory {path} does not exist.")
-
- files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
-
- if sort:
- files = natsorted(files)
-
- return files
def get_latest_checkpoint(path: Path | str) -> Path | None:
@@ -64,56 +14,3 @@ def get_latest_checkpoint(path: Path | str) -> Path | None:
return None
return ckpts[-1]
-
-
-def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
- """
- Load a Bert-VITS2 style filelist.
- """
-
- files = set()
- results = []
- count_duplicated, count_not_found = 0, 0
-
- LANGUAGE_TO_LANGUAGES = {
- "zh": ["zh", "en"],
- "jp": ["jp", "en"],
- "en": ["en"],
- }
-
- with open(path, "r", encoding="utf-8") as f:
- for line in f.readlines():
- splits = line.strip().split("|", maxsplit=3)
- if len(splits) != 4:
- logger.warning(f"Invalid line: {line}")
- continue
-
- filename, speaker, language, text = splits
- file = Path(filename)
- language = language.strip().lower()
-
- if language == "ja":
- language = "jp"
-
- assert language in ["zh", "jp", "en"], f"Invalid language {language}"
- languages = LANGUAGE_TO_LANGUAGES[language]
-
- if file in files:
- logger.warning(f"Duplicated file: {file}")
- count_duplicated += 1
- continue
-
- if not file.exists():
- logger.warning(f"File not found: {file}")
- count_not_found += 1
- continue
-
- results.append((file, speaker, languages, text))
-
- if count_duplicated > 0:
- logger.warning(f"Total duplicated files: {count_duplicated}")
-
- if count_not_found > 0:
- logger.warning(f"Total files not found: {count_not_found}")
-
- return results
diff --git a/pyproject.toml b/pyproject.toml
index 80f54d3a..83518957 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,9 +38,11 @@ dependencies = [
"zstandard>=0.22.0",
"pydub",
"faster_whisper",
- "modelscope==1.16.1",
- "funasr==1.1.2",
+ "modelscope==1.17.1",
+ "funasr==1.1.5",
"opencc-python-reimplemented==0.1.7",
+ "audio-seperator[gpu]==0.18.3",
+ "silero-vad",
]
[project.optional-dependencies]
diff --git a/run_cmd.bat b/run_cmd.bat
index 05fda82d..c2af8a9b 100644
--- a/run_cmd.bat
+++ b/run_cmd.bat
@@ -29,7 +29,7 @@ set INSTALL_ENV_DIR=%cd%\fishenv\env
set PYTHONNOUSERSITE=1
-set PYTHONPATH=
+set PYTHONPATH=%~dp0
set PYTHONHOME=
diff --git a/tools/api.py b/tools/api.py
index 9b6f3d2e..05b31338 100644
--- a/tools/api.py
+++ b/tools/api.py
@@ -3,6 +3,7 @@
import json
import queue
import random
+import sys
import traceback
import wave
from argparse import ArgumentParser
@@ -10,11 +11,11 @@
from pathlib import Path
from typing import Annotated, Literal, Optional
-import librosa
import numpy as np
import pyrootutils
import soundfile as sf
import torch
+import torchaudio
from kui.asgi import (
Body,
HTTPException,
@@ -87,7 +88,18 @@ def load_audio(reference_audio, sr):
except base64.binascii.Error:
raise ValueError("Invalid path or base64 string")
- audio, _ = librosa.load(reference_audio, sr=sr, mono=True)
+ waveform, original_sr = torchaudio.load(
+ reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
+ )
+
+ if waveform.shape[0] > 1:
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
+
+ if original_sr != sr:
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
+ waveform = resampler(waveform)
+
+ audio = waveform.squeeze().numpy()
return audio
diff --git a/tools/file.py b/tools/file.py
new file mode 100644
index 00000000..b4b8051d
--- /dev/null
+++ b/tools/file.py
@@ -0,0 +1,108 @@
+from pathlib import Path
+from typing import Union
+
+from loguru import logger
+from natsort import natsorted
+
+AUDIO_EXTENSIONS = {
+ ".mp3",
+ ".wav",
+ ".flac",
+ ".ogg",
+ ".m4a",
+ ".wma",
+ ".aac",
+ ".aiff",
+ ".aif",
+ ".aifc",
+}
+
+VIDEO_EXTENSIONS = {
+ ".mp4",
+ ".avi",
+}
+
+
+def list_files(
+ path: Union[Path, str],
+ extensions: set[str] = None,
+ recursive: bool = False,
+ sort: bool = True,
+) -> list[Path]:
+ """List files in a directory.
+
+ Args:
+ path (Path): Path to the directory.
+ extensions (set, optional): Extensions to filter. Defaults to None.
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
+ sort (bool, optional): Whether to sort the files. Defaults to True.
+
+ Returns:
+ list: List of files.
+ """
+
+ if isinstance(path, str):
+ path = Path(path)
+
+ if not path.exists():
+ raise FileNotFoundError(f"Directory {path} does not exist.")
+
+ files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
+
+ if sort:
+ files = natsorted(files)
+
+ return files
+
+
+def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
+ """
+ Load a Bert-VITS2 style filelist.
+ """
+
+ files = set()
+ results = []
+ count_duplicated, count_not_found = 0, 0
+
+ LANGUAGE_TO_LANGUAGES = {
+ "zh": ["zh", "en"],
+ "jp": ["jp", "en"],
+ "en": ["en"],
+ }
+
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ splits = line.strip().split("|", maxsplit=3)
+ if len(splits) != 4:
+ logger.warning(f"Invalid line: {line}")
+ continue
+
+ filename, speaker, language, text = splits
+ file = Path(filename)
+ language = language.strip().lower()
+
+ if language == "ja":
+ language = "jp"
+
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
+ languages = LANGUAGE_TO_LANGUAGES[language]
+
+ if file in files:
+ logger.warning(f"Duplicated file: {file}")
+ count_duplicated += 1
+ continue
+
+ if not file.exists():
+ logger.warning(f"File not found: {file}")
+ count_not_found += 1
+ continue
+
+ results.append((file, speaker, languages, text))
+
+ if count_duplicated > 0:
+ logger.warning(f"Total duplicated files: {count_duplicated}")
+
+ if count_not_found > 0:
+ logger.warning(f"Total files not found: {count_not_found}")
+
+ return results
diff --git a/tools/merge_asr_files.py b/tools/merge_asr_files.py
index d86d29a7..cc120620 100644
--- a/tools/merge_asr_files.py
+++ b/tools/merge_asr_files.py
@@ -4,7 +4,7 @@
from pydub import AudioSegment
from tqdm import tqdm
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+from tools.file import AUDIO_EXTENSIONS, list_files
def merge_and_delete_files(save_dir, original_files):
diff --git a/tools/sensevoice/README.md b/tools/sensevoice/README.md
new file mode 100644
index 00000000..9a2078aa
--- /dev/null
+++ b/tools/sensevoice/README.md
@@ -0,0 +1,59 @@
+# FunASR Command Line Interface
+
+This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
+
+## Requirements
+
+- Python >= 3.10
+- PyTorch <= 2.3.1
+- ffmpeg, pydub, audio-separator[gpu].
+
+## Installation
+
+Install the required packages:
+
+```bash
+pip install -e .[stable]
+```
+
+Make sure you have `ffmpeg` installed and available in your `PATH`.
+
+## Usage
+
+### Basic Usage
+
+To run the tool with default settings:
+
+```bash
+python tools/sensevoice/fun_asr.py --audio-dir --save-dir
+```
+
+## Options
+
+| Option | Description |
+| :-----------------------: | :---------------------------------------------------------------------------: |
+| --audio-dir | Directory containing audio or video files. |
+| --save-dir | Directory to save processed audio files. |
+| --device | Device to use for processing. Options: cuda (default) or cpu. |
+| --language | Language of the transcription. Default is auto. |
+| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
+| --punc | Enable punctuation prediction. |
+| --denoise | Enable noise reduction (vocal separation). |
+
+## Example
+
+To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
+
+```bash
+python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
+```
+
+## Additional Notes
+
+- The tool supports `both audio and video files`. Videos will be converted to audio automatically.
+- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
+- The script will automatically create necessary directories in the `--save-dir`.
+
+## Troubleshooting
+
+If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
diff --git a/tools/sensevoice/__init__.py b/tools/sensevoice/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tools/sensevoice/auto_model.py b/tools/sensevoice/auto_model.py
new file mode 100644
index 00000000..dd2e1866
--- /dev/null
+++ b/tools/sensevoice/auto_model.py
@@ -0,0 +1,573 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import copy
+import json
+import logging
+import os.path
+import random
+import re
+import string
+import time
+
+import numpy as np
+import torch
+from funasr.download.download_model_from_hub import download_model
+from funasr.download.file import download_from_url
+from funasr.register import tables
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import export_utils, misc
+from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
+from funasr.utils.misc import deep_update
+from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
+from tqdm import tqdm
+
+from .vad_utils import merge_vad, slice_padding_audio_samples
+
+try:
+ from funasr.models.campplus.cluster_backend import ClusterBackend
+ from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
+except:
+ pass
+
+
+def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
+ """ """
+ data_list = []
+ key_list = []
+ filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
+
+ chars = string.ascii_letters + string.digits
+ if isinstance(data_in, str):
+ if data_in.startswith("http://") or data_in.startswith("https://"): # url
+ data_in = download_from_url(data_in)
+
+ if isinstance(data_in, str) and os.path.exists(
+ data_in
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
+ _, file_extension = os.path.splitext(data_in)
+ file_extension = file_extension.lower()
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
+ with open(data_in, encoding="utf-8") as fin:
+ for line in fin:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ if data_in.endswith(
+ ".jsonl"
+ ): # file.jsonl: json.dumps({"source": data})
+ lines = json.loads(line.strip())
+ data = lines["source"]
+ key = data["key"] if "key" in data else key
+ else: # filelist, wav.scp, text.txt: id \t data or data
+ lines = line.strip().split(maxsplit=1)
+ data = lines[1] if len(lines) > 1 else lines[0]
+ key = lines[0] if len(lines) > 1 else key
+
+ data_list.append(data)
+ key_list.append(key)
+ else:
+ if key is None:
+ # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ key = misc.extract_filename_without_extension(data_in)
+ data_list = [data_in]
+ key_list = [key]
+ elif isinstance(data_in, (list, tuple)):
+ if data_type is not None and isinstance(
+ data_type, (list, tuple)
+ ): # mutiple inputs
+ data_list_tmp = []
+ for data_in_i, data_type_i in zip(data_in, data_type):
+ key_list, data_list_i = prepare_data_iterator(
+ data_in=data_in_i, data_type=data_type_i
+ )
+ data_list_tmp.append(data_list_i)
+ data_list = []
+ for item in zip(*data_list_tmp):
+ data_list.append(item)
+ else:
+ # [audio sample point, fbank, text]
+ data_list = data_in
+ key_list = []
+ for data_i in data_in:
+ if isinstance(data_i, str) and os.path.exists(data_i):
+ key = misc.extract_filename_without_extension(data_i)
+ else:
+ if key is None:
+ key = "rand_key_" + "".join(
+ random.choice(chars) for _ in range(13)
+ )
+ key_list.append(key)
+
+ else: # raw text; audio sample point, fbank; bytes
+ if isinstance(data_in, bytes): # audio bytes
+ data_in = load_bytes(data_in)
+ if key is None:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ data_list = [data_in]
+ key_list = [key]
+
+ return key_list, data_list
+
+
+class AutoModel:
+
+ def __init__(self, **kwargs):
+
+ try:
+ from funasr.utils.version_checker import check_for_update
+
+ print(
+ "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
+ )
+ check_for_update(disable=kwargs.get("disable_update", False))
+ except:
+ pass
+
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+ logging.basicConfig(level=log_level)
+
+ model, kwargs = self.build_model(**kwargs)
+
+ # if vad_model is not None, build vad model else None
+ vad_model = kwargs.get("vad_model", None)
+ vad_kwargs = (
+ {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
+ )
+ if vad_model is not None:
+ logging.info("Building VAD model.")
+ vad_kwargs["model"] = vad_model
+ vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
+ vad_kwargs["device"] = kwargs["device"]
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
+
+ # if punc_model is not None, build punc model else None
+ punc_model = kwargs.get("punc_model", None)
+ punc_kwargs = (
+ {}
+ if kwargs.get("punc_kwargs", {}) is None
+ else kwargs.get("punc_kwargs", {})
+ )
+ if punc_model is not None:
+ logging.info("Building punc model.")
+ punc_kwargs["model"] = punc_model
+ punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
+ punc_kwargs["device"] = kwargs["device"]
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
+
+ # if spk_model is not None, build spk model else None
+ spk_model = kwargs.get("spk_model", None)
+ spk_kwargs = (
+ {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
+ )
+ if spk_model is not None:
+ logging.info("Building SPK model.")
+ spk_kwargs["model"] = spk_model
+ spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
+ spk_kwargs["device"] = kwargs["device"]
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
+ self.cb_model = ClusterBackend().to(kwargs["device"])
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
+ logging.error(
+ "spk_mode should be one of default, vad_segment and punc_segment."
+ )
+ self.spk_mode = spk_mode
+
+ self.kwargs = kwargs
+ self.model = model
+ self.vad_model = vad_model
+ self.vad_kwargs = vad_kwargs
+ self.punc_model = punc_model
+ self.punc_kwargs = punc_kwargs
+ self.spk_model = spk_model
+ self.spk_kwargs = spk_kwargs
+ self.model_path = kwargs.get("model_path")
+
+ @staticmethod
+ def build_model(**kwargs):
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
+ logging.info(
+ "download models from model hub: {}".format(kwargs.get("hub", "ms"))
+ )
+ kwargs = download_model(**kwargs)
+
+ set_all_random_seed(kwargs.get("seed", 0))
+
+ device = kwargs.get("device", "cuda")
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
+ device = "cpu"
+ kwargs["batch_size"] = 1
+ kwargs["device"] = device
+
+ torch.set_num_threads(kwargs.get("ncpu", 4))
+
+ # build tokenizer
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+ tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
+ kwargs["token_list"] = (
+ tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+ )
+ kwargs["token_list"] = (
+ tokenizer.get_vocab()
+ if hasattr(tokenizer, "get_vocab")
+ else kwargs["token_list"]
+ )
+ vocab_size = (
+ len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
+ )
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+ vocab_size = tokenizer.get_vocab_size()
+ else:
+ vocab_size = -1
+ kwargs["tokenizer"] = tokenizer
+
+ # build frontend
+ frontend = kwargs.get("frontend", None)
+ kwargs["input_size"] = None
+ if frontend is not None:
+ frontend_class = tables.frontend_classes.get(frontend)
+ frontend = frontend_class(**kwargs.get("frontend_conf", {}))
+ kwargs["input_size"] = (
+ frontend.output_size() if hasattr(frontend, "output_size") else None
+ )
+ kwargs["frontend"] = frontend
+ # build model
+ model_class = tables.model_classes.get(kwargs["model"])
+ assert model_class is not None, f'{kwargs["model"]} is not registered'
+ model_conf = {}
+ deep_update(model_conf, kwargs.get("model_conf", {}))
+ deep_update(model_conf, kwargs)
+ model = model_class(**model_conf, vocab_size=vocab_size)
+
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ if os.path.exists(init_param):
+ logging.info(f"Loading pretrained params from {init_param}")
+ load_pretrained_model(
+ model=model,
+ path=init_param,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ scope_map=kwargs.get("scope_map", []),
+ excludes=kwargs.get("excludes", None),
+ )
+ else:
+ print(f"error, init_param does not exist!: {init_param}")
+
+ # fp16
+ if kwargs.get("fp16", False):
+ model.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ model.to(torch.bfloat16)
+ model.to(device)
+
+ if not kwargs.get("disable_log", True):
+ tables.print()
+
+ return model, kwargs
+
+ def __call__(self, *args, **cfg):
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ res = self.model(*args, kwargs)
+ return res
+
+ def generate(self, input, input_len=None, **cfg):
+ if self.vad_model is None:
+ return self.inference(input, input_len=input_len, **cfg)
+
+ else:
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
+
+ def inference(
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
+ ):
+ kwargs = self.kwargs if kwargs is None else kwargs
+ if "cache" in kwargs:
+ kwargs.pop("cache")
+ deep_update(kwargs, cfg)
+ model = self.model if model is None else model
+ model.eval()
+
+ batch_size = kwargs.get("batch_size", 1)
+ # if kwargs.get("device", "cpu") == "cpu":
+ # batch_size = 1
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
+ )
+
+ speed_stats = {}
+ asr_result_list = []
+ num_samples = len(data_list)
+ disable_pbar = self.kwargs.get("disable_pbar", False)
+ pbar = (
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+ if not disable_pbar
+ else None
+ )
+ time_speech_total = 0.0
+ time_escape_total = 0.0
+ for beg_idx in range(0, num_samples, batch_size):
+ end_idx = min(num_samples, beg_idx + batch_size)
+ data_batch = data_list[beg_idx:end_idx]
+ key_batch = key_list[beg_idx:end_idx]
+ batch = {"data_in": data_batch, "key": key_batch}
+
+ if (end_idx - beg_idx) == 1 and kwargs.get(
+ "data_type", None
+ ) == "fbank": # fbank
+ batch["data_in"] = data_batch[0]
+ batch["data_lengths"] = input_len
+
+ time1 = time.perf_counter()
+ with torch.no_grad():
+ res = model.inference(**batch, **kwargs)
+ if isinstance(res, (list, tuple)):
+ results = res[0] if len(res) > 0 else [{"text": ""}]
+ meta_data = res[1] if len(res) > 1 else {}
+ time2 = time.perf_counter()
+
+ asr_result_list.extend(results)
+
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+ batch_data_time = meta_data.get("batch_data_time", -1)
+ time_escape = time2 - time1
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
+ speed_stats["forward"] = f"{time_escape:0.3f}"
+ speed_stats["batch_size"] = f"{len(results)}"
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
+ description = f"{speed_stats}, "
+ if pbar:
+ pbar.update(end_idx - beg_idx)
+ pbar.set_description(description)
+ time_speech_total += batch_data_time
+ time_escape_total += time_escape
+
+ if pbar:
+ # pbar.update(1)
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
+ torch.cuda.empty_cache()
+ return asr_result_list
+
+ def vad(self, input, input_len=None, **cfg):
+ kwargs = self.kwargs
+ # step.1: compute the vad model
+ deep_update(self.vad_kwargs, cfg)
+ beg_vad = time.time()
+ res = self.inference(
+ input,
+ input_len=input_len,
+ model=self.vad_model,
+ kwargs=self.vad_kwargs,
+ **cfg,
+ )
+ end_vad = time.time()
+ # FIX(gcf): concat the vad clips for sense vocie model for better aed
+ if cfg.get("merge_vad", False):
+ for i in range(len(res)):
+ res[i]["value"] = merge_vad(
+ res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
+ )
+ elapsed = end_vad - beg_vad
+ return elapsed, res
+
+ def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
+
+ kwargs = self.kwargs
+
+ # step.2 compute asr model
+ model = self.model
+ deep_update(kwargs, cfg)
+ batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
+ kwargs["batch_size"] = batch_size
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
+ )
+ results_ret_list = []
+ time_speech_total_all_samples = 1e-6
+
+ beg_total = time.time()
+ pbar_total = (
+ tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
+ if not kwargs.get("disable_pbar", False)
+ else None
+ )
+
+ for i in range(len(vad_res)):
+ key = vad_res[i]["key"]
+ vadsegments = vad_res[i]["value"]
+ input_i = data_list[i]
+ fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
+ speech = load_audio_text_image_video(
+ input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
+ )
+ speech_lengths = len(speech)
+ n = len(vadsegments)
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
+ results_sorted = []
+
+ if not len(sorted_data):
+ results_ret_list.append({"key": key, "text": "", "timestamp": []})
+ logging.info("decoding, utt: {}, empty speech".format(key))
+ continue
+
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
+ batch_size = max(
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
+ )
+
+ if kwargs["device"] == "cpu":
+ batch_size = 0
+
+ beg_idx = 0
+ beg_asr_total = time.time()
+ time_speech_total_per_sample = speech_lengths / 16000
+ time_speech_total_all_samples += time_speech_total_per_sample
+
+ # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
+
+ all_segments = []
+ max_len_in_batch = 0
+ end_idx = 1
+
+ for j, _ in enumerate(range(0, n)):
+ # pbar_sample.update(1)
+ sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
+ potential_batch_length = max(max_len_in_batch, sample_length) * (
+ j + 1 - beg_idx
+ )
+ # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
+ if (
+ j < n - 1
+ and sample_length < batch_size_threshold_ms
+ and potential_batch_length < batch_size
+ ):
+ max_len_in_batch = max(max_len_in_batch, sample_length)
+ end_idx += 1
+ continue
+
+ speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
+ )
+ results = self.inference(
+ speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
+ )
+
+ for _b in range(len(speech_j)):
+ results[_b]["interval"] = intervals[_b]
+
+ if self.spk_model is not None:
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
+ for _b in range(len(speech_j)):
+ vad_segments = [
+ [
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
+ np.array(speech_j[_b]),
+ ]
+ ]
+ segments = sv_chunk(vad_segments)
+ all_segments.extend(segments)
+ speech_b = [i[2] for i in segments]
+ spk_res = self.inference(
+ speech_b,
+ input_len=None,
+ model=self.spk_model,
+ kwargs=kwargs,
+ **cfg,
+ )
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
+
+ beg_idx = end_idx
+ end_idx += 1
+ max_len_in_batch = sample_length
+ if len(results) < 1:
+ continue
+ results_sorted.extend(results)
+
+ # end_asr_total = time.time()
+ # time_escape_total_per_sample = end_asr_total - beg_asr_total
+ # pbar_sample.update(1)
+ # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
+ # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
+
+ restored_data = [0] * n
+ for j in range(n):
+ index = sorted_data[j][1]
+ cur = results_sorted[j]
+ pattern = r"<\|([^|]+)\|>"
+ emotion_string = re.findall(pattern, cur["text"])
+ cur["text"] = re.sub(pattern, "", cur["text"])
+ cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
+ if self.punc_model is not None and len(cur["text"].strip()) > 0:
+ deep_update(self.punc_kwargs, cfg)
+ punc_res = self.inference(
+ cur["text"],
+ model=self.punc_model,
+ kwargs=self.punc_kwargs,
+ **cfg,
+ )
+ cur["text"] = punc_res[0]["text"]
+
+ restored_data[index] = cur
+
+ end_asr_total = time.time()
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
+ if pbar_total:
+ pbar_total.update(1)
+ pbar_total.set_description(
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
+ )
+
+ # end_total = time.time()
+ # time_escape_total_all_samples = end_total - beg_total
+ # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
+ # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
+ # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
+ return restored_data
+
+ def export(self, input=None, **cfg):
+ """
+
+ :param input:
+ :param type:
+ :param quantize:
+ :param fallback_num:
+ :param calib_num:
+ :param opset_version:
+ :param cfg:
+ :return:
+ """
+
+ device = cfg.get("device", "cpu")
+ model = self.model.to(device=device)
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ kwargs["device"] = device
+ del kwargs["model"]
+ model.eval()
+
+ type = kwargs.get("type", "onnx")
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=None, data_type=kwargs.get("data_type", None), key=None
+ )
+
+ with torch.no_grad():
+ export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
+
+ return export_dir
diff --git a/tools/sensevoice/fun_asr.py b/tools/sensevoice/fun_asr.py
new file mode 100644
index 00000000..02c15a59
--- /dev/null
+++ b/tools/sensevoice/fun_asr.py
@@ -0,0 +1,332 @@
+import gc
+import os
+import re
+
+from audio_separator.separator import Separator
+
+os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
+os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
+import json
+import subprocess
+from pathlib import Path
+
+import click
+import torch
+from loguru import logger
+from pydub import AudioSegment
+from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
+from tools.sensevoice.auto_model import AutoModel
+
+
+def uvr5_cli(
+ audio_dir: Path,
+ output_folder: Path,
+ audio_files: list[Path] | None = None,
+ output_format: str = "flac",
+ model: str = "BS-Roformer-Viperx-1296.ckpt",
+):
+ # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
+ sepr = Separator(
+ model_file_dir=os.environ["UVR5_CACHE"],
+ output_dir=output_folder,
+ output_format=output_format,
+ )
+ dictmodel = {
+ "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
+ "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
+ "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
+ "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
+ }
+ roformer_model = dictmodel[model]
+ sepr.load_model(roformer_model)
+ if audio_files is None:
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+ total_files = len(audio_files)
+
+ print(f"{total_files} audio files found")
+
+ res = []
+ for audio in tqdm(audio_files, desc="Denoising: "):
+ file_path = str(audio_dir / audio)
+ sep_out = sepr.separate(file_path)
+ if isinstance(sep_out, str):
+ res.append(sep_out)
+ elif isinstance(sep_out, list):
+ res.extend(sep_out)
+ del sepr
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return res, roformer_model
+
+
+def get_sample_rate(media_path: Path):
+ result = subprocess.run(
+ [
+ "ffprobe",
+ "-v",
+ "quiet",
+ "-print_format",
+ "json",
+ "-show_streams",
+ str(media_path),
+ ],
+ capture_output=True,
+ text=True,
+ check=True,
+ )
+ media_info = json.loads(result.stdout)
+ for stream in media_info.get("streams", []):
+ if stream.get("codec_type") == "audio":
+ return stream.get("sample_rate")
+ return "44100" # Default sample rate if not found
+
+
+def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
+ sr = get_sample_rate(src_path)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ if src_path.resolve() == out_path.resolve():
+ output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
+ else:
+ output = str(out_path)
+ subprocess.run(
+ [
+ "ffmpeg",
+ "-loglevel",
+ "error",
+ "-i",
+ str(src_path),
+ "-acodec",
+ "pcm_s16le" if out_fmt == "wav" else "flac",
+ "-ar",
+ sr,
+ "-ac",
+ "1",
+ "-y",
+ output,
+ ],
+ check=True,
+ )
+ return out_path
+
+
+def convert_video_to_audio(video_path: Path, audio_dir: Path):
+ cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
+ vocals = [
+ p
+ for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
+ if p.suffix in AUDIO_EXTENSIONS
+ ]
+ if len(vocals) > 0:
+ return vocals[0]
+ audio_path = cur_dir / f"{video_path.stem}.wav"
+ convert_to_mono(video_path, audio_path)
+ return audio_path
+
+
+@click.command()
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option(
+ "--max_single_segment_time",
+ default=20000,
+ type=int,
+ help="Maximum of Output single audio duration(ms)",
+)
+@click.option("--fsmn-vad/--silero-vad", default=False)
+@click.option("--punc/--no-punc", default=False)
+@click.option("--denoise/--no-denoise", default=False)
+@click.option("--save_emo/--no_save_emo", default=False)
+def main(
+ audio_dir: str,
+ save_dir: str,
+ device: str,
+ language: str,
+ max_single_segment_time: int,
+ fsmn_vad: bool,
+ punc: bool,
+ denoise: bool,
+ save_emo: bool,
+):
+
+ audios_path = Path(audio_dir)
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ video_files = list_files(
+ path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
+ )
+ v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
+
+ if denoise:
+ VOCAL = "_(Vocals)"
+ original_files = [
+ p
+ for p in audios_path.glob("**/*")
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
+ ]
+
+ _, cur_model = uvr5_cli(
+ audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
+ )
+ need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
+ need_remove.extend(original_files)
+ for _ in need_remove:
+ _.unlink()
+ vocal_files = [
+ p
+ for p in audios_path.glob("**/*")
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
+ ]
+ for f in vocal_files:
+ fn, ext = f.stem, f.suffix
+
+ v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
+ if v_pos != -1:
+ new_fn = fn[: v_pos + len(VOCAL)]
+ new_f = f.with_name(new_fn + ext)
+ f = f.rename(new_f)
+ convert_to_mono(f, f, "flac")
+ f.unlink()
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ logger.info("Loading / Downloading Funasr model...")
+
+ model_dir = "iic/SenseVoiceSmall"
+
+ vad_model = "fsmn-vad" if fsmn_vad else None
+ vad_kwargs = {"max_single_segment_time": max_single_segment_time}
+ punc_model = "ct-punc" if punc else None
+
+ manager = AutoModel(
+ model=model_dir,
+ trust_remote_code=False,
+ vad_model=vad_model,
+ vad_kwargs=vad_kwargs,
+ punc_model=punc_model,
+ device=device,
+ )
+
+ if not fsmn_vad and vad_model is None:
+ vad_model = load_silero_vad()
+
+ logger.info("Model loaded.")
+
+ pattern = re.compile(r"_\d{3}\.")
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+
+ if pattern.search(file_path.name):
+ # logger.info(f"Skipping {file_path} as it has already been processed.")
+ continue
+
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ audio = AudioSegment.from_file(file_path)
+
+ cfg = dict(
+ cache={},
+ language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=False,
+ batch_size_s=60,
+ )
+
+ if fsmn_vad:
+ elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
+ else:
+ wav = read_audio(
+ str(file_path)
+ ) # backend (sox, soundfile, or ffmpeg) required!
+ audio_key = file_path.stem
+ audio_val = []
+ speech_timestamps = get_speech_timestamps(
+ wav,
+ vad_model,
+ max_speech_duration_s=max_single_segment_time // 1000,
+ return_seconds=True,
+ )
+
+ audio_val = [
+ [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
+ for timestamp in speech_timestamps
+ ]
+ vad_res = []
+ vad_res.append(dict(key=audio_key, value=audio_val))
+
+ res = manager.inference_with_vadres(
+ input=str(file_path), vad_res=vad_res, **cfg
+ )
+
+ for i, info in enumerate(res):
+ [start_ms, end_ms] = info["interval"]
+ text = info["text"]
+ emo = info["emo"]
+ sliced_audio = audio[start_ms:end_ms]
+ audio_save_path = (
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
+ )
+ sliced_audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}: {text}")
+
+ transcript_save_path = (
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
+ )
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(text)
+
+ if save_emo:
+ emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
+ with open(
+ emo_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(emo)
+
+ if audios_path.resolve() == save_path.resolve():
+ file_path.unlink()
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
+
+ # Load the audio file
+ audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
+ model_dir = "iic/SenseVoiceSmall"
+ m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
+ m.eval()
+
+ res = m.inference(
+ data_in=f"{kwargs['model_path']}/example/zh.mp3",
+ language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=False,
+ ban_emo_unk=False,
+ **kwargs,
+ )
+
+ print(res)
+ text = rich_transcription_postprocess(res[0][0]["text"])
+ print(text)
diff --git a/tools/sensevoice/vad_utils.py b/tools/sensevoice/vad_utils.py
new file mode 100644
index 00000000..3bef75ed
--- /dev/null
+++ b/tools/sensevoice/vad_utils.py
@@ -0,0 +1,61 @@
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+def slice_padding_fbank(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ for i, segment in enumerate(vad_segments):
+
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
+ speech_i = speech[0, bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
+ speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
+ return feats_pad, speech_lengths_pad
+
+
+def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ intervals = []
+ for i, segment in enumerate(vad_segments):
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths)
+ speech_i = speech[bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ intervals.append([bed_idx // 16, end_idx // 16])
+
+ return speech_list, speech_lengths_list, intervals
+
+
+def merge_vad(vad_result, max_length=15000, min_length=0):
+ new_result = []
+ if len(vad_result) <= 1:
+ return vad_result
+ time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
+ time_step = sorted(list(set(time_step)))
+ if len(time_step) == 0:
+ return []
+ bg = 0
+ for i in range(len(time_step) - 1):
+ time = time_step[i]
+ if time_step[i + 1] - bg < max_length:
+ continue
+ if time - bg > min_length:
+ new_result.append([bg, time])
+ # if time - bg < max_length * 1.5:
+ # new_result.append([bg, time])
+ # else:
+ # split_num = int(time - bg) // max_length + 1
+ # spl_l = int(time - bg) // split_num
+ # for j in range(split_num):
+ # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
+ bg = time
+ new_result.append([bg, time_step[-1]])
+ return new_result
diff --git a/tools/smart_pad.py b/tools/smart_pad.py
index 7fb55d9a..9772168f 100644
--- a/tools/smart_pad.py
+++ b/tools/smart_pad.py
@@ -8,7 +8,7 @@
import torchaudio
from tqdm import tqdm
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+from tools.file import AUDIO_EXTENSIONS, list_files
threshold = 10 ** (-50 / 20.0)
diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py
index 977afdf3..d24a5f39 100644
--- a/tools/vqgan/create_train_split.py
+++ b/tools/vqgan/create_train_split.py
@@ -7,7 +7,7 @@
from pydub import AudioSegment
from tqdm import tqdm
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
@click.command()
diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py
index fc84f7e6..bc6bc408 100644
--- a/tools/vqgan/extract_vq.py
+++ b/tools/vqgan/extract_vq.py
@@ -17,7 +17,7 @@
from loguru import logger
from omegaconf import OmegaConf
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
# register eval resolver
OmegaConf.register_new_resolver("eval", eval)
diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py
index 6070e735..42e7de8a 100644
--- a/tools/whisper_asr.py
+++ b/tools/whisper_asr.py
@@ -32,7 +32,7 @@
from pydub import AudioSegment
from tqdm import tqdm
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+from tools.file import AUDIO_EXTENSIONS, list_files
@click.command()
@@ -83,8 +83,6 @@ def main(
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
)
- numbered_suffix_pattern = re.compile(r"-\d{3}$")
-
for file_path in tqdm(audio_files, desc="Processing audio file"):
file_stem = file_path.stem
file_suffix = file_path.suffix
@@ -92,18 +90,6 @@ def main(
rel_path = Path(file_path).relative_to(audio_dir)
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
- # Skip files that already have a .lab file or a -{3-digit number} suffix
- numbered_suffix = numbered_suffix_pattern.search(file_stem)
- lab_file = file_path.with_suffix(".lab")
-
- if numbered_suffix and lab_file.exists():
- continue
-
- if not numbered_suffix and lab_file.with_stem(lab_file.stem + "-001").exists():
- if file_path.exists():
- file_path.unlink()
- continue
-
audio = AudioSegment.from_file(file_path)
segments, info = model.transcribe(
@@ -119,6 +105,7 @@ def main(
)
print("Total len(ms): ", len(audio))
+ whole_text = None
for segment in segments:
id, start, end, text = (
segment.id,
@@ -127,26 +114,24 @@ def main(
segment.text,
)
print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
- start_ms = int(start * 1000)
- end_ms = int(end * 1000) + 200 # add 0.2s avoid truncating
- segment_audio = audio[start_ms:end_ms]
- audio_save_path = (
- save_path / rel_path.parent / f"{file_stem}-{id:03d}{file_suffix}"
- )
- segment_audio.export(audio_save_path, format=file_suffix[1:])
- print(f"Exported {audio_save_path}")
-
- transcript_save_path = (
- save_path / rel_path.parent / f"{file_stem}-{id:03d}.lab"
- )
- with open(
- transcript_save_path,
- "w",
- encoding="utf-8",
- ) as f:
- f.write(segment.text)
-
- file_path.unlink()
+ if not whole_text:
+ whole_text = text
+ else:
+ whole_text += ", " + text
+
+ whole_text += "."
+
+ audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
+ audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}")
+
+ transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(whole_text)
if __name__ == "__main__":