-
Notifications
You must be signed in to change notification settings - Fork 17
/
main.py
168 lines (130 loc) · 5.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import sys
import time
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import utils as vutils
import utils
from dataset import Dataset
from model import Resnet_Unet as model
#6.9定稿版本
###########
#可调整的训练超参数
batch_size = 16
val_batch_size = 2
lr =1e-3
start_epoch = 0
stop_epoch = 10
###########
###########
#可调整的路径参数
title = 'ResNet_final'
path = '/mnt/diskarray/fj/ResNet-Unet/'
data_path = path+'cut_imgs/'
val_path = path+'imgs/train/'
Model_path = path+'log/checkpoints/'+title+'/4.pth'
###########
###########
#可调整的训练相关处理
pretrain = True
multi_GPU = False
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
save_step = 10 #决定多少次保存一次可视化结果
transform = transforms.Compose([
#transforms.Resize((256,256)),
transforms.ToTensor()
])
###########
###########
#无需调整的路径参数
log_path = path+'log/'
checkpoints_path = path+'log/checkpoints/'+title+'/'
tensorboard_path = path+'log/tensorboard/'+title+'/'
visualize_path = path+'log/visualize/'+title+'/'
###########
utils.path_checker(log_path)
utils.path_checker(checkpoints_path)
utils.path_checker(tensorboard_path)
utils.path_checker(visualize_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Writer = SummaryWriter(tensorboard_path)
train_set = Dataset(path=data_path, transform=transform, mode='train')
val_set = Dataset(path=data_path, transform=transform, mode='validation', val_path=val_path)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=val_batch_size, shuffle=False)
Model = model(BN_enable=True, resnet_pretrain=False).to(device)
if pretrain:
Model.load_state_dict(torch.load(Model_path))
criterion = nn.BCELoss().to(device)
optimizer = torch.optim.Adam(Model.parameters(),lr=lr)
#optimizer = torch.optim.SGD(Model.parameters(),lr=lr,momentum=0.9,weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)
###########
#开始训练
for epoch in range(start_epoch, stop_epoch):
scheduler.step()
batch_sum = len(train_loader)
#训练部分
for index, (img,mask) in enumerate(train_loader):
img = img.to(device)
mask = mask.to(device)
Model.train()
Model.zero_grad()
output = Model(img)
loss = criterion(output, mask)
loss.backward()
optimizer.step()
if index%save_step==0:
output_img = vutils.make_grid(output[0,:,:,:], padding=0, normalize=True, range=(0,255))
output_tmp = torch.ge(output, 0.5).mul(255)
Writer.add_scalar('scalar/loss', loss, index)
Writer.add_image('image/input', img[0,:,:,:])
Writer.add_image('image/mask', mask[0,:,:,:])
Writer.add_image('image/predict', output_tmp[0,:,:,:])
Writer.add_image('image/output', output[0,:,:,:])
sys.stdout.write("\r[Train] [Epoch {}/{}] [Batch {}/{}] [loss:{:.8f}] [learning rate:{}]".format(epoch+1, stop_epoch, index+1, batch_sum, loss.item(), optimizer.param_groups[0]['lr']))
sys.stdout.flush()
#保存权重,每个epoch进行一次保存
torch.save(Model.state_dict(), checkpoints_path+'{}.pth'.format(epoch+1))
#验证部分
DSC_sum = 0
PPV_sum = 0
Sen_sum = 0
batch_sum = 0
for index, (img,mask) in enumerate(val_loader):
Model.eval()
DSC = 0
PPV = 0
Sen = 0
batch = 0
img = img.to(device)
mask = mask.to(device)
with torch.no_grad():
output = Model(img)
output = torch.ge(output, 0.5).type(dtype=torch.float32) #二值化
output = utils.post_process(output) #后处理
DSC ,PPV, Sen, batch = utils.analysis(output,mask)
DSC_sum += DSC*batch
PPV_sum += PPV*batch
Sen_sum += Sen*batch
batch_sum += batch
if index%save_step==0:
img_list = [
img[0,:,:,:],
output[0,:,:,:],
mask[0,:,:,:]
]
img_visualize = vutils.make_grid(img_list)
visualize_img_path = visualize_path+str(epoch)+'_'+str(index+1)+'.tif'
vutils.save_image(img_visualize, visualize_img_path)
sys.stdout.write("\r[Val] [Epoch {}/{}] [Batch {}/{}] [DSC:{:.5f}] [PPV:{:.5f}] [Sen:{:.5f}]".format(epoch+1, stop_epoch, index+1, len(val_loader), DSC, PPV, Sen))
sys.stdout.flush()
DSC_sum /= batch_sum
PPV_sum /= batch_sum
Sen_sum /= batch_sum
with open(log_path+title+'.txt','a') as f:
f.write('{}\t{:.5f}\t{:.5f}\t{:.5f}\n'.format(epoch+1, DSC_sum, PPV_sum, Sen_sum))