-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransform.py
50 lines (35 loc) · 1.36 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torchaudio
from config import config
class OnsetTransform(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, audio, sr):
# resample if needed
if sr != config.audio.sample_rate:
audio = torchaudio.transforms.Resample(sr, config.audio.sample_rate)(audio)
# convert to mono
if len(audio.shape) > 1:
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=False)
else:
audio = audio.squeeze(0)
# calculate features
features = []
for n_fft in config.audio.n_ffts:
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=config.audio.sample_rate,
n_fft=n_fft,
hop_length=config.audio.hop_length,
n_mels=config.audio.n_bins,
f_min=config.audio.fmin,
f_max=config.audio.fmax
)(audio)
if config.audio.log:
mel = torch.log(mel + config.audio.log_eps)
features.append(mel)
# stack features to be (C, F, T) (channels, features, time)
features = torch.stack(features, dim=0)
# transpose to (T, F, C)
features = features.permute(2, 1, 0)
return features