-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathETRI_Dataset.py
79 lines (63 loc) · 2.32 KB
/
ETRI_Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import gc
import shutil
import pandas as pd
import torch
import torch.nn.functional as F
import torchvision
import torchaudio
from torch.utils.data import Dataset
from typing import Callable
from knusl import KnuSL
import math
import numpy as np
class ETRI_Corpus_Dataset(Dataset):
def __init__(self, path, tokenizer, transform : Callable=None, length :float = 1.5) -> None:
super().__init__()
# self.path = os.path.join(path, "ETRI_Backchannel_Corpus_2022")
print("Load ETRI_Corpus_Dataset...")
self.tokenizer = tokenizer
self.path = os.path.join("/local_datasets/BC/etri_last/")
self.length = length
self.annotation = pd.read_csv('/data/minjae/BC/etri_last.tsv',delimiter='\t',encoding='utf-8')
self.sr = 16000
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ret = {}
c3 = {}
item = self.annotation.iloc[index]
# print(item)
trans = item['transcript']
lable = item['BC']
start = item['start']
end = item['end']
role = item['role']
role = role == 1
path = os.path.join(self.path, f"{str(index)}.wav")
audio, sr = torchaudio.load(path)
resampler = torchaudio.transforms.Resample(sr, 16000)
if audio.size(1)>0:
audio = resampler(audio)
audio = audio[role:role+1, -int(self.length*self.sr):]
if audio.size(1) != int(self.sr * 1.5):
audio = F.pad(audio, (0, int(self.sr * 1.5) - audio.size(1)), "constant", 0)
sentiment = torch.zeros(5)
for word in trans.split(' '):
r_word, s_word = KnuSL.data_list(word)
if s_word != 'None':
sentiment[int(s_word)] += 1
else:
sentiment[0] += 1
sentiment = sentiment / sentiment.sum()
trans = self.tokenizer(trans, padding='max_length', max_length=10, truncation=True, return_tensors="pt")['input_ids'].squeeze()
if lable == 0:
NoBC = 0
else:
NoBC = 1
ret['audio'] = audio
ret['label'] = lable
ret['BClabel'] = NoBC
ret['text'] = trans
ret['sentiment'] = sentiment
return ret