Skip to content

Commit

Permalink
Added batching to gender classification
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomiinek committed Feb 28, 2025
1 parent 358f90f commit d159111
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions src/tidytunes/pipeline_components/gender_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,41 @@
import torch

from tidytunes.pipeline_components.speaker_segmentation import load_speaker_encoder
from tidytunes.utils import Audio, collate_audios
from tidytunes.utils import Audio, collate_audios, to_batches


def is_male(audios: list[Audio], device="cpu"):
def is_male(
audio: list[Audio],
device: str = "cpu",
batch_size: int = 64,
batch_duration: float = 1280.0,
):
"""
Classifies gender of the speaker in the input audios
Args:
audio (list[Audio]): List of audio objects.
device (str): Device to run the model on (default: "cpu").
batch_size (int): Maximal number of audio samples to process in a batch (default: 64).
batch_duration (float): Maximal duration of audio samples to process in a batch (default: 1280.0)
Returns:
list[bool]: List of booleans for each input audio, True for males, False for females.
"""

speaker_encoder = load_speaker_encoder(device=device)
model = load_gender_classification_model()
embeddings = []

for audio_batch in to_batches(audio, batch_size, batch_duration):

with torch.no_grad():
audio, audio_lens = collate_audios(
audios, sampling_rate=speaker_encoder.sampling_rate
)
audio = audio.to(device)
audio_lens = audio_lens.to(device)
embeddings = speaker_encoder(audio, audio_lens)
embeddings_flattened = [e.mean(dim=0) for e in embeddings]
a, al = collate_audios(audio_batch, sampling_rate=speaker_encoder.sampling_rate)
with torch.no_grad():
be = speaker_encoder(a.to(device), al.to(device))
embeddings.extend([e.mean(dim=0) for e in be])

classifications = [
model.predict(e.cpu().numpy().reshape(1, -1)) for e in embeddings_flattened
model.predict(e.cpu().numpy().reshape(1, -1)) for e in embeddings
]

return [c == 1 for c in classifications]
Expand Down

0 comments on commit d159111

Please sign in to comment.