Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementaiton of automatic batching #9

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/tidytunes/bin/process_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@ def process_audio(audios, device, pipeline_components):

for name, func, kwargs, filter_fn in pipeline_components:

print(name)

# to improve batching efficiency
audio_segments = sorted(audio_segments, key=lambda x: x.duration)

values = func(audio_segments, device=device, **kwargs)
if filter_fn:
audio_segments, _ = partition(audio_segments, by=filter_fn(values))
audio_segments, _ = partition(
audio_segments, by=[filter_fn(v) for v in values]
)
else:
audio_segments = trim_audios(audio_segments, values)

Expand Down
6 changes: 5 additions & 1 deletion src/tidytunes/models/external/silerovad.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def forward(self, audio_chunk_16khz: torch.Tensor, state: list[torch.Tensor]):

x = audio_chunk_16khz.unsqueeze(1)

x = F.pad(x, (self.audio_padding_size, self.audio_padding_size), mode="reflect")
x = F.pad(
x.float().contiguous(),
(self.audio_padding_size, self.audio_padding_size),
mode="reflect",
)
x = self.input_conv(x)

a, b = torch.pow(x, 2).chunk(2, dim=1)
Expand Down
155 changes: 89 additions & 66 deletions src/tidytunes/models/source_separation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.transforms import Fade

from tidytunes.utils import masked_mean, masked_std, sequence_mask
Expand All @@ -9,104 +10,126 @@ class SourceSeparator(nn.Module):
def __init__(
self,
model,
segment: float = 10.0,
overlap: float = 0.1,
frame_shift: float = 0.16,
segment_frames: int = 63,
overlap_frames: int = 5,
sampling_rate: int = 44100,
max_music_energy: float = 0.01,
min_speech_energy: float = 0.99,
window_frames: int = 2,
minimal_energy: float = 1e-7,
):
super().__init__()
self.model = model
self.frame_shift = frame_shift
self.sampling_rate = sampling_rate
self.chunk_len = int(sampling_rate * segment * (1 + overlap))
self.overlap_frames = int(overlap * sampling_rate)
self.fade = Fade(
fade_in_len=0, fade_out_len=self.overlap_frames, fade_shape="linear"
)
self.frame_samples = int(frame_shift * sampling_rate)
self.segment_samples = self.frame_samples * segment_frames
self.overlap_samples = self.frame_samples * overlap_frames
self.fade_in = Fade(fade_in_len=self.overlap_samples, fade_out_len=0)
self.fade_out = Fade(fade_in_len=0, fade_out_len=self.overlap_samples)
self.max_music_energy = max_music_energy
self.min_speech_energy = min_speech_energy
self.window_samples = self.frame_samples * window_frames
self.minimal_energy = minimal_energy

def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
audio: torch.Tensor,
audio_lens: torch.Tensor,
):
"""
Normalizes and processes input audio to separate sources.
Normalizes and processes input audio to separate sources and calculates
the energy of vocals and the rest of sources within a sliding window to
decide if there is a background music together with speech or not.

Args:
x (B, T): Input audio tensor.
x_lens (B,): Lengths of each sequence in the batch.
audio (B, T): Input audio tensor.
audio_lens (B,): Lengths of each sequence in the batch.

Returns:
(B, sources, L): Separated audio sources.
(B, L): Mask indicating frames without music.
"""

mask = sequence_mask(x_lens)
mean = masked_mean(x, mask)
std = masked_std(x, mask, mean=mean)
B, T = audio.shape

x = (x - mean.unsqueeze(-1)) / std.unsqueeze(-1)
x[~mask] = 0.0
# Pad to nearest multiple of segment and add overlap
padded_lens = (
(T + self.segment_samples - 1) // self.segment_samples
) * self.segment_samples
pad_size = padded_lens - T + self.overlap_samples

y = self.get_sources(x)
y = y * std[:, None, None] + mean[:, None, None]
mask = mask.unsqueeze(1).repeat(1, 4, 1)
y[~mask] = 0.0
audio = F.pad(audio, (0, pad_size))

# (B, sources, T), sources are: drums, bass, other, vocals
return y
mask = sequence_mask(audio_lens, max_length=audio.shape[-1])
mean = masked_mean(audio, mask)
std = masked_std(audio, mask, mean=mean)

@torch.no_grad()
def get_sources(
self,
audio: torch.Tensor,
):
"""
Splits audio into overlapping chunks, processes with the model,
and applies fade-in/fade-out to smooth transitions.
x = (audio - mean.unsqueeze(-1)) / std.unsqueeze(-1)
x[~mask] = 0.0

Args:
audio (B, T): Normalized input audio.
audio_buffer = torch.zeros(x.shape[0], self.overlap_samples, device=x.device)
window_buffer = None
output = []

Returns:
(B, sources, T): Separated sources.
"""
for i in range((x.shape[-1] - self.overlap_samples) // self.segment_samples):

# The model expects stereo inputs
audio = audio.unsqueeze(1).expand(-1, 2, -1)
B, C, L = audio.shape
s = i * self.segment_samples
e = (i + 1) * self.segment_samples + self.overlap_samples
segment = x[..., s:e]

if L <= self.chunk_len:
return self.model(audio).mean(dim=-2)
assert segment.shape[-1] == self.segment_samples + self.overlap_samples

output = torch.zeros(B, len(self.model.sources), L, device=audio.device)
buffer = None
if i > 0:
segment = self.fade_in(segment)
segment[..., : self.overlap_samples] += audio_buffer
segment = self.fade_out(segment)
audio_buffer = segment[..., -self.overlap_samples :]
segment = segment[..., : -self.overlap_samples]

start, end = 0, self.chunk_len
while start < L - self.overlap_frames:
chunk = audio[:, :, start:end]
x = self.model(chunk)
x = self.fade(x)
y = self.forward_segment(segment)
y = y * std[:, None, None] + mean[:, None, None]

chunk_output = x[..., : x.shape[-1] - self.fade.fade_out_len]
if window_buffer is not None:
y = torch.cat([window_buffer, y], dim=-1)
window_buffer = y[..., -self.frame_samples :]

if self.fade.fade_in_len > 0:
chunk_output[..., : self.fade.fade_in_len] += buffer
buffer = x[..., x.shape[-1] - self.fade.fade_out_len :]
frames = y.unfold(-1, self.window_samples, self.frame_samples)

output[..., start : start + chunk_output.shape[-1]] = chunk_output.mean(
dim=-2
energy = (frames**2).mean(dim=-1) # b c t w -> b c t

# Merge all non-vocal channels into a single one
energy = torch.cat(
(energy[:, :3].sum(dim=-2, keepdim=True), energy[:, 3:]), dim=-2
)
energy_total = energy.sum(dim=-2, keepdim=True)
rel_energy = energy / energy_total

if start == 0:
self.fade.fade_in_len = self.overlap_frames
start += self.chunk_len - self.overlap_frames
else:
start += self.chunk_len
abs_silence = energy_total.squeeze(1) < self.minimal_energy
no_music = abs_silence | (rel_energy[:, 0] <= self.max_music_energy)
is_speech = abs_silence | (rel_energy[:, 1] >= self.min_speech_energy)
output.append(no_music & is_speech)

end += self.chunk_len
if end >= L:
self.fade.fade_out_len = 0
output = torch.cat(output, dim=-1)

# reset the original chunk fading
self.fade.fade_in_len = 0
self.fade.fade_out_len = self.overlap_frames
# Trim output to remove segment padding and mask invalid positions
n_frames = (
audio_lens + 2 * self.frame_samples - 1 - self.window_samples
) // self.frame_samples
output = output[..., : n_frames.max()]
mask = sequence_mask(n_frames)
output[~mask] = False

return output

@torch.no_grad()
def forward_segment(
self,
x: torch.Tensor,
):
# The model expects stereo inputs
x = x.unsqueeze(1).expand(-1, 2, -1)
B, C, L = x.shape
assert L == self.segment_samples
x = self.model(x).mean(dim=-2)
return x
59 changes: 24 additions & 35 deletions src/tidytunes/pipeline_components/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,46 @@

from tidytunes.utils import (
Audio,
chunk_list,
batched,
collate_audios,
decollate_audios,
sequence_mask,
)


@batched(batch_size=1024, batch_duration=1280.0)
def denoise(
audio: list[Audio], device: str = "cpu", batch_size: int = 32
audio: list[Audio],
device: str = "cpu",
) -> list[Audio]:
"""
Apply denoising to a list of audio samples using a pre-trained model.

Args:
audio (list[Audio]): List of audio objects to be denoised.
device (str): The device to run the denoising model on (default: "cpu").
batch_size (int): Number of audio samples to process in a batch (default: 32).

Returns:
list[Audio]: List of denoised audio objects.
"""
denoiser = load_denoiser(device)
denoised = []

for audio_batch in chunk_list(audio, batch_size):
audio_tensor, audio_lengths = collate_audios(
audio_batch, denoiser.sampling_rate
)
mask = sequence_mask(audio_lengths.to(device))
with torch.no_grad():
denoised_audio = denoiser(audio_tensor.to(device), mask)
denoised.extend(
decollate_audios(
denoised_audio,
audio_lengths,
denoiser.sampling_rate,
origin_like=audio_batch,
)
)

return denoised
audio_tensor, audio_lengths = collate_audios(audio, denoiser.sampling_rate)
mask = sequence_mask(audio_lengths.to(device))
with torch.no_grad():
denoised_audio = denoiser(audio_tensor.to(device), mask)
return decollate_audios(
denoised_audio,
audio_lengths,
denoiser.sampling_rate,
origin_like=audio,
)


def get_denoised_pesq(
audio: list[Audio],
sampling_rate: int = 16000,
device: str = "cpu",
batch_size: int = 32,
) -> torch.Tensor:
"""
Compute the Perceptual Evaluation of Speech Quality (PESQ) score between original and denoised audio.
Expand All @@ -61,23 +53,20 @@ def get_denoised_pesq(
audio (list[Audio]): List of audio objects to be denoised.
sampling_rate (int): The target sampling rate for PESQ computation (default: 16000 Hz).
device (str): The device to run the denoising model on (default: "cpu").
batch_size (int): Number of audio samples to process in a batch (default: 32).

Returns:
torch.Tensor: Tensor containing PESQ scores for each input Audio.
"""
denoised = denoise(audio, device, batch_size)
return torch.tensor(
[
pesq(
sampling_rate,
ref.resample(sampling_rate).as_tensor().cpu().numpy(),
enh.resample(sampling_rate).as_tensor().cpu().numpy(),
on_error=PesqError.RETURN_VALUES,
)
for ref, enh in zip(audio, denoised)
]
)
denoised = denoise(audio, device)
return [
pesq(
sampling_rate,
ref.resample(sampling_rate).as_tensor().cpu().numpy(),
enh.resample(sampling_rate).as_tensor().cpu().numpy(),
on_error=PesqError.RETURN_VALUES,
)
for ref, enh in zip(audio, denoised)
]


@lru_cache(maxsize=1)
Expand Down
17 changes: 7 additions & 10 deletions src/tidytunes/pipeline_components/dnsmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import torch

from tidytunes.utils import Audio, chunk_list, collate_audios
from tidytunes.utils import Audio, batched, collate_audios


@batched(batch_size=1024, batch_duration=1280.0)
def get_dnsmos(
audio: list[Audio],
personalized: bool = True,
device: str = "cpu",
num_threads: int | None = 8,
batch_size: int = 32,
) -> torch.Tensor:
"""
Computes DNSMOS (Deep Noise Suppression Mean Opinion Score) for a batch of audio clips.
Expand All @@ -20,21 +20,18 @@ def get_dnsmos(
personalized (bool): Whether to use a personalized model (default: True).
device (str): The device to run the model on (default: "cpu").
num_threads (int | None): Number of threads to use for ONNX inference (default: 8).
batch_size (int): Batch size for processing (default: 32).

Returns:
torch.Tensor: Tensor containing DNSMOS scores for each input audio clip.
"""

model = load_dnsmos_model(device, personalized, num_threads)
mos_scores = []

for audio_batch in chunk_list(audio, batch_size):
a, al = collate_audios(audio_batch, model.sampling_rate)
with torch.no_grad():
_, _, _, mos = model(a.to(device), al.to(device))
mos_scores.append(mos)
a, al = collate_audios(audio, model.sampling_rate)
with torch.no_grad():
_, _, _, mos = model(a.to(device), al.to(device))

return torch.cat(mos_scores, dim=0)
return torch.unbind(mos)


@lru_cache(1)
Expand Down
Loading