-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathData_Retriever.py
87 lines (68 loc) · 2.52 KB
/
Data_Retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import numpy as np
import math
import pandas as pd
import os
import cv2
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from albumentations import (HorizontalFlip, Normalize, Compose, Resize)
from albumentations.pytorch import ToTensor
# In[2]:
class ImageData(Dataset):
def __init__(self, df, data_folder, mean, std, phase):
self.df = df
self.root = data_folder
self.mean = mean
self.std = std
self.phase = phase
self.transforms = self.get_transforms()
self.fnames = self.df.index.tolist()
def __len__(self):
return len(self.fnames)
def __getitem__(self, idx):
image_id, mask = self.make_mask(idx)
image_path = os.path.join(self.root, "train_images", image_id)
img = cv2.imread(image_path)
augmented = self.transforms(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask'] # 1x256x1600x4
mask = mask[0].permute(2, 0, 1) # 4x256x1600
return img, mask
def make_mask(self, row_id):
'''Given a row index, return image_id and mask (256, 1600, 4) from the dataframe `df`'''
fname = self.df.iloc[row_id].name
labels = self.df.iloc[row_id][:4]
masks = np.zeros((256, 1600, 4), dtype=np.float32) # float32 is V.Imp
# 4:class 1~4 (ch:0~3)
for idx, label in enumerate(labels.values):
if type(label) is str:
label = label.split(" ")
positions = map(int, label[0::2])
length = map(int, label[1::2])
mask = np.zeros(256 * 1600, dtype=np.uint8)
for pos, le in zip(positions, length):
mask[pos:(pos + le)] = 1
masks[:, :, idx] = mask.reshape(256, 1600, order='F')
return fname, masks
def get_transforms(self):
list_transforms = []
if self.phase == "train":
list_transforms.extend(
[
HorizontalFlip(p=0.5), # only horizontal flip as of now
]
)
list_transforms.extend(
[
Normalize(mean=self.mean, std=self.std, p=1),
Resize(224,224),
ToTensor(),
]
)
list_trfms = Compose(list_transforms)
return list_trfms
# In[ ]: