diff --git a/src/tidytunes/bin/process_audio.py b/src/tidytunes/bin/process_audio.py index 790b825..349e550 100644 --- a/src/tidytunes/bin/process_audio.py +++ b/src/tidytunes/bin/process_audio.py @@ -36,9 +36,16 @@ def process_audio(audios, device, pipeline_components): for name, func, kwargs, filter_fn in pipeline_components: + print(name) + + # to improve batching efficiency + audio_segments = sorted(audio_segments, key=lambda x: x.duration) + values = func(audio_segments, device=device, **kwargs) if filter_fn: - audio_segments, _ = partition(audio_segments, by=filter_fn(values)) + audio_segments, _ = partition( + audio_segments, by=[filter_fn(v) for v in values] + ) else: audio_segments = trim_audios(audio_segments, values) diff --git a/src/tidytunes/models/external/silerovad.py b/src/tidytunes/models/external/silerovad.py index a2ba48b..d0ea6f8 100644 --- a/src/tidytunes/models/external/silerovad.py +++ b/src/tidytunes/models/external/silerovad.py @@ -103,7 +103,11 @@ def forward(self, audio_chunk_16khz: torch.Tensor, state: list[torch.Tensor]): x = audio_chunk_16khz.unsqueeze(1) - x = F.pad(x, (self.audio_padding_size, self.audio_padding_size), mode="reflect") + x = F.pad( + x.float().contiguous(), + (self.audio_padding_size, self.audio_padding_size), + mode="reflect", + ) x = self.input_conv(x) a, b = torch.pow(x, 2).chunk(2, dim=1) diff --git a/src/tidytunes/models/source_separation.py b/src/tidytunes/models/source_separation.py index 764b412..bc48148 100644 --- a/src/tidytunes/models/source_separation.py +++ b/src/tidytunes/models/source_separation.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F from torchaudio.transforms import Fade from tidytunes.utils import masked_mean, masked_std, sequence_mask @@ -9,104 +10,126 @@ class SourceSeparator(nn.Module): def __init__( self, model, - segment: float = 10.0, - overlap: float = 0.1, + frame_shift: float = 0.16, + segment_frames: int = 63, + overlap_frames: int = 5, sampling_rate: int = 44100, + max_music_energy: float = 0.01, + min_speech_energy: float = 0.99, + window_frames: int = 2, + minimal_energy: float = 1e-7, ): super().__init__() self.model = model + self.frame_shift = frame_shift self.sampling_rate = sampling_rate - self.chunk_len = int(sampling_rate * segment * (1 + overlap)) - self.overlap_frames = int(overlap * sampling_rate) - self.fade = Fade( - fade_in_len=0, fade_out_len=self.overlap_frames, fade_shape="linear" - ) + self.frame_samples = int(frame_shift * sampling_rate) + self.segment_samples = self.frame_samples * segment_frames + self.overlap_samples = self.frame_samples * overlap_frames + self.fade_in = Fade(fade_in_len=self.overlap_samples, fade_out_len=0) + self.fade_out = Fade(fade_in_len=0, fade_out_len=self.overlap_samples) + self.max_music_energy = max_music_energy + self.min_speech_energy = min_speech_energy + self.window_samples = self.frame_samples * window_frames + self.minimal_energy = minimal_energy def forward( self, - x: torch.Tensor, - x_lens: torch.Tensor, + audio: torch.Tensor, + audio_lens: torch.Tensor, ): """ - Normalizes and processes input audio to separate sources. + Normalizes and processes input audio to separate sources and calculates + the energy of vocals and the rest of sources within a sliding window to + decide if there is a background music together with speech or not. Args: - x (B, T): Input audio tensor. - x_lens (B,): Lengths of each sequence in the batch. + audio (B, T): Input audio tensor. + audio_lens (B,): Lengths of each sequence in the batch. Returns: - (B, sources, L): Separated audio sources. + (B, L): Mask indicating frames without music. """ - mask = sequence_mask(x_lens) - mean = masked_mean(x, mask) - std = masked_std(x, mask, mean=mean) + B, T = audio.shape - x = (x - mean.unsqueeze(-1)) / std.unsqueeze(-1) - x[~mask] = 0.0 + # Pad to nearest multiple of segment and add overlap + padded_lens = ( + (T + self.segment_samples - 1) // self.segment_samples + ) * self.segment_samples + pad_size = padded_lens - T + self.overlap_samples - y = self.get_sources(x) - y = y * std[:, None, None] + mean[:, None, None] - mask = mask.unsqueeze(1).repeat(1, 4, 1) - y[~mask] = 0.0 + audio = F.pad(audio, (0, pad_size)) - # (B, sources, T), sources are: drums, bass, other, vocals - return y + mask = sequence_mask(audio_lens, max_length=audio.shape[-1]) + mean = masked_mean(audio, mask) + std = masked_std(audio, mask, mean=mean) - @torch.no_grad() - def get_sources( - self, - audio: torch.Tensor, - ): - """ - Splits audio into overlapping chunks, processes with the model, - and applies fade-in/fade-out to smooth transitions. + x = (audio - mean.unsqueeze(-1)) / std.unsqueeze(-1) + x[~mask] = 0.0 - Args: - audio (B, T): Normalized input audio. + audio_buffer = torch.zeros(x.shape[0], self.overlap_samples, device=x.device) + window_buffer = None + output = [] - Returns: - (B, sources, T): Separated sources. - """ + for i in range((x.shape[-1] - self.overlap_samples) // self.segment_samples): - # The model expects stereo inputs - audio = audio.unsqueeze(1).expand(-1, 2, -1) - B, C, L = audio.shape + s = i * self.segment_samples + e = (i + 1) * self.segment_samples + self.overlap_samples + segment = x[..., s:e] - if L <= self.chunk_len: - return self.model(audio).mean(dim=-2) + assert segment.shape[-1] == self.segment_samples + self.overlap_samples - output = torch.zeros(B, len(self.model.sources), L, device=audio.device) - buffer = None + if i > 0: + segment = self.fade_in(segment) + segment[..., : self.overlap_samples] += audio_buffer + segment = self.fade_out(segment) + audio_buffer = segment[..., -self.overlap_samples :] + segment = segment[..., : -self.overlap_samples] - start, end = 0, self.chunk_len - while start < L - self.overlap_frames: - chunk = audio[:, :, start:end] - x = self.model(chunk) - x = self.fade(x) + y = self.forward_segment(segment) + y = y * std[:, None, None] + mean[:, None, None] - chunk_output = x[..., : x.shape[-1] - self.fade.fade_out_len] + if window_buffer is not None: + y = torch.cat([window_buffer, y], dim=-1) + window_buffer = y[..., -self.frame_samples :] - if self.fade.fade_in_len > 0: - chunk_output[..., : self.fade.fade_in_len] += buffer - buffer = x[..., x.shape[-1] - self.fade.fade_out_len :] + frames = y.unfold(-1, self.window_samples, self.frame_samples) - output[..., start : start + chunk_output.shape[-1]] = chunk_output.mean( - dim=-2 + energy = (frames**2).mean(dim=-1) # b c t w -> b c t + + # Merge all non-vocal channels into a single one + energy = torch.cat( + (energy[:, :3].sum(dim=-2, keepdim=True), energy[:, 3:]), dim=-2 ) + energy_total = energy.sum(dim=-2, keepdim=True) + rel_energy = energy / energy_total - if start == 0: - self.fade.fade_in_len = self.overlap_frames - start += self.chunk_len - self.overlap_frames - else: - start += self.chunk_len + abs_silence = energy_total.squeeze(1) < self.minimal_energy + no_music = abs_silence | (rel_energy[:, 0] <= self.max_music_energy) + is_speech = abs_silence | (rel_energy[:, 1] >= self.min_speech_energy) + output.append(no_music & is_speech) - end += self.chunk_len - if end >= L: - self.fade.fade_out_len = 0 + output = torch.cat(output, dim=-1) - # reset the original chunk fading - self.fade.fade_in_len = 0 - self.fade.fade_out_len = self.overlap_frames + # Trim output to remove segment padding and mask invalid positions + n_frames = ( + audio_lens + 2 * self.frame_samples - 1 - self.window_samples + ) // self.frame_samples + output = output[..., : n_frames.max()] + mask = sequence_mask(n_frames) + output[~mask] = False return output + + @torch.no_grad() + def forward_segment( + self, + x: torch.Tensor, + ): + # The model expects stereo inputs + x = x.unsqueeze(1).expand(-1, 2, -1) + B, C, L = x.shape + assert L == self.segment_samples + x = self.model(x).mean(dim=-2) + return x diff --git a/src/tidytunes/pipeline_components/denoising.py b/src/tidytunes/pipeline_components/denoising.py index 01f7ad1..679d479 100644 --- a/src/tidytunes/pipeline_components/denoising.py +++ b/src/tidytunes/pipeline_components/denoising.py @@ -5,15 +5,17 @@ from tidytunes.utils import ( Audio, - chunk_list, + batched, collate_audios, decollate_audios, sequence_mask, ) +@batched(batch_size=1024, batch_duration=1280.0) def denoise( - audio: list[Audio], device: str = "cpu", batch_size: int = 32 + audio: list[Audio], + device: str = "cpu", ) -> list[Audio]: """ Apply denoising to a list of audio samples using a pre-trained model. @@ -21,38 +23,28 @@ def denoise( Args: audio (list[Audio]): List of audio objects to be denoised. device (str): The device to run the denoising model on (default: "cpu"). - batch_size (int): Number of audio samples to process in a batch (default: 32). Returns: list[Audio]: List of denoised audio objects. """ denoiser = load_denoiser(device) - denoised = [] - - for audio_batch in chunk_list(audio, batch_size): - audio_tensor, audio_lengths = collate_audios( - audio_batch, denoiser.sampling_rate - ) - mask = sequence_mask(audio_lengths.to(device)) - with torch.no_grad(): - denoised_audio = denoiser(audio_tensor.to(device), mask) - denoised.extend( - decollate_audios( - denoised_audio, - audio_lengths, - denoiser.sampling_rate, - origin_like=audio_batch, - ) - ) - return denoised + audio_tensor, audio_lengths = collate_audios(audio, denoiser.sampling_rate) + mask = sequence_mask(audio_lengths.to(device)) + with torch.no_grad(): + denoised_audio = denoiser(audio_tensor.to(device), mask) + return decollate_audios( + denoised_audio, + audio_lengths, + denoiser.sampling_rate, + origin_like=audio, + ) def get_denoised_pesq( audio: list[Audio], sampling_rate: int = 16000, device: str = "cpu", - batch_size: int = 32, ) -> torch.Tensor: """ Compute the Perceptual Evaluation of Speech Quality (PESQ) score between original and denoised audio. @@ -61,23 +53,20 @@ def get_denoised_pesq( audio (list[Audio]): List of audio objects to be denoised. sampling_rate (int): The target sampling rate for PESQ computation (default: 16000 Hz). device (str): The device to run the denoising model on (default: "cpu"). - batch_size (int): Number of audio samples to process in a batch (default: 32). Returns: torch.Tensor: Tensor containing PESQ scores for each input Audio. """ - denoised = denoise(audio, device, batch_size) - return torch.tensor( - [ - pesq( - sampling_rate, - ref.resample(sampling_rate).as_tensor().cpu().numpy(), - enh.resample(sampling_rate).as_tensor().cpu().numpy(), - on_error=PesqError.RETURN_VALUES, - ) - for ref, enh in zip(audio, denoised) - ] - ) + denoised = denoise(audio, device) + return [ + pesq( + sampling_rate, + ref.resample(sampling_rate).as_tensor().cpu().numpy(), + enh.resample(sampling_rate).as_tensor().cpu().numpy(), + on_error=PesqError.RETURN_VALUES, + ) + for ref, enh in zip(audio, denoised) + ] @lru_cache(maxsize=1) diff --git a/src/tidytunes/pipeline_components/dnsmos.py b/src/tidytunes/pipeline_components/dnsmos.py index 836b2b2..235e784 100644 --- a/src/tidytunes/pipeline_components/dnsmos.py +++ b/src/tidytunes/pipeline_components/dnsmos.py @@ -2,15 +2,15 @@ import torch -from tidytunes.utils import Audio, chunk_list, collate_audios +from tidytunes.utils import Audio, batched, collate_audios +@batched(batch_size=1024, batch_duration=1280.0) def get_dnsmos( audio: list[Audio], personalized: bool = True, device: str = "cpu", num_threads: int | None = 8, - batch_size: int = 32, ) -> torch.Tensor: """ Computes DNSMOS (Deep Noise Suppression Mean Opinion Score) for a batch of audio clips. @@ -20,21 +20,18 @@ def get_dnsmos( personalized (bool): Whether to use a personalized model (default: True). device (str): The device to run the model on (default: "cpu"). num_threads (int | None): Number of threads to use for ONNX inference (default: 8). - batch_size (int): Batch size for processing (default: 32). Returns: torch.Tensor: Tensor containing DNSMOS scores for each input audio clip. """ + model = load_dnsmos_model(device, personalized, num_threads) - mos_scores = [] - for audio_batch in chunk_list(audio, batch_size): - a, al = collate_audios(audio_batch, model.sampling_rate) - with torch.no_grad(): - _, _, _, mos = model(a.to(device), al.to(device)) - mos_scores.append(mos) + a, al = collate_audios(audio, model.sampling_rate) + with torch.no_grad(): + _, _, _, mos = model(a.to(device), al.to(device)) - return torch.cat(mos_scores, dim=0) + return torch.unbind(mos) @lru_cache(1) diff --git a/src/tidytunes/pipeline_components/gender_classification.py b/src/tidytunes/pipeline_components/gender_classification.py index 28d5dbd..4287538 100644 --- a/src/tidytunes/pipeline_components/gender_classification.py +++ b/src/tidytunes/pipeline_components/gender_classification.py @@ -3,27 +3,34 @@ import torch from tidytunes.pipeline_components.speaker_segmentation import load_speaker_encoder -from tidytunes.utils import Audio, collate_audios +from tidytunes.utils import Audio, batched, collate_audios -def is_male(audios: list[Audio], device="cpu"): +@batched(batch_size=1024, batch_duration=1280.0) +def is_male( + audio: list[Audio], + device: str = "cpu", +): + """ + Classifies gender of the speaker in the input audios + + Args: + audio (list[Audio]): List of audio objects. + device (str): Device to run the model on (default: "cpu"). + + Returns: + list[bool]: List of booleans for each input audio, True for males, False for females. + """ speaker_encoder = load_speaker_encoder(device=device) model = load_gender_classification_model() + a, al = collate_audios(audio, sampling_rate=speaker_encoder.sampling_rate) with torch.no_grad(): - audio, audio_lens = collate_audios( - audios, sampling_rate=speaker_encoder.sampling_rate - ) - audio = audio.to(device) - audio_lens = audio_lens.to(device) - embeddings = speaker_encoder(audio, audio_lens) - embeddings_flattened = [e.mean(dim=0) for e in embeddings] - + embeddings = speaker_encoder(a.to(device), al.to(device)) classifications = [ - model.predict(e.cpu().numpy().reshape(1, -1)) for e in embeddings_flattened + model.predict(e.mean(dim=0).cpu().numpy().reshape(1, -1)) for e in embeddings ] - return [c == 1 for c in classifications] diff --git a/src/tidytunes/pipeline_components/language_id.py b/src/tidytunes/pipeline_components/language_id.py index 8b684ae..327029d 100644 --- a/src/tidytunes/pipeline_components/language_id.py +++ b/src/tidytunes/pipeline_components/language_id.py @@ -2,11 +2,14 @@ import torch -from tidytunes.utils import Audio, collate_audios +from tidytunes.utils import Audio, batched, collate_audios +@batched(batch_size=1024, batch_duration=1280.0) def get_language_probabilities( - audio: list[Audio], language_code: str, device: str = "cpu" + audio: list[Audio], + language_code: str, + device: str = "cpu", ) -> torch.Tensor: """ Compute the probability of a given language being spoken in the audio. @@ -24,9 +27,8 @@ def get_language_probabilities( audio_16khz, audio_16khz_lens = collate_audios(audio, 16000) with torch.no_grad(): out_prob, _, _ = model(audio_16khz.to(device), audio_16khz_lens.to(device)) - lang_prob = torch.stack([p[lab2ind[language_code]] for p in out_prob]) - return lang_prob + return [p[lab2ind[language_code]] for p in out_prob] @lru_cache(1) diff --git a/src/tidytunes/pipeline_components/source_separation.py b/src/tidytunes/pipeline_components/source_separation.py index 10fd644..0b2dc6f 100644 --- a/src/tidytunes/pipeline_components/source_separation.py +++ b/src/tidytunes/pipeline_components/source_separation.py @@ -4,18 +4,15 @@ from tidytunes.utils import ( Audio, + batched, collate_audios, frame_labels_to_time_segments, - masked_mean, - sequence_mask, ) +@batched(batch_size=1024, batch_duration=1280.0) def find_segments_without_music( - audios: list[Audio], - frame_shift: float = 0.16, - max_music_energy: float = 0.01, - min_speech_energy: float = 0.99, + audio: list[Audio], min_duration: float = 6.4, device: str = "cpu", ): @@ -35,49 +32,17 @@ def find_segments_without_music( """ demucs = load_demucs(device) - audio, audio_lens = collate_audios(audios, demucs.sampling_rate) - audio, audio_lens = audio.to(device), audio_lens.to(device) - + a, al = collate_audios(audio, demucs.sampling_rate) with torch.no_grad(): - sources = demucs(audio, audio_lens) - - B, C, T = sources.shape - hop_length = int(frame_shift * demucs.sampling_rate) - window_size = hop_length * 2 - - is_speech_without_music = [] - n_frames = (audio_lens - window_size) // hop_length + 1 - - # NOTE: Simply unfolding the whole sources can easily cause OOMs for longer - # inputs, so we rather go slowly frame by frame to be memory efficient - # TODO: Implement chunk-wise inference instead of frame-wise - for i in range(T // hop_length): - start, end = i * hop_length, i * hop_length + window_size - energy = masked_mean(sources[..., start:end] ** 2) - - mask = sequence_mask(n_frames.clamp(max=1, min=0), max_length=1) - energy[~mask.expand_as(energy)] = 0.0 - n_frames -= 1 - - # Merge all non-vocal channels into a single one - energy = torch.cat( - (energy[..., :3].sum(dim=-1, keepdim=True), energy[..., 3:]), dim=-1 - ) - energy /= energy.sum(dim=-1, keepdim=True) - - no_music = energy[..., 0] <= max_music_energy - is_speech = energy[..., 1] >= min_speech_energy - is_speech_without_music.append(no_music & is_speech) - - is_speech_without_music = torch.stack(is_speech_without_music, dim=1) + speech_without_music_mask = demucs(a.to(device), al.to(device)) return [ frame_labels_to_time_segments( - frames, - frame_shift, + m, + demucs.frame_shift, filter_with=lambda x: (x.symbol is True) & (x.duration >= min_duration), ) - for frames in is_speech_without_music + for m in speech_without_music_mask ] diff --git a/src/tidytunes/pipeline_components/speaker_segmentation.py b/src/tidytunes/pipeline_components/speaker_segmentation.py index a9ca858..f30cb8e 100644 --- a/src/tidytunes/pipeline_components/speaker_segmentation.py +++ b/src/tidytunes/pipeline_components/speaker_segmentation.py @@ -4,7 +4,12 @@ import torch.nn.functional as F from sklearn.cluster import AgglomerativeClustering, KMeans -from tidytunes.utils import Audio, collate_audios, frame_labels_to_time_segments +from tidytunes.utils import ( + Audio, + batched, + collate_audios, + frame_labels_to_time_segments, +) def find_segments_with_single_speaker( @@ -18,6 +23,7 @@ def find_segments_with_single_speaker( ): """ Identifies segments in the audio where only a single speaker is present. + *** All input segments are supposed to come from a single source. *** Args: audio (list[Audio]): List of audio objects. @@ -31,19 +37,11 @@ def find_segments_with_single_speaker( Returns: list[list[Segment]]: List of speaker segments for each input audio. """ - speaker_encoder = load_speaker_encoder(num_frames=frame_shift, device=device) - - audio, audio_lens = collate_audios( - audio, sampling_rate=speaker_encoder.sampling_rate - ) - audio, audio_lens = audio.to(device), audio_lens.to(device) - with torch.no_grad(): - embeddings = speaker_encoder(audio, audio_lens) - embeddings_all = torch.cat(embeddings, dim=0) + embeddings = get_speaker_embeddings(audio, frame_shift, device) + embeddings_all = torch.cat(embeddings, dim=0) centroids = find_cluster_centers(embeddings_all, num_clusters) - labels = [ F.cosine_similarity(e.unsqueeze(1), centroids.unsqueeze(0), dim=-1).argmax( dim=-1 @@ -51,6 +49,7 @@ def find_segments_with_single_speaker( for e in embeddings ] + speaker_encoder = load_speaker_encoder(num_frames=frame_shift, device=device) frame_shift_seconds = ( frame_shift * speaker_encoder.hop_length / speaker_encoder.sampling_rate ) @@ -73,6 +72,17 @@ def find_segments_with_single_speaker( return time_segments +@batched(batch_size=1024, batch_duration=1280.0) +def get_speaker_embeddings( + audio: list[Audio], frame_shift: int = 64, device: str = "cpu" +): + speaker_encoder = load_speaker_encoder(num_frames=frame_shift, device=device) + a, al = collate_audios(audio, sampling_rate=speaker_encoder.sampling_rate) + with torch.no_grad(): + e = speaker_encoder(a.to(device), al.to(device)) + return e + + def find_cluster_centers(embeddings: torch.Tensor, num_clusters): """ Clusters speaker embeddings and refines cluster centers. diff --git a/src/tidytunes/pipeline_components/vad.py b/src/tidytunes/pipeline_components/vad.py index 7080a9d..270acfa 100644 --- a/src/tidytunes/pipeline_components/vad.py +++ b/src/tidytunes/pipeline_components/vad.py @@ -2,9 +2,15 @@ import torch -from tidytunes.utils import Audio, collate_audios, frame_labels_to_time_segments +from tidytunes.utils import ( + Audio, + batched, + collate_audios, + frame_labels_to_time_segments, +) +@batched(batch_size=1024, batch_duration=1280.0) def find_segments_with_speech( audio: list[Audio], min_duration: float = 3.2, @@ -24,12 +30,13 @@ def find_segments_with_speech( list[list[Segment]]: Time segments containing speech for each input Audio. """ vad = load_vad(device) + audio_tensor, _ = collate_audios(audio, vad.sampling_rate) with torch.no_grad(): speech_mask = vad(audio_tensor.to(device)) speech_mask[..., :-1] += speech_mask[..., 1:].clone() # Pre-bounce speech starts - time_segments = [ + return [ frame_labels_to_time_segments( m, vad.frame_shift, @@ -38,7 +45,6 @@ def find_segments_with_speech( ) for m in speech_mask ] - return time_segments @lru_cache(maxsize=1) diff --git a/src/tidytunes/utils/__init__.py b/src/tidytunes/utils/__init__.py index 910cee5..18f6edb 100644 --- a/src/tidytunes/utils/__init__.py +++ b/src/tidytunes/utils/__init__.py @@ -1,6 +1,12 @@ from .audio import Audio, collate_audios, decollate_audios, trim_audios from .download import download_github -from .etc import chunk_list, frame_labels_to_time_segments, partition +from .etc import ( + SpeculativeBatcher, + batched, + frame_labels_to_time_segments, + partition, + to_batches, +) from .logging import setup_logger from .tensors import masked_mean, masked_std, sequence_mask from .trace import TraceMixin diff --git a/src/tidytunes/utils/etc.py b/src/tidytunes/utils/etc.py index 0c973a5..b85fd78 100644 --- a/src/tidytunes/utils/etc.py +++ b/src/tidytunes/utils/etc.py @@ -1,6 +1,11 @@ +import inspect +from functools import wraps +from typing import Any + import torch -from .audio import Segment +from .audio import Audio, Segment +from .memory import is_oom_error def partition(lst: list, by: list, other: list | None = None) -> tuple[list, list]: @@ -59,9 +64,107 @@ def frame_labels_to_time_segments( return segments -def chunk_list(lst: list, chunk_size: int) -> list[list]: +def to_batches(audios: list[Audio], max_size: int, max_duration: float) -> list[list]: """ - Split input list `lst` into lists of length `chunk_size`. The last item can be - incomplete if the length of input is not divisible by chunk size. + Split input list `audios` into lists of length of at most `max_size`, but at + least 1, while containing Audio objects with duration of at most `max_duration` + (might be violated when a about to return only a single element). """ - return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + assert max_size >= 1 + assert max_duration > 0.0 + + batches, batch = [], [] + for audio in audios: + + if len(batch) == 0: + batch.append(audio) + continue + + total_duration = max(a.duration for a in (batch + [audio])) * (len(batch) + 1) + if (len(batch) == max_size) or (total_duration > max_duration): + batches.append(batch) + batch = [] + + batch.append(audio) + + if len(batch) > 0: + batches.append(batch) + + return batches + + +class SpeculativeBatcher: + def __init__( + self, + max_size: int, + init_max_duration: float, + growth_factor: float = 1.1, + backoff_factor: float = 0.95, + growth_interval: int = 100, + growth_interval_factor: float = 2.0, + ): + self.max_size = max_size + self.max_duration = init_max_duration + self.growth_factor = growth_factor + self.backoff_factor = backoff_factor + self.growth_interval = growth_interval + self.growth_interval_factor = growth_interval_factor + self._reset_counter() + + def _reset_counter(self): + self.counter = self.growth_interval + + def _increase(self): + self.max_duration *= self.growth_factor + self.growth_interval *= self.growth_factor + + def _decrease(self): + self.max_duration *= self.backoff_factor + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if is_oom_error(exc_value): + self._decrease() + self._reset_counter() + return True + self.counter -= 1 + if self.counter <= 0: + self._increase() + self._reset_counter() + return False + + def __call__(self, audios: list[Audio]): + return to_batches(audios, self.max_size, self.max_duration) + + +def batched(batch_size, batch_duration): + + num_retries = 100 + + def decorator(func): + batcher = SpeculativeBatcher(batch_size, batch_duration) + + @wraps(func) + def wrapper(*args, **kwargs): + + sig = inspect.signature(func) + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + audio = bound_args.arguments["audio"] + + for _ in range(num_retries): + with batcher: + outputs = [] + for batch in batcher(audio): + bound_args.arguments["audio"] = batch + o = func(*bound_args.args, **bound_args.kwargs) + outputs.extend(o) + return outputs + else: + raise RuntimeError("OOM, failed to find a suitable batch size!") + + return wrapper + + return decorator diff --git a/src/tidytunes/utils/memory.py b/src/tidytunes/utils/memory.py new file mode 100644 index 0000000..468a13e --- /dev/null +++ b/src/tidytunes/utils/memory.py @@ -0,0 +1,45 @@ +import gc + +import torch + + +def is_oom_error(exception: BaseException) -> bool: + return ( + is_cuda_out_of_memory(exception) + or is_cudnn_snafu(exception) + or is_out_of_cpu_memory(exception) + ) + + +def is_cuda_out_of_memory(exception: BaseException) -> bool: + return ( + isinstance(exception, RuntimeError) + and len(exception.args) == 1 + and "CUDA" in exception.args[0] + and "out of memory" in exception.args[0] + ) + + +def is_cudnn_snafu(exception: BaseException) -> bool: + return ( + isinstance(exception, RuntimeError) + and len(exception.args) == 1 + and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] + ) + + +def is_out_of_cpu_memory(exception: BaseException) -> bool: + return ( + isinstance(exception, RuntimeError) + and len(exception.args) == 1 + and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] + ) + + +def garbage_collection_cuda() -> None: + gc.collect() + try: + torch.cuda.empty_cache() + except RuntimeError as exception: + if not is_oom_error(exception): + raise