-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathevaluate.py
130 lines (116 loc) · 5.46 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
from fileinput import filename
from locale import locale_encoding_alias
import torch
import torch.nn as nn
from network.Math_Module import P, Q
from network.decom import Decom
import os
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import time
from utils import *
import glob
"""
As different illumination adjustment ratio will cause
different enhanced results. Certainly you can tune the ratio youself
to get the best results.
To get better result, we use the illumination of normal light image
to adaptively generate ratio.
Noted that KinD and KinD++ also use ratio to guide the illumination adjustment,
for fair comparison, the ratio of their methods also generate by the illumination
of normal light image.
"""
def one2three(x):
return torch.cat([x, x, x], dim=1).to(x)
class Inference(nn.Module):
def __init__(self, opts):
super().__init__()
self.opts = opts
# loading decomposition model
self.model_Decom_low = Decom()
self.model_Decom_high = Decom()
self.model_Decom_low = load_initialize(self.model_Decom_low, self.opts.Decom_model_low_path)
self.model_Decom_high = load_initialize(self.model_Decom_high, self.opts.Decom_model_high_path)
# loading R; old_model_opts; and L model
self.unfolding_opts, self.model_R, self.model_L= load_unfolding(self.opts.unfolding_model_path)
# loading adjustment model
self.adjust_model = load_adjustment(self.opts.adjust_model_path)
self.P = P()
self.Q = Q()
transform = [
transforms.ToTensor(),
]
self.transform = transforms.Compose(transform)
print(self.model_Decom_low)
print(self.model_R)
print(self.model_L)
print(self.adjust_model)
#time.sleep(8)
def get_ratio(self, high_l, low_l):
ratio = (low_l / (high_l + 0.0001)).mean()
low_ratio = torch.ones(high_l.shape).cuda() * (1/(ratio+0.0001))
return low_ratio
def unfolding(self, input_low_img):
for t in range(self.unfolding_opts.round):
if t == 0: # initialize R0, L0
P, Q = self.model_Decom_low(input_low_img)
else: # update P and Q
w_p = (self.unfolding_opts.gamma + self.unfolding_opts.Roffset * t)
w_q = (self.unfolding_opts.lamda + self.unfolding_opts.Loffset * t)
P = self.P(I=input_low_img, Q=Q, R=R, gamma=w_p)
Q = self.Q(I=input_low_img, P=P, L=L, lamda=w_q)
R = self.model_R(r=P, l=Q)
L = self.model_L(l=Q)
return R, L
def lllumination_adjust(self, L, ratio):
ratio = torch.ones(L.shape).cuda() * ratio
return self.adjust_model(l=L, alpha=ratio)
def forward(self, input_low_img, input_high_img):
if torch.cuda.is_available():
input_low_img = input_low_img.cuda()
input_high_img = input_high_img.cuda()
with torch.no_grad():
start = time.time()
R, L = self.unfolding(input_low_img)
# the ratio is calculated using the decomposed normal illumination
_, high_L = self.model_Decom_high(input_high_img)
ratio = self.get_ratio(high_L, L)
High_L = self.lllumination_adjust(L, ratio)
I_enhance = High_L * R
p_time = (time.time() - start)
return I_enhance, p_time
def evaluate(self):
low_files = glob.glob(self.opts.low_dir+"/*.png")
for file in low_files:
file_name = os.path.basename(file)
name = file_name.split('.')[0]
high_file = os.path.join(self.opts.high_dir, file_name)
low_img = self.transform(Image.open(file)).unsqueeze(0)
high_img = self.transform(Image.open(high_file)).unsqueeze(0)
enhance, p_time = self.forward(low_img, high_img)
if not os.path.exists(self.opts.output):
os.makedirs(self.opts.output)
save_path = os.path.join(self.opts.output, file_name.replace(name, "%s_URetinexNet"%(name)))
np_save_TensorImg(enhance, save_path)
print("================================= time for %s: %f============================"%(file_name, p_time))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Configure')
# specify your data path here!
parser.add_argument('--low_dir', type=str, default="./test_daat/LOLdataset/eval15/low")
parser.add_argument('--high_dir', type=str, default="./test_data/LOLdataset/eval15/high")
parser.add_argument('--output', type=str, default="./demo/output/LOL")
# ratio are recommended to be 3-5, bigger ratio will lead to over-exposure
# model path
parser.add_argument('--Decom_model_low_path', type=str, default="./ckpt/init_low.pth")
parser.add_argument('--Decom_model_high_path', type=str, default="./ckpt/init_high.pth")
parser.add_argument('--unfolding_model_path', type=str, default="./ckpt/unfolding.pth")
parser.add_argument('--adjust_model_path', type=str, default="./ckpt/L_adjust.pth")
parser.add_argument('--gpu_id', type=int, default=0)
opts = parser.parse_args()
for k, v in vars(opts).items():
print(k, v)
os.environ['CUDA_VISIBLE_DEVICES'] = str(opts.gpu_id)
model = Inference(opts).cuda()
model.evaluate()