-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtools_golden_subject.py
48 lines (29 loc) · 1.11 KB
/
tools_golden_subject.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import torch
from torch.utils.data import DataLoader
import random
import os
import glob
import torchvision.transforms as transforms
class cwtDataset(torch.utils.data.Dataset):
def __init__(self, root1, root2, root3):
self.files_A = root1
self.files_B = torch.from_numpy(root2)
self.files_C = root3
def __getitem__(self, index):
# item_A = self.files_A[random.randint(0, len(self.files_A) - 1)]
item_B = self.files_B[index % len(self.files_B)]
item_A = self.files_A[index % len(self.files_A)]
item_C = self.files_C[index % len(self.files_C)]
return {"A": item_A, "B": item_B, "C":item_C}
def __len__(self):
return max(len(self.files_A), len(self.files_B))
# return len(self.files_A)
class val_cwtDataset(torch.utils.data.Dataset):
def __init__(self, root1):
self.files_A = torch.from_numpy(root1)
def __getitem__(self, index):
item_A = self.files_A[index]
return item_A
def __len__(self):
return len(self.files_A)