-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdataloader.py
77 lines (66 loc) · 3.08 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import sys
sys.path.append('./AutoAugment/')
import torch
import torchvision
import torchvision.transforms as transforms
from AutoAugment.autoaugment import CIFAR10Policy, SVHNPolicy
def get_dataloaders(args):
train_transform, test_transform = get_transform(args)
if args.dataset == "c10":
train_ds = torchvision.datasets.CIFAR10('./datasets', train=True, transform=train_transform, download=True)
test_ds = torchvision.datasets.CIFAR10('./datasets', train=False, transform=test_transform, download=True)
args.num_classes = 10
elif args.dataset == "c100":
train_ds = torchvision.datasets.CIFAR100('./datasets', train=True, transform=train_transform, download=True)
test_ds = torchvision.datasets.CIFAR100('./datasets', train=False, transform=test_transform, download=True)
args.num_classes = 100
elif args.dataset == "svhn":
train_ds = torchvision.datasets.SVHN('./datasets', split='train', transform=train_transform, download=True)
test_ds = torchvision.datasets.SVHN('./datasets', split='test', transform=test_transform, download=True)
args.num_classes = 10
else:
raise ValueError(f"No such dataset:{args.dataset}")
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
return train_dl, test_dl
def get_transform(args):
if args.dataset in ["c10", "c100", 'svhn']:
args.padding=4
args.size = 32
if args.dataset=="c10":
args.mean, args.std = [0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]
elif args.dataset=="c100":
args.mean, args.std = [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]
elif args.dataset=="svhn":
args.mean, args.std = [0.4377, 0.4438, 0.4728], [0.1980, 0.2010, 0.1970]
else:
args.padding=28
args.size = 224
args.mean, args.std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
train_transform_list = [transforms.RandomCrop(size=(args.size,args.size), padding=args.padding)]
if args.dataset!="svhn":
train_transform_list.append(transforms.RandomCrop(size=(args.size,args.size), padding=args.padding))
if args.autoaugment:
if args.dataset == 'c10' or args.dataset=='c100':
train_transform_list.append(CIFAR10Policy())
elif args.dataset == 'svhn':
train_transform_list.append(SVHNPolicy())
else:
print(f"No AutoAugment for {args.dataset}")
train_transform = transforms.Compose(
train_transform_list+[
transforms.ToTensor(),
transforms.Normalize(
mean=args.mean,
std = args.std
)
]
)
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=args.mean,
std = args.std
)
])
return train_transform, test_transform