Skip to content

Commit

Permalink
Update trainer_hcl.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jxhuang0508 authored Dec 23, 2021
1 parent dd0cd63 commit c987026
Showing 1 changed file with 1 addition and 30 deletions.
31 changes: 1 addition & 30 deletions hcl_target/trainer_hcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def __init__(self, args):
new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
print('%s is loaded from pre-trained weight.\n'%i_parts[0:])
self.G.load_state_dict(new_params)

# hcl

self.G_ma = get_deeplab_v2(num_classes=self.num_classes, multi_level=True)
self.G_ma.load_state_dict(self.G.state_dict().copy())
self.G_source = get_deeplab_v2(num_classes=self.num_classes, multi_level=True)
Expand All @@ -109,7 +108,6 @@ def __init__(self, args):
if self.multi_gpu and args.sync_bn:
print("using apex synced BN")
self.G = apex.parallel.convert_syncbn_model(self.G)
# hcl
self.G_ma = apex.parallel.convert_syncbn_model(self.G_ma)
self.G_source = apex.parallel.convert_syncbn_model(self.G_source)

Expand All @@ -128,7 +126,6 @@ def __init__(self, args):
self.log_sm = torch.nn.LogSoftmax(dim = 1)
self.sm = torch.nn.Softmax(dim = 1)
self.G = self.G.cuda()
# hcl
self.G_ma = self.G_ma.cuda()
self.G_source = self.G_source.cuda()

Expand All @@ -147,7 +144,6 @@ def __init__(self, args):
# Name the FP16_Optimizer instance to replace the existing optimizer
assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1")
# hcl
self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1")
self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1")

Expand Down Expand Up @@ -175,32 +171,18 @@ def update_label(self, labels, prediction):
criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none')
#criterion = self.seg_loss
loss = criterion(prediction, labels)
print('original loss: %f'% self.seg_loss(prediction, labels) )
#mm = torch.median(loss)
loss_data = loss.data.cpu().numpy()
mm = np.percentile(loss_data[:], self.only_hard_label)
#print(m.data.cpu(), mm)
labels[loss < mm] = 255
return labels

def update_variance(self, labels, pred1, pred2):
criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none')
kl_distance = nn.KLDivLoss( reduction = 'none')
loss = criterion(pred1, labels)

#n, h, w = labels.shape
#labels_onehot = torch.zeros(n, self.num_classes, h, w)
#labels_onehot = labels_onehot.cuda()
#labels_onehot.scatter_(1, labels.view(n,1,h,w), 1)
variance = torch.sum(kl_distance(self.log_sm(pred1),self.sm(pred2)), dim=1)
exp_variance = torch.exp(-variance)
#variance = torch.log( 1 + (torch.mean((pred1-pred2)**2, dim=1)))
#torch.mean( kl_distance(self.log_sm(pred1),pred2), dim=1) + 1e-6
print(variance.shape)
print('variance mean: %.4f'%torch.mean(exp_variance[:]))
print('variance min: %.4f'%torch.min(exp_variance[:]))
print('variance max: %.4f'%torch.max(exp_variance[:]))
#loss = torch.mean(loss/variance) + torch.mean(variance)
loss = torch.mean(loss*exp_variance) + torch.mean(variance)
return loss

Expand All @@ -211,15 +193,6 @@ def update_consistency_his(self, labels, pred1, pred2, pred1_ma, pred2_ma, pred1

variance = torch.sum(kl_distance(self.log_sm(pred1_ma),self.sm(pred1_source)), dim=1)
exp_variance = torch.exp(-variance)
#variance = torch.log( 1 + (torch.mean((pred1-pred2)**2, dim=1)))
#torch.mean( kl_distance(self.log_sm(pred1),pred2), dim=1) + 1e-6
print(variance.shape)
print('variance mean: %.4f'%torch.mean(exp_variance[:]))
print('variance min: %.4f'%torch.min(exp_variance[:]))
print('variance max: %.4f'%torch.max(exp_variance[:]))
#loss = torch.mean(loss/variance) + torch.mean(variance)
# loss = torch.mean(loss*exp_variance) + torch.mean(variance)
# hcl
loss = torch.mean(loss*exp_variance)
return loss

Expand Down Expand Up @@ -252,7 +225,6 @@ def update_contrast_his(self, label_label_aug1, feature, labels, feature_ma, pre
return loss

def gen_update(self, images, images_t, labels, labels_t, i_iter, image_aug1, label_label_aug1):
# hcl update ma
for param_q, param_k in zip(self.G.parameters(), self.G_ma.parameters()):
param_k.data = param_k.data.clone() * 0.996 + param_q.data.clone() * (1. - 0.996)
for buffer_q, buffer_k in zip(self.G.buffers(), self.G_ma.buffers()):
Expand Down Expand Up @@ -304,7 +276,6 @@ def gen_update(self, images, images_t, labels, labels_t, i_iter, image_aug1, lab

loss += self.lambda_adv_target1 * loss_adv_target1 + self.lambda_adv_target2 * loss_adv_target2

# hcl
if i_iter % 300 == 0 and i_iter != 0:
print('record historical model ...')
for param_q, param_k, param_source in zip(self.G.parameters(), self.G_ma.parameters(), self.G_source.parameters()):
Expand Down

0 comments on commit c987026

Please sign in to comment.