Skip to content

Commit

Permalink
base4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
wangwb committed Feb 21, 2024
1 parent e6450f3 commit 37d604a
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 91 deletions.
6 changes: 3 additions & 3 deletions inference_pretrain_unet_3d.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

branch_name=unet_3d # pure-stepnet
branch_name=stepnet_patch32_d4g2
iter='10000000000' # nan Appoint in terminal

# 从命令行获取参数
Expand All @@ -20,8 +20,8 @@ done

CUDA_VISIBLE_DEVICES=6 python scripts/inference_pretrained.py \
--expname "${branch_name}_net_g_${iter}" --num_io_consumer 1\
-i /home/wangwb/workspace/sr_3dunet/datasets/rotated_blocks/val_rotated_small\
-i /home/wangwb/workspace/sr_3dunet/datasets/rotated_blocks/val_rotated_big\
-o /home/wangwb/workspace/sr_3dunet/results/${branch_name}_net_g_${iter}\
--model_path /home/wangwb/workspace/sr_3dunet/experiments/${branch_name}/models/net_g_A_${iter}.pth\
--model_back_path /home/wangwb/workspace/sr_3dunet/experiments/${branch_name}/models/net_g_B_${iter}.pth\
--piece_flag True --piece_size 128 --overlap 16 --step_size 16 --rotated_flag False
--piece_flag True --piece_size 128 --overlap 16 --step_size 16 --rotated_flag True
39 changes: 24 additions & 15 deletions options/stepnet.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# general settings
name: pure-stepnet
name: stepnet_patch32_d4g2
model_type: StepNet_Model
num_gpu: 4 # set num_gpu: 0 for cpu mode
num_gpu: auto # set num_gpu: 0 for cpu mode
manual_seed: 0

# # USM the ground-truth
Expand All @@ -23,13 +23,13 @@ datasets:
use_flip: true
use_rot: true
gt_size: [32, 64] # fullsize is 128
gt_probs: [0, 1]
gt_probs: [1, 0]
iso_dimension: -1 # -3/-2/-1 <--> 0/1/2, means anisotropic in dimension 2

# data loader
use_shuffle6: true
num_worker_per_gpu: 6
batch_size_per_gpu: 2 # 4 # 16
batch_size_per_gpu: 16
dataset_enlarge_ratio: 1
prefetch_mode: ~
pin_memory: True
Expand All @@ -40,32 +40,39 @@ network_g_A:
in_channels: 1
out_channels: 1
features: [64, 128, 256, 512]
norm_type: instance
# norm_type: instance
dim: 3

network_g_B:
type: UNet_3d_Generator
in_channels: 1
out_channels: 1
features: [64, 128, 256, 512]
norm_type: instance
# norm_type: instance
dim: 3

network_d_A:
network_d_proj:
type: ProjectionDiscriminator
in_channels: 1
features: [64, 128, 256]
norm_type: 'batch'
dim: 2

network_d_iso:
type: CubeDiscriminator
in_channels: 1
features: [64, 128, 256]
norm_type: 'batch'
dim: 3

network_d_proj:
type: ProjectionDiscriminator
network_d_A1:
type: CubeDiscriminator
in_channels: 1
features: [64, 128, 256]
norm_type: 'batch'
dim: 2
dim: 3

network_d_iso:
network_d_A2:
type: CubeDiscriminator
in_channels: 1
features: [64, 128, 256]
Expand All @@ -74,10 +81,12 @@ network_d_iso:

# path
path:
pretrain_network_g_A: ~ # /home/wangwb/workspace/sr_3dunet/experiments/one_step_mixsize!!save/models/net_g_A_12000.pth
pretrain_network_g_B: ~ # /home/wangwb/workspace/sr_3dunet/experiments/one_step_mixsize!!save/models/net_g_B_12000.pth
pretrain_network_d_proj: ~
pretrain_network_d_iso: ~
pretrain_network_g_A: /home/wangwb/workspace/sr_3dunet/experiments/stepnet_archived_20240221_000600_archived_20240221_123247/models/net_g_A_40000.pth
pretrain_network_g_B: /home/wangwb/workspace/sr_3dunet/experiments/stepnet_archived_20240221_000600_archived_20240221_123247/models/net_g_B_40000.pth
pretrain_network_d_proj: /home/wangwb/workspace/sr_3dunet/experiments/stepnet_archived_20240221_000600_archived_20240221_123247/models/net_d_proj_40000.pth
pretrain_network_d_iso: /home/wangwb/workspace/sr_3dunet/experiments/stepnet_archived_20240221_000600_archived_20240221_123247/models/net_d_iso_40000.pth
pretrain_network_d_A1: /home/wangwb/workspace/sr_3dunet/experiments/stepnet_archived_20240221_000600_archived_20240221_123247/models/net_d_iso_40000.pth
pretrain_network_d_A2: /home/wangwb/workspace/sr_3dunet/experiments/stepnet_archived_20240221_000600_archived_20240221_123247/models/net_d_iso_40000.pth
param_key_g_A: params
param_key_g_B: params
strict_load_g: true
Expand Down
8 changes: 4 additions & 4 deletions options/train_unet_3d.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# general settings
name: unet_3d
name: unet_3d_old
model_type: Unet_3D_old
num_gpu: 4 # set num_gpu: 0 for cpu mode
manual_seed: 0
Expand Down Expand Up @@ -40,15 +40,15 @@ network_g_A:
in_channels: 1
out_channels: 1
features: [64, 128, 256, 512]
norm_type: instance
# norm_type: instance
dim: 3

