forked from JCruan519/VM-UNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine_synapse.py
105 lines (86 loc) · 3.6 KB
/
engine_synapse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import autocast as autocast
import torch
from sklearn.metrics import confusion_matrix
from scipy.ndimage.morphology import binary_fill_holes, binary_opening
from utils import test_single_volume
import time
def train_one_epoch(train_loader,
model,
criterion,
optimizer,
scheduler,
epoch,
logger,
config,
scaler=None):
'''
train model for one epoch
'''
stime = time.time()
model.train()
loss_list = []
for iter, data in enumerate(train_loader):
optimizer.zero_grad()
images, targets = data['image'], data['label']
images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float()
if config.amp:
with autocast():
out = model(images)
loss = criterion(out, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
out = model(images)
loss = criterion(out, targets)
loss.backward()
optimizer.step()
loss_list.append(loss.item())
now_lr = optimizer.state_dict()['param_groups'][0]['lr']
mean_loss = np.mean(loss_list)
if iter % config.print_interval == 0:
log_info = f'train: epoch {epoch}, iter:{iter}, loss: {loss.item():.4f}, lr: {now_lr}'
print(log_info)
logger.info(log_info)
scheduler.step()
etime = time.time()
log_info = f'Finish one epoch train: epoch {epoch}, loss: {mean_loss:.4f}, time(s): {etime-stime:.2f}'
print(log_info)
logger.info(log_info)
return mean_loss
def val_one_epoch(test_datasets,
test_loader,
model,
epoch,
logger,
config,
test_save_path,
val_or_test=False):
# switch to evaluate mode
stime = time.time()
model.eval()
with torch.no_grad():
metric_list = 0.0
i_batch = 0
for data in tqdm(test_loader):
img, msk, case_name = data['image'], data['label'], data['case_name'][0]
metric_i = test_single_volume(img, msk, model, classes=config.num_classes, patch_size=[config.input_size_h, config.input_size_w],
test_save_path=test_save_path, case=case_name, z_spacing=config.z_spacing, val_or_test=val_or_test)
metric_list += np.array(metric_i)
logger.info('idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name,
np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1]))
i_batch += 1
metric_list = metric_list / len(test_datasets)
performance = np.mean(metric_list, axis=0)[0]
mean_hd95 = np.mean(metric_list, axis=0)[1]
for i in range(1, config.num_classes):
logger.info('Mean class %d mean_dice %f mean_hd95 %f' % (i, metric_list[i-1][0], metric_list[i-1][1]))
performance = np.mean(metric_list, axis=0)[0]
mean_hd95 = np.mean(metric_list, axis=0)[1]
etime = time.time()
log_info = f'val epoch: {epoch}, mean_dice: {performance}, mean_hd95: {mean_hd95}, time(s): {etime-stime:.2f}'
print(log_info)
logger.info(log_info)
return performance, mean_hd95