-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathdatasets.py
211 lines (166 loc) · 7.17 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
import json
import os
from glob import glob
from utils import load_img, modcrop
DEG_MAP = {
"noise": 0,
"blur" : 1,
"rain" : 2,
"haze" : 3,
"lol" : 4,
"sr" : 5,
"en" : 6,
}
DEG2TASK = {
"noise": "denoising",
"blur" : "deblurring",
"rain" : "deraining",
"haze" : "dehazing",
"lol" : "lol",
"sr" : "sr",
"en" : "enhancement"
}
def augment_prompt (prompt):
### special prompts
lol_prompts = ["fix the illumination", "increase the exposure of the photo", "the image is too dark to see anything, correct the photo", "poor illumination, improve the shot", "brighten dark regions", "make it HDR", "improve the light of the image", "Can you make the image brighter?"]
sr_prompts = ["I need to enhance the size and quality of this image.", "My photo is lacking size and clarity; can you improve it?", "I'd appreciate it if you could upscale this photo.", "My picture is too little, enlarge it.", "upsample this image", "increase the resolution of this photo", "increase the number of pixels", "upsample this photo", "Add details to this image", "improve the quality of this photo"]
en_prompts = ["make my image look like DSLR", "improve the colors of my image", "improve the contrast of this photo", "apply tonemapping", "enhance the colors of the image", "retouch the photo like a photograper"]
init = np.random.choice(["Remove the", "Reduce the", "Clean the", "Fix the", "Remove", "Improve the", "Correct the",])
end = np.random.choice(["please", "fast", "now", "in the photo", "in the picture", "in the image", ""])
newp = f"{init} {prompt} {end}"
if "lol" in prompt:
newp = np.random.choice(lol_prompts)
elif "sr" in prompt:
newp = np.random.choice(sr_prompts)
elif "en" in prompt:
newp = np.random.choice(en_prompts)
newp = newp.strip().replace(" ", " ").replace("\n", "")
return newp
def get_deg_name(path):
"""
Get the degradation name from the path
"""
if ("gopro" in path) or ("GoPro" in path) or ("blur" in path) or ("Blur" in path) or ("RealBlur" in path):
return "blur"
elif ("SOTS" in path) or ("haze" in path) or ("sots" in path) or ("RESIDE" in path):
return "haze"
elif ("LOL" in path):
return "lol"
elif ("fiveK" in path):
return "en"
elif ("super" in path) or ("classicalSR" in path):
return "sr"
elif ("Rain100" in path) or ("rain13k" in path) or ("Rain13k" in path):
return "rain"
else:
return "noise"
def crop_img(image, base=16):
"""
Mod crop the image to ensure the dimension is divisible by base. Also done by SwinIR, Restormer and others.
"""
h = image.shape[0]
w = image.shape[1]
crop_h = h % base
crop_w = w % base
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
################# DATASETS
class RefDegImage(Dataset):
"""
Dataset for Image Restoration having low-quality image and the reference image.
Tasks: synthetic denoising, deblurring, super-res, etc.
"""
def __init__(self, hq_img_paths, lq_img_paths, augmentations=None, val=False, name="test", deg_name="noise", deg_class=0):
assert len(hq_img_paths) == len(lq_img_paths)
self.hq_paths = hq_img_paths
self.lq_paths = lq_img_paths
self.totensor = torchvision.transforms.ToTensor()
self.val = val
self.augs = augmentations
self.name = name
self.degradation = deg_name
self.deg_class = deg_class
if self.val:
self.augs = None # No augmentations during validation/test
def __len__(self):
return len(self.hq_paths)
def __getitem__(self, idx):
hq_path = self.hq_paths[idx]
lq_path = self.lq_paths[idx]
hq_image = load_img(hq_path)
lq_image = load_img(lq_path)
if self.val:
# if an image has an odd number dimension we trim for example from [321, 189] to [320, 188].
hq_image = crop_img(hq_image)
lq_image = crop_img(lq_image)
hq_image = self.totensor(hq_image.astype(np.float32))
lq_image = self.totensor(lq_image.astype(np.float32))
return hq_image, lq_image, hq_path
def create_testsets (testsets, debug=False):
"""
Given a list of testsets create pytorch datasets for each.
The method requires the paths to references and noisy images.
"""
assert len(testsets) > 0
if debug:
print (20*'****')
print ("Creating Testsets", len(testsets))
datasets = []
for testdt in testsets:
path_hq , path_lq = testdt[0], testdt[1]
if debug: print (path_hq , path_lq)
if ("denoising" in path_hq) or ("jpeg" in path_hq):
dataset_name = path_hq.split("/")[-1]
dataset_sigma = path_lq.split("/")[-1].split("_")[-1].split(".")[0]
dataset_name = dataset_name+ f"_{dataset_sigma}"
elif "Rain" in path_hq:
if "Rain100L" in path_hq:
dataset_name = "Rain100L"
else:
dataset_name = path_hq.split("/")[3]
elif ("gopro" in path_hq) or ("GoPro" in path_hq):
dataset_name = "GoPro"
elif "LOL" in path_hq:
dataset_name = "LOL"
elif "SOTS" in path_hq:
dataset_name = "SOTS"
elif "fiveK" in path_hq:
dataset_name = "MIT5K"
else:
assert False, f"{path_hq} - unknown dataset"
hq_img_paths = sorted(glob(os.path.join(path_hq, "*")))
lq_img_paths = sorted(glob(os.path.join(path_lq, "*")))
if "SOTS" in path_hq:
# Haze removal SOTS test dataset
dataset_name = "SOTS"
hq_img_paths = sorted(glob(os.path.join(path_hq, "*.jpg")))
assert len(hq_img_paths) == 500
lq_img_paths = [file.replace("GT", "IN") for file in hq_img_paths]
if "fiveK" in path_hq:
dataset_name = "MIT5K"
testf = "test-data/mit5k/test.txt"
f = open(testf, "r")
test_ids = f.readlines()
test_ids = [x.strip() for x in test_ids]
f.close()
hq_img_paths = [os.path.join(path_hq, f"{x}.jpg") for x in test_ids]
lq_img_paths = [x.replace("expertC", "input") for x in hq_img_paths]
assert len(hq_img_paths) == 498
if "gopro" in path_hq:
assert len(hq_img_paths) == 1111
if "LOL" in path_hq:
assert len(hq_img_paths) == 15
assert len(hq_img_paths) == len(lq_img_paths)
deg_name = get_deg_name(path_hq)
deg_class = DEG_MAP[deg_name]
valdts = RefDegImage(hq_img_paths = hq_img_paths,
lq_img_paths = lq_img_paths,
val = True, name= dataset_name, deg_name=deg_name, deg_class=deg_class)
datasets.append(valdts)
assert len(datasets) == len(testsets)
print (20*'****')
return datasets