Skip to content

Commit

Permalink
Batching everything.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomiinek committed Feb 28, 2025
1 parent 32f8169 commit a6afc4a
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 61 deletions.
3 changes: 3 additions & 0 deletions src/tidytunes/bin/process_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def process_audio(audios, device, pipeline_components):

for name, func, kwargs, filter_fn in pipeline_components:

# to improve batching efficiency
audio_segments = sorted(audio_segments, key=lambda x: x.duration)

values = func(audio_segments, device=device, **kwargs)
if filter_fn:
audio_segments, _ = partition(audio_segments, by=filter_fn(values))
Expand Down
6 changes: 5 additions & 1 deletion src/tidytunes/models/external/silerovad.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def forward(self, audio_chunk_16khz: torch.Tensor, state: list[torch.Tensor]):

x = audio_chunk_16khz.unsqueeze(1)

x = F.pad(x, (self.audio_padding_size, self.audio_padding_size), mode="reflect")
x = F.pad(
x.float().contiguous(),
(self.audio_padding_size, self.audio_padding_size),
mode="reflect",
)
x = self.input_conv(x)

a, b = torch.pow(x, 2).chunk(2, dim=1)
Expand Down
12 changes: 8 additions & 4 deletions src/tidytunes/pipeline_components/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,35 @@

from tidytunes.utils import (
Audio,
chunk_list,
collate_audios,
decollate_audios,
sequence_mask,
to_batches,
)


def denoise(
audio: list[Audio], device: str = "cpu", batch_size: int = 32
audio: list[Audio],
device: str = "cpu",
batch_size: int = 64,
batch_duration: float = 1280.0,
) -> list[Audio]:
"""
Apply denoising to a list of audio samples using a pre-trained model.
Args:
audio (list[Audio]): List of audio objects to be denoised.
device (str): The device to run the denoising model on (default: "cpu").
batch_size (int): Number of audio samples to process in a batch (default: 32).
batch_size (int): Maximal number of audio samples to process in a batch (default: 64).
batch_duration (float): Maximal duration of audio samples to process in a batch (default: 1280.0)
Returns:
list[Audio]: List of denoised audio objects.
"""
denoiser = load_denoiser(device)
denoised = []

for audio_batch in chunk_list(audio, batch_size):
for audio_batch in to_batches(audio, batch_size, batch_duration):
audio_tensor, audio_lengths = collate_audios(
audio_batch, denoiser.sampling_rate
)
Expand Down
8 changes: 5 additions & 3 deletions src/tidytunes/pipeline_components/dnsmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from tidytunes.utils import Audio, chunk_list, collate_audios
from tidytunes.utils import Audio, collate_audios, to_batches


