-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathload_dataset.py
83 lines (65 loc) · 2.69 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# The following was adapted from https://github.com/macaodha/inat_comp_2018/blob/master/inat2018_loader.py
# Repurposed for iNat19 dataset and modified transformations
import torch.utils.data as data
from PIL import Image
import os
import json
from torchvision import transforms
import random
import numpy as np
from autoaugment import ImageNetPolicy
def default_loader(path):
return Image.open(path).convert('RGB')
class INAT(data.Dataset):
def __init__(self, root, ann_file, is_train=True, size=299):
# load annotations
print('Loading annotations from: ' + os.path.basename(ann_file))
with open(ann_file) as data_file:
ann_data = json.load(data_file)
# set up the filenames and annotations
self.imgs = [aa['file_name'] for aa in ann_data['images']]
self.ids = [aa['id'] for aa in ann_data['images']]
# if we dont have class labels set them to '0'
if 'annotations' in ann_data.keys():
self.classes = [aa['category_id'] for aa in ann_data['annotations']]
else:
self.classes = [0]*len(self.imgs)
# print out some stats
print('\t' + str(len(self.imgs)) + ' images')
print('\t' + str(len(set(self.classes))) + ' classes')
self.root = root
self.is_train = is_train
self.loader = default_loader
# augmentation params
self.im_size = [size, size]
self.mu_data = [0.485, 0.456, 0.406]
self.std_data = [0.229, 0.224, 0.225]
self.brightness = 0.4
self.contrast = 0.4
self.saturation = 0.4
self.hue = 0.25
# augmentations
self.center_crop = transforms.CenterCrop((self.im_size[0], self.im_size[1]))
self.scale_aug = transforms.RandomResizedCrop(size=self.im_size[0])
self.flip_aug = transforms.RandomHorizontalFlip()
self.color_aug = transforms.ColorJitter(self.brightness, self.contrast, self.saturation, self.hue)
self.tensor_aug = transforms.ToTensor()
self.norm_aug = transforms.Normalize(mean=self.mu_data, std=self.std_data)
self.autoaugment = ImageNetPolicy()
def __getitem__(self, index):
path = self.root + self.imgs[index]
im_id = self.ids[index]
img = self.loader(path)
species_id = self.classes[index]
if self.is_train:
img = self.scale_aug(img)
img = self.flip_aug(img)
#img = self.color_aug(img)
img = self.autoaugment(img)
else:
img = self.center_crop(img)
img = self.tensor_aug(img)
img = self.norm_aug(img)
return img, im_id, species_id
def __len__(self):
return len(self.imgs)