forked from YanchaoYang/FDA
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
174 lines (135 loc) · 5.76 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import torch.nn.functional as F
import numpy as np
from options.train_options import TrainOptions
from utils.timer import Timer
import os
from data import CreateSrcDataLoader
from data import CreateTrgDataLoader
from model import CreateModel
#import tensorboardX
import torch.backends.cudnn as cudnn
import torch
from torch.autograd import Variable
from utils import FDA_source_to_target
import scipy.io as sio
import wandb
IMG_MEAN = np.array((104.00698793, 116.66876762,
122.67891434), dtype=np.float32)
IMG_MEAN = torch.reshape(torch.from_numpy(IMG_MEAN), (1, 3, 1, 1))
CS_weights = np.array((1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0), dtype=np.float32)
CS_weights = torch.from_numpy(CS_weights)
wandb.login()
wandb.init(project="FDA-test")
#------------------------ This file is used for training of first round (T=0). Pseudo-label generation is not required. -------------------------------------------#
def main():
opt = TrainOptions()
args = opt.initialize()
os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
_t = {'iter time': Timer()}
model_name = args.source + '_to_' + args.target
if not os.path.exists(args.snapshot_dir):
os.makedirs(args.snapshot_dir)
os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
opt.print_options(args)
# Data-Loader Initialization
sourceloader, targetloader = CreateSrcDataLoader(
args), CreateTrgDataLoader(args)
sourceloader_iter, targetloader_iter = iter(
sourceloader), iter(targetloader)
# Model Initialization
model, optimizer = CreateModel(args)
start_iter = 0
if args.restore_from is not None:
# Used when restarting training from a checkpoint.
start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])
cudnn.enabled = True
cudnn.benchmark = True
model.train()
model.cuda()
wandb.watch(model, log='all', log_freq=1)
# losses to log
loss = ['loss_seg_src', 'loss_seg_trg']
loss_train = 0.0
loss_val = 0.0
loss_train_list = []
loss_val_list = []
mean_img = torch.zeros(1, 1)
# All class weights are taken to be 1.
class_weights = Variable(CS_weights).cuda()
_t['iter time'].tic() # Timer starts.
for i in range(start_iter, args.num_steps):
# adjust learning rate
model.adjust_learning_rate(args, optimizer, i)
# zero grad
optimizer.zero_grad()
src_img, src_lbl, _, _ = sourceloader_iter.next(
) # new batch source
trg_img, trg_lbl, _, _ = targetloader_iter.next(
) # new batch target
scr_img_copy = src_img.clone()
if mean_img.shape[-1] < 2:
B, C, H, W = src_img.shape
mean_img = IMG_MEAN.repeat(B, 1, H, W)
#-------------------------------------------------------------------#
# 1. source to target, target to target
src_in_trg = FDA_source_to_target(
src_img, trg_img, L=args.LB) # src_lbl
trg_in_trg = trg_img
# 2. subtract mean
# src, src_lbl
src_img = src_in_trg.clone() - mean_img
# trg, trg_lbl
trg_img = trg_in_trg.clone() - mean_img
#-------------------------------------------------------------------#
# evaluate and update params #####
src_img, src_lbl = Variable(src_img).cuda(), Variable(
src_lbl.long()).cuda() # to gpu
src_seg_score = model(
src_img, lbl=src_lbl, weight=class_weights, ita=args.ita) # forward pass
# get loss
loss_seg_src = model.loss_seg
loss_ent_src = model.loss_ent
# get target loss, only entropy for backpro
trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(
trg_lbl.long()).cuda() # to gpu
trg_seg_score = model(
trg_img, lbl=trg_lbl, weight=class_weights, ita=args.ita) # forward pass
# get loss
loss_seg_trg = model.loss_seg
loss_ent_trg = model.loss_ent
triger_ent = 0.0
if i > args.switch2entropy:
triger_ent = 1.0
loss_all = loss_seg_src + triger_ent * args.entW * \
loss_ent_trg # loss of seg on src, and ent on s and t
loss_all.backward()
optimizer.step()
loss_train += loss_seg_src.detach().cpu().numpy()
loss_val += loss_seg_trg.detach().cpu().numpy()
wandb.log({"src seg loss": loss_seg_src.data, "trg seg loss": loss_seg_trg.data,
"learnign rate": optimizer.param_groups[0]['lr']*10000})
if (i+1) % args.save_pred_every == 0:
print('taking snapshot ...')
torch.save(model.state_dict(), os.path.join(
args.snapshot_dir, '%s_' % (args.source) + str(i+1) + '.pth'))
if (i+1) % args.print_freq == 0:
_t['iter time'].toc(average=False)
print('[it %d][src seg loss %.4f][trg seg loss %.4f][lr %.4f][%.2fs]' %
(i + 1, loss_seg_src.data, loss_seg_trg.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff))
sio.savemat(args.tempdata, {'src_img': src_img.cpu(
).numpy(), 'trg_img': trg_img.cpu().numpy()})
loss_train /= args.print_freq
loss_val /= args.print_freq
loss_train_list.append(loss_train)
loss_val_list.append(loss_val)
sio.savemat(args.matname, {
'loss_train': loss_train_list, 'loss_val': loss_val_list})
loss_train = 0.0
loss_val = 0.0
if i + 1 > args.num_steps_stop:
print('finish training')
break
_t['iter time'].tic()
if __name__ == '__main__':
main()