def get_dnsmos(
Expand All @@ -11,6 +11,7 @@ def get_dnsmos(
device: str = "cpu",
num_threads: int | None = 8,
batch_size: int = 32,
batch_duration: float = 640.0,
) -> torch.Tensor:
"""
Computes DNSMOS (Deep Noise Suppression Mean Opinion Score) for a batch of audio clips.
Expand All @@ -20,15 +21,16 @@ def get_dnsmos(
personalized (bool): Whether to use a personalized model (default: True).
device (str): The device to run the model on (default: "cpu").
num_threads (int | None): Number of threads to use for ONNX inference (default: 8).
batch_size (int): Batch size for processing (default: 32).
batch_size (int): Maximal number of audio samples to process in a batch (default: 32).
batch_duration (float): Maximal duration of audio samples to process in a batch (default: 640.0)
Returns:
torch.Tensor: Tensor containing DNSMOS scores for each input audio clip.
"""
model = load_dnsmos_model(device, personalized, num_threads)
mos_scores = []

for audio_batch in chunk_list(audio, batch_size):
for audio_batch in to_batches(audio, batch_size, batch_duration):
a, al = collate_audios(audio_batch, model.sampling_rate)
with torch.no_grad():
_, _, _, mos = model(a.to(device), al.to(device))
Expand Down
22 changes: 15 additions & 7 deletions src/tidytunes/pipeline_components/language_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import torch

from tidytunes.utils import Audio, collate_audios
from tidytunes.utils import Audio, collate_audios, to_batches


def get_language_probabilities(
audio: list[Audio], language_code: str, device: str = "cpu"
audio: list[Audio],
language_code: str,
device: str = "cpu",
batch_size: int = 64,
batch_duration: float = 1280.0,
) -> torch.Tensor:
"""
Compute the probability of a given language being spoken in the audio.
Expand All @@ -15,18 +19,22 @@ def get_language_probabilities(
audio (list[Audio]): List of Audio objects to analyze.
language_code (str): The target language code to check probabilities for.
device (str): The device to run the model on (default: "cpu").
batch_size (int): Maximal number of audio samples to process in a batch (default: 64).
batch_duration (float): Maximal duration of audio samples to process in a batch (default: 1280.0)
Returns:
Tensor (B,) of probabilities for the specified language.
"""
model, lab2ind = load_langid_voxlingua107_ecapa(device)
lang_probs = []

audio_16khz, audio_16khz_lens = collate_audios(audio, 16000)
with torch.no_grad():
out_prob, _, _ = model(audio_16khz.to(device), audio_16khz_lens.to(device))
lang_prob = torch.stack([p[lab2ind[language_code]] for p in out_prob])
for audio_batch in to_batches(audio, batch_size, batch_duration):
audio_16khz, audio_16khz_lens = collate_audios(audio_batch, 16000)
with torch.no_grad():
out_prob, _, _ = model(audio_16khz.to(device), audio_16khz_lens.to(device))
lang_probs.extend([p[lab2ind[language_code]] for p in out_prob])

return lang_prob
return torch.stack(lang_probs)


@lru_cache(1)
Expand Down
47 changes: 31 additions & 16 deletions src/tidytunes/pipeline_components/source_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@

import torch

from tidytunes.utils import Audio, collate_audios, frame_labels_to_time_segments
from tidytunes.utils import (
Audio,
collate_audios,
frame_labels_to_time_segments,
to_batches,
)


def find_segments_without_music(
audios: list[Audio],
audio: list[Audio],
min_duration: float = 6.4,
device: str = "cpu",
batch_size: int = 1,
batch_duration: float = 36000.0,
):
"""
Identifies segments in audio where speech is present but music is absent.
Expand All @@ -20,26 +27,34 @@ def find_segments_without_music(
min_speech_energy (float): Minimum required energy for vocal sources to be considered speech (default: 0.99).
min_duration (float): Minimum duration (in seconds) for valid speech segments (default: 6.4).
device (str): The device to run the model on (default: "cpu").
batch_size (int): Maximal number of audio samples to process in a batch (default: 1).
batch_duration (float): Maximal duration of audio samples to process in a batch (default: 36000.0)
Returns:
list[list[Segment]]: List of speech segments without music for each input Audio.
"""
demucs = load_demucs(device)

audio, audio_lens = collate_audios(audios, demucs.sampling_rate)
audio, audio_lens = audio.to(device), audio_lens.to(device)

with torch.no_grad():
speech_without_music_mask = demucs(audio, audio_lens)

return [
frame_labels_to_time_segments(
m,
demucs.frame_shift,
filter_with=lambda x: (x.symbol is True) & (x.duration >= min_duration),
time_segments = []

for audio_batch in to_batches(audio, batch_size, batch_duration):

a, al = collate_audios(audio_batch, demucs.sampling_rate)
with torch.no_grad():
speech_without_music_mask = demucs(a.to(device), al.to(device))

time_segments.extend(
[
frame_labels_to_time_segments(
m,
demucs.frame_shift,
filter_with=lambda x: (x.symbol is True)
& (x.duration >= min_duration),
)
for m in speech_without_music_mask
]
)
for m in speech_without_music_mask
]

return time_segments


@lru_cache(1)
Expand Down
27 changes: 18 additions & 9 deletions src/tidytunes/pipeline_components/speaker_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import torch.nn.functional as F
from sklearn.cluster import AgglomerativeClustering, KMeans

from tidytunes.utils import Audio, collate_audios, frame_labels_to_time_segments
from tidytunes.utils import (
Audio,
collate_audios,
frame_labels_to_time_segments,
to_batches,
)


def find_segments_with_single_speaker(
Expand All @@ -15,9 +20,12 @@ def find_segments_with_single_speaker(
frame_shift: int = 64,
num_clusters: int = 10,
device: str = "cpu",
batch_size: int = 64,
batch_duration: float = 1280.0,
):
"""
Identifies segments in the audio where only a single speaker is present.
*** All input segments are supposed to come from a single source. ***
Args:
audio (list[Audio]): List of audio objects.
Expand All @@ -27,23 +35,24 @@ def find_segments_with_single_speaker(
frame_shift (float): Number of model input frames per one output speaker label (default: 64).
num_clusters (int): Initial number of clusters before agglomertive clustering (defailt: 10).
device (str): Device to run the model on (default: "cpu").
batch_size (int): Maximal number of audio samples to process in a batch (default: 64).
batch_duration (float): Maximal duration of audio samples to process in a batch (default: 1280.0)
Returns:
list[list[Segment]]: List of speaker segments for each input audio.
"""
speaker_encoder = load_speaker_encoder(num_frames=frame_shift, device=device)
embeddings = []

audio, audio_lens = collate_audios(
audio, sampling_rate=speaker_encoder.sampling_rate
)
audio, audio_lens = audio.to(device), audio_lens.to(device)
for audio_batch in to_batches(audio, batch_size, batch_duration):

with torch.no_grad():
embeddings = speaker_encoder(audio, audio_lens)
embeddings_all = torch.cat(embeddings, dim=0)
a, al = collate_audios(audio_batch, sampling_rate=speaker_encoder.sampling_rate)
with torch.no_grad():
e = speaker_encoder(a.to(device), al.to(device))
embeddings.extend(e)

embeddings_all = torch.cat(embeddings, dim=0)
centroids = find_cluster_centers(embeddings_all, num_clusters)

labels = [
F.cosine_similarity(e.unsqueeze(1), centroids.unsqueeze(0), dim=-1).argmax(
dim=-1
Expand Down
49 changes: 34 additions & 15 deletions src/tidytunes/pipeline_components/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@

import torch

from tidytunes.utils import Audio, collate_audios, frame_labels_to_time_segments
from tidytunes.utils import (
Audio,
collate_audios,
frame_labels_to_time_segments,
to_batches,
)


def find_segments_with_speech(
audio: list[Audio],
min_duration: float = 3.2,
max_duration: float = 30.0,
device: str = "cpu",
batch_size: int = 64,
batch_duration: float = 1280.0,
):
"""
Identifies speech segments in the given audio using a Voice Activity Detector (VAD).
Expand All @@ -19,25 +26,36 @@ def find_segments_with_speech(
min_duration (float): Minimum duration for a valid speech segment (default: 3.2).
max_duration (float): Maximum duration for a valid speech segment (default: 30.0).
device (str): The device to run the VAD model on (default: "cpu").
batch_size (int): Maximal number of audio samples to process in a batch (default: 64).
batch_duration (float): Maximal duration of audio samples to process in a batch (default: 1280.0)
Returns:
list[list[Segment]]: Time segments containing speech for each input Audio.
"""
vad = load_vad(device)
audio_tensor, _ = collate_audios(audio, vad.sampling_rate)
with torch.no_grad():
speech_mask = vad(audio_tensor.to(device))
speech_mask[..., :-1] += speech_mask[..., 1:].clone() # Pre-bounce speech starts

time_segments = [
frame_labels_to_time_segments(
m,
vad.frame_shift,
filter_with=lambda x: (x.symbol is True)
and (min_duration <= x.duration <= max_duration),
time_segments = []

for audio_batch in to_batches(audio, batch_size, batch_duration):

audio_tensor, _ = collate_audios(audio_batch, vad.sampling_rate)
with torch.no_grad():
speech_mask = vad(audio_tensor.to(device))
speech_mask[..., :-1] += speech_mask[
..., 1:
].clone() # Pre-bounce speech starts

time_segments.extend(
[
frame_labels_to_time_segments(
m,
vad.frame_shift,
filter_with=lambda x: (x.symbol is True)
and (min_duration <= x.duration <= max_duration),
)
for m in speech_mask
]
)
for m in speech_mask
]

return time_segments


Expand All @@ -59,5 +77,6 @@ def load_vad(device: str = "cpu", tag: str = "v1.0.0"):

model_weights_path = download_github(tag, "silerovad_weights.pt")
vad = SileroVAD.from_files(model_weights_path)
vad_trace = vad.to_jit_trace(device)
# vad_trace = vad.to_jit_trace(device)
vad_trace = vad.eval().to(device)
return VoiceActivityDetector(vad_trace).to(device)
2 changes: 1 addition & 1 deletion src/tidytunes/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .audio import Audio, collate_audios, decollate_audios, trim_audios
from .download import download_github
from .etc import chunk_list, frame_labels_to_time_segments, partition
from .etc import frame_labels_to_time_segments, partition, to_batches
from .logging import setup_logger
from .tensors import masked_mean, masked_std, sequence_mask
from .trace import TraceMixin
Loading

0 comments on commit a6afc4a

Please sign in to comment.