network_g_B:
type: UNet_3d_Generator
in_channels: 1
out_channels: 1
features: [64, 128, 256, 512]
norm_type: instance
# norm_type: instance
dim: 3

network_d_A:
Expand Down Expand Up @@ -162,7 +162,7 @@ train:
# logging settings
logger:
print_freq: 10
save_checkpoint_freq: !!float 500
save_checkpoint_freq: !!float 200
use_tb_logger: true
wandb:
project: ~
Expand Down
4 changes: 2 additions & 2 deletions scripts/inference_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def remove_outer_layer(matrix, overlap):
def get_inference_model(args, device) -> UNet_3d_Generator:
"""return an on device model with eval mode"""
# set up model
model = UNet_3d_Generator(in_channels=1, out_channels=1, features=[64, 128, 256, 512], norm_type='instance', dim=3)
model_back = UNet_3d_Generator(in_channels=1, out_channels=1, features=[64, 128, 256, 512], norm_type='instance', dim=3)
model = UNet_3d_Generator(in_channels=1, out_channels=1, features=[64, 128, 256, 512], dim=3)
model_back = UNet_3d_Generator(in_channels=1, out_channels=1, features=[64, 128, 256, 512], dim=3)

model_path = args.model_path
model_back_path = args.model_back_path
Expand Down
4 changes: 2 additions & 2 deletions sr_3dunet/archs/unet_3d_generator_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def __init__(self, in_channels, out_channels, *, norm_type='batch', dim=3):

self.conv = nn.Sequential(
Conv(in_channels, out_channels, kernel_size=3, padding=1, bias=use_bias),
norm_layer(out_channels),
# norm_layer(out_channels),
nn.ReLU(inplace=True),
Conv(out_channels, out_channels, kernel_size=3, padding=1, bias=use_bias),
norm_layer(out_channels),
# norm_layer(out_channels),
nn.ReLU(inplace=True)
)

Expand Down
152 changes: 87 additions & 65 deletions sr_3dunet/models/stepnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,42 @@ def init_training_settings(self):
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()

# define network net_d
self.net_d_proj = build_network(self.opt['network_d_proj'])
self.net_d_proj = self.model_to_device(self.net_d_proj)
self.net_d_iso = build_network(self.opt['network_d_iso'])
self.net_d_iso = self.model_to_device(self.net_d_iso)
def define_load_network(opt_name):
# define network net_d
net_d = build_network(opt_name)
net_d = self.model_to_device(net_d)

# load pretrained models
load_path = self.opt['path'].get('opt_name', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_d', 'params')
self.load_network(net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
net_d.train()
return net_d

self.net_d_proj = define_load_network(self.opt['network_d_proj'])
self.net_d_iso = define_load_network(self.opt['network_d_iso'])
self.net_d_A1 = define_load_network(self.opt['network_d_A1'])
self.net_d_A2 = define_load_network(self.opt['network_d_A2'])

# self.net_d_proj = build_network(self.opt['network_d_proj'])
# self.net_d_proj = self.model_to_device(self.net_d_proj)
# self.net_d_iso = build_network(self.opt['network_d_iso'])
# self.net_d_iso = self.model_to_device(self.net_d_iso)

# load pretrained models
load_path = self.opt['path'].get('network_d_proj', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_d', 'params')
self.load_network(self.net_d_proj, load_path, self.opt['path'].get('strict_load_d', True), param_key)
load_path = self.opt['path'].get('pretrain_network_d_iso', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_d', 'params')
self.load_network(self.net_d_iso, load_path, self.opt['path'].get('strict_load_d', True), param_key)
# load_path = self.opt['path'].get('network_d_proj', None)
# if load_path is not None:
# param_key = self.opt['path'].get('param_key_d', 'params')
# self.load_network(self.net_d_proj, load_path, self.opt['path'].get('strict_load_d', True), param_key)
# load_path = self.opt['path'].get('pretrain_network_d_iso', None)
# if load_path is not None:
# param_key = self.opt['path'].get('param_key_d', 'params')
# self.load_network(self.net_d_iso, load_path, self.opt['path'].get('strict_load_d', True), param_key)

self.net_g_A.train()
self.net_g_B.train()
self.net_d_proj.train()
self.net_d_iso.train()
# self.net_g_A.train()
# self.net_g_B.train()
# self.net_d_proj.train()
# self.net_d_iso.train()

# define losses
if train_opt.get('projection_opt'):
Expand Down Expand Up @@ -133,7 +149,7 @@ def setup_optimizers(self):
self.optimizers.append(self.optimizer_g)
# optimizer d
optim_type = train_opt['optim_d'].pop('type')
self.optimizer_d = self.get_optimizer(optim_type, itertools.chain(self.net_d_proj.parameters(),self.net_d_iso.parameters()), **train_opt['optim_d'])
self.optimizer_d = self.get_optimizer(optim_type, itertools.chain(self.net_d_proj.parameters(),self.net_d_iso.parameters(),self.net_d_A1.parameters(),self.net_d_A2.parameters()), **train_opt['optim_d'])
self.optimizers.append(self.optimizer_d)

def feed_data(self, data):
Expand All @@ -155,60 +171,57 @@ def optimize_parameters(self, current_iter):
self.affine_fakeB = affine_img(self.fakeB)

self.recA1 = self.net_g_B(self.fakeB)
# self.fakeC = self.net_g_B(self.affine_fakeB)
# self.recA2 = self.net_g_B(affine_img(self.net_g_A(self.fakeC)))
self.fakeC = self.net_g_B(self.affine_fakeB)
self.recA2 = self.net_g_B(affine_img(self.net_g_A(self.fakeC)))

# get iso and aniso projection arrays
input_iso_proj, input_aiso_proj0, input_aiso_proj1 = get_projection(self.realA, iso_dimension)
output_iso_proj, output_aiso_proj0, output_aiso_proj1 = get_projection(self.fakeB, iso_dimension)
aiso_proj_index = random.choice(['0', '1'])
proj_index = random.choice(['0', '1'])
match = lambda x: {
'0': (input_aiso_proj0, output_aiso_proj0),
'1': (input_aiso_proj1, output_aiso_proj1)
}.get(x, ('error0', 'error1'))
input_aiso_proj, output_aiso_proj = match(aiso_proj_index) # random.choice([output_aiso_proj0, output_aiso_proj1])
input_aiso_proj, output_aiso_proj = match(proj_index) # random.choice([output_aiso_proj0, output_aiso_proj1])

l_total = 0
loss_dict = OrderedDict()
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
# cycle loss
if self.cri_cycle:
l_cycle1 = self.cri_cycle(self.realA, self.recA1) + self.cri_cycle_ssim(self.realA, self.recA1)
# l_cycle2 = self.cri_cycle(self.realA, self.recA2) + self.cri_cycle_ssim(self.realA, self.recA2)
l_total += l_cycle1 # + l_cycle2
l_cycle2 = self.cri_cycle(self.realA, self.recA2) + self.cri_cycle_ssim(self.realA, self.recA2)
l_total += l_cycle1 + l_cycle2
loss_dict['l_cycle1'] = l_cycle1
# loss_dict['l_cycle2'] = l_cycle2
loss_dict['l_cycle2'] = l_cycle2
# projection loss
if self.cri_projection:
l_iso_proj = self.cri_projection(output_iso_proj, input_iso_proj) + self.cri_projection_ssim(output_iso_proj, input_iso_proj)
l_total += l_iso_proj
loss_dict['l_iso_proj'] = l_iso_proj

# generator loss
# fakeB_g_anisoproj_pred = self.net_d_proj(output_aiso_proj)
# l_g_A = self.cri_gan(fakeB_g_anisoproj_pred, True, is_disc=False)
# l_total += l_g_A
# loss_dict['l_g_A'] = l_g_A
fakeB_g_anisoproj_pred = self.net_d_proj(output_aiso_proj)
l_g_A = self.cri_gan(fakeB_g_anisoproj_pred, True, is_disc=False)
l_total += l_g_A
loss_dict['l_g_A'] = l_g_A

recA1_g_pred = self.net_d_iso(self.recA1)
recA1_g_pred = self.net_d_A1(self.recA1)
l_g_recA1 = self.cri_gan(recA1_g_pred, True, is_disc=False)
l_total += l_g_recA1
loss_dict['l_g_recA1'] = l_g_recA1

l_total.backward()

# recA2_g_pred = self.net_d_iso(self.recA2)
# l_g_recA2 = self.cri_gan(recA2_g_pred, True, is_disc=False)
# l_total += l_g_recA2
# l_g_recA2.backward(retain_graph=True)
# loss_dict['l_g_recA2'] = l_g_recA2
recA2_g_pred = self.net_d_A2(self.recA2)
l_g_recA2 = self.cri_gan(recA2_g_pred, True, is_disc=False)
l_total += l_g_recA2
loss_dict['l_g_recA2'] = l_g_recA2

# fakeC_g_pred = self.net_d_iso(self.fakeC)
# l_g_A2C = self.cri_gan(fakeC_g_pred, True, is_disc=False)
# l_total += l_g_A2C
# l_g_A2C.backward()
# loss_dict['l_g_A2C'] = l_g_A2C
fakeC_g_pred = self.net_d_iso(self.fakeC)
l_g_A2C = self.cri_gan(fakeC_g_pred, True, is_disc=False)
l_total += l_g_A2C
loss_dict['l_g_A2C'] = l_g_A2C

l_total.backward()
loss_dict['l_total'] = l_total

self.optimizer_g.step()
Expand All @@ -222,33 +235,42 @@ def optimize_parameters(self, current_iter):
self.optimizer_d.zero_grad()
# discriminator loss
# real
# realB_g_anisoproj_pred = self.net_d_proj(input_iso_proj)
# l_d_real_B = self.cri_gan(realB_g_anisoproj_pred, True, is_disc=True)
# loss_dict['l_d_real_B'] = l_d_real_B
# l_d_real_B.backward()
realC_g_pred = self.net_d_iso(self.realA) # same as A
l_d_real_A2C = self.cri_gan(realC_g_pred, True, is_disc=True)
realB_d_anisoproj_pred = self.net_d_proj(input_iso_proj)
l_d_real_B = self.cri_gan(realB_d_anisoproj_pred, True, is_disc=True)
loss_dict['l_d_real_B'] = l_d_real_B
l_d_real_B.backward()
realC_d_pred = self.net_d_iso(self.realA) # same as A
l_d_real_A2C = self.cri_gan(realC_d_pred, True, is_disc=True)
loss_dict['l_d_real_A2C'] = l_d_real_A2C
l_d_real_A2C.backward()

recA1_d_pred = self.net_d_A1(self.realA)
l_d_real_recA1 = self.cri_gan(recA1_d_pred, True, is_disc=True)
loss_dict['l_d_real_recA1'] = l_d_real_recA1
l_d_real_recA1.backward()
recA2_d_pred = self.net_d_A2(self.realA)
l_d_real_recA2 = self.cri_gan(recA2_d_pred, True, is_disc=True)
loss_dict['l_d_real_recA2'] = l_d_real_recA2
l_d_real_recA2.backward()

# fake
# fakeB_d_anisoproj_pred = self.net_d_proj(output_aiso_proj.detach())
# l_d_fake_B = self.cri_gan(fakeB_d_anisoproj_pred, False, is_disc=False)
# loss_dict['l_d_fake_B'] = l_d_fake_B
# l_d_fake_B.backward()
# fakeC_d_pred = self.net_d_iso(self.fakeC.detach())
# l_d_fake_A2C = self.cri_gan(fakeC_d_pred, False, is_disc=False)
# loss_dict['l_d_fake_A2C'] = l_d_fake_A2C
# l_d_fake_A2C.backward()
fakeB_d_anisoproj_pred = self.net_d_proj(output_aiso_proj.detach())
l_d_fake_B = self.cri_gan(fakeB_d_anisoproj_pred, False, is_disc=False)
loss_dict['l_d_fake_B'] = l_d_fake_B
l_d_fake_B.backward()
fakeC_d_pred = self.net_d_iso(self.fakeC.detach())
l_d_fake_A2C = self.cri_gan(fakeC_d_pred, False, is_disc=False)
loss_dict['l_d_fake_A2C'] = l_d_fake_A2C
l_d_fake_A2C.backward()

recA1_d_pred = self.net_d_iso(self.recA1.detach())
l_d_recA1 = self.cri_gan(recA1_d_pred, False, is_disc=False)
loss_dict['l_d_recA1'] = l_d_recA1
l_d_recA1.backward()
# recA2_d_pred = self.net_d_iso(self.recA2.detach())
# l_d_recA2 = self.cri_gan(recA2_d_pred, False, is_disc=False)
# loss_dict['l_d_recA2'] = l_d_recA2
# l_d_recA2.backward()
recA1_d_pred = self.net_d_A1(self.recA1.detach())
l_d_fake_recA1 = self.cri_gan(recA1_d_pred, False, is_disc=False)
loss_dict['l_d_fake_recA1'] = l_d_fake_recA1
l_d_fake_recA1.backward()
recA2_d_pred = self.net_d_A2(self.recA2.detach())
l_d_fake_recA2 = self.cri_gan(recA2_d_pred, False, is_disc=False)
loss_dict['l_d_fake_recA2'] = l_d_fake_recA2
l_d_fake_recA2.backward()

self.optimizer_d.step()

Expand Down

0 comments on commit 37d604a

Please sign in to comment.