Skip to content

Commit

Permalink
Fix fintune error
Browse files Browse the repository at this point in the history
  • Loading branch information
Janspiry committed Jun 1, 2022
1 parent 059c314 commit ed29b1c
Show file tree
Hide file tree
Showing 36 changed files with 59 additions and 81 deletions.
Empty file modified .gitignore
100644 → 100755
Empty file.
Empty file modified LICENSE
100644 → 100755
Empty file.
Empty file modified README.md
100644 → 100755
Empty file.
7 changes: 0 additions & 7 deletions config/colorization_mirflickr25k.json
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,6 @@
}
}
],
"which_optimizers": [ // len(networks) == len(optimizers) == len(lr_schedulers), it will be deleted after initialization if not used.
{ "name": "Adam", "args":{ "lr": 5e-5, "weight_decay": 0}}
],
"which_lr_schedulers": [ // {} represents None, it will be deleted after initialization.
{}
// { "name": "LinearLR", "args": { "start_factor": 0.2, "total_iters": 1e3 }}
],
"which_losses": [ // import designated list of losses without arguments
"mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}}
],
Expand Down
23 changes: 7 additions & 16 deletions config/inpainting_celebahq.json
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"tb_logger": "tb_logger", // path of tensorboard logger
"results": "results",
"checkpoint": "checkpoint",
// "resume_state": "experiments/train_inpainting_celebahq_220426_233652/checkpoint/200"
"resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration
"resume_state": "experiments/train_inpainting_celebahq_220426_233652/checkpoint/190"
// "resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration
},

"datasets": { // train or test
Expand Down Expand Up @@ -74,12 +74,10 @@
"ema_start": 1,
"ema_iter": 1,
"ema_decay": 0.9999
}
// "ema_scheduler": { // debug
// "ema_start": 0,
// "ema_iter": 10,
// "ema_decay": 0.9999
// }
},
"optimizers": [
{ "lr": 5e-5, "weight_decay": 0}
]
}
},
"which_networks": [ // import designated list of networks using arguments
Expand Down Expand Up @@ -126,13 +124,6 @@
}
}
],
"which_optimizers": [ // len(networks) == len(optimizers) == len(lr_schedulers), it will be deleted after initialization if not used.
{ "name": "Adam", "args":{ "lr": 5e-5, "weight_decay": 0}}
],
"which_lr_schedulers": [ // {} represents None, it will be deleted after initialization.
{}
// { "name": "LinearLR", "args": { "start_factor": 0.2, "total_iters": 1e3 }} // support in newest pytorch vision
],
"which_losses": [ // import designated list of losses without arguments
"mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}}
],
Expand All @@ -146,7 +137,7 @@
"n_iter": 1e8, // max interations
"val_epoch": 5, // valdation every specified number of epochs
"save_checkpoint_epoch": 10,
"log_iter": 1e4, // log every specified number of iterations
"log_iter": 1e3, // log every specified number of iterations
"tensorboard" : true // tensorboardX enable
},

Expand Down
17 changes: 4 additions & 13 deletions config/inpainting_places2.json
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,10 @@
"ema_start": 1,
"ema_iter": 1,
"ema_decay": 0.9999
}
// "ema_scheduler": { // debug
// "ema_start": 0,
// "ema_iter": 10,
// "ema_decay": 0.9999
// }
},
"optimizers": [
{ "lr": 5e-5, "weight_decay": 0}
]
}
},
"which_networks": [ // import designated list of networks using arguments
Expand Down Expand Up @@ -126,13 +124,6 @@
}
}
],
"which_optimizers": [ // len(networks) == len(optimizers) == len(lr_schedulers), it will be deleted after initialization if not used.
{ "name": "Adam", "args":{ "lr": 5e-5, "weight_decay": 0}}
],
"which_lr_schedulers": [ // {} represents None, it will be deleted after initialization.
{}
// { "name": "LinearLR", "args": { "start_factor": 0.2, "total_iters": 1e3 }}
],
"which_losses": [ // import designated list of losses without arguments
"mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}}
],
Expand Down
5 changes: 4 additions & 1 deletion config/uncropping_places2.json
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@
"ema_start": 1,
"ema_iter": 1,
"ema_decay": 0.9999
}
},
"optimizers": [
{ "lr": 5e-5, "weight_decay": 0}
]
}
},
"which_networks": [ // import designated list of networks using arguments
Expand Down
Empty file modified core/base_dataset.py
100644 → 100755
Empty file.
42 changes: 30 additions & 12 deletions core/base_model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn


import core.util as Util
CustomResult = collections.namedtuple('CustomResult', 'name result')

Expand All @@ -16,6 +17,10 @@ def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer):
self.phase = opt['phase']
self.set_device = partial(Util.set_device, rank=opt['global_rank'])

''' optimizers and schedulers '''
self.schedulers = []
self.optimizers = []

''' process record '''
self.batch_size = self.opt['datasets'][self.phase]['dataloader']['args']['batch_size']
self.epoch = 0
Expand Down Expand Up @@ -53,7 +58,7 @@ def train(self):
if self.epoch % self.opt['train']['val_epoch'] == 0:
self.logger.info("\n\n\n------------------------------Validation Start------------------------------")
if self.val_loader is None:
self.logger.info('Validation stop where dataloader is None, Skip it.')
self.logger.warning('Validation stop where dataloader is None, Skip it.')
else:
val_log = self.val_step()
for key, value in val_log.items():
Expand Down Expand Up @@ -103,44 +108,57 @@ def save_network(self, network, network_label):
def load_network(self, network, network_label, strict=True):
if self.opt['path']['resume_state'] is None:
return
self.logger.info('Beign loading pretrained model [{:s}] ...'.format(network_label))

model_path = "{}_{}.pth".format(self. opt['path']['resume_state'], network_label)

if not os.path.exists(model_path):
self.logger.warning('Pretrained model in [{:s}] is not existed, Skip it'.format(model_path))
return

self.logger.info('Loading pretrained model from [{:s}] ...'.format(model_path))
if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel):
network = network.module
network.load_state_dict(torch.load(model_path, map_location = lambda storage, loc: Util.set_device(storage)), strict=strict)

def save_training_state(self, optimizers, schedulers):
def save_training_state(self):
""" saves training state during training, only work on GPU 0 """
if self.opt['global_rank'] !=0:
return
assert isinstance(optimizers, list) and isinstance(schedulers, list), 'optimizers and schedulers must be a list.'
assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.'
state = {'epoch': self.epoch, 'iter': self.iter, 'schedulers': [], 'optimizers': []}
for s in schedulers:
for s in self.schedulers:
state['schedulers'].append(s.state_dict())
for o in optimizers:
for o in self.optimizers:
state['optimizers'].append(o.state_dict())
save_filename = '{}.state'.format(self.epoch)
save_path = os.path.join(self.opt['path']['checkpoint'], save_filename)
torch.save(state, save_path)

def resume_training(self, optimizers, schedulers):
def resume_training(self):
""" resume the optimizers and schedulers for training, only work when phase is test or resume training enable """
if self.phase!='train' or self. opt['path']['resume_state'] is None:
return
assert isinstance(optimizers, list) and isinstance(schedulers, list), 'optimizers and schedulers must be a list.'
self.logger.info('Beign loading training states'.format())
assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.'

state_path = "{}.state".format(self. opt['path']['resume_state'])

if not os.path.exists(state_path):
self.logger.warning('Training state in [{:s}] is not existed, Skip it'.format(state_path))
return

self.logger.info('Loading training state for [{:s}] ...'.format(state_path))
resume_state = torch.load(state_path, map_location = lambda storage, loc: self.set_device(storage))
# resume_state = torch.load(state_path)

resume_optimizers = resume_state['optimizers']
resume_schedulers = resume_state['schedulers']
assert len(resume_optimizers) == len(optimizers), 'Wrong lengths of optimizers {} != {}'.format(len(resume_optimizers), len(optimizers))
assert len(resume_schedulers) == len(schedulers), 'Wrong lengths of schedulers {} != {}'.format(len(resume_schedulers), len(schedulers))
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers {} != {}'.format(len(resume_optimizers), len(self.optimizers))
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers {} != {}'.format(len(resume_schedulers), len(self.schedulers))
for i, o in enumerate(resume_optimizers):
optimizers[i].load_state_dict(o)
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
schedulers[i].load_state_dict(s)
self.schedulers[i].load_state_dict(s)

self.epoch = resume_state['epoch']
self.iter = resume_state['iter']
Expand Down
Empty file modified core/base_network.py
100644 → 100755
Empty file.
Empty file modified core/logger.py
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion core/praser.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def init_obj(opt, logger, *args, default_file_name='default file', given_module=
''' default format is dict with name key '''
if isinstance(opt, str):
opt = {'name': opt}
logger.info('Config is a str, converts to a dict {}'.format(opt))
logger.warning('Config is a str, converts to a dict {}'.format(opt))

name = opt['name']
''' name can be list, indicates the file and class name of function '''
Expand Down
Empty file modified core/util.py
100644 → 100755
Empty file.
Empty file modified data/__init__.py
100644 → 100755
Empty file.
5 changes: 2 additions & 3 deletions data/dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import torch
import numpy as np
import cv2

from .util.mask import (bbox2mask, brush_stroke_mask, get_irregular_mask, random_bbox, random_cropping_bbox)

Expand Down Expand Up @@ -62,7 +61,7 @@ def __getitem__(self, index):
ret['cond_image'] = cond_image
ret['mask_image'] = mask_img
ret['mask'] = mask
ret['path'] = path.rsplit("/")[-1]
ret['path'] = path.rsplit("/")[-1].rsplit("\\")[-1]
return ret

def __len__(self):
Expand Down Expand Up @@ -119,7 +118,7 @@ def __getitem__(self, index):
ret['cond_image'] = cond_image
ret['mask_image'] = mask_img
ret['mask'] = mask
ret['path'] = path.rsplit("/")[-1]
ret['path'] = path.rsplit("/")[-1].rsplit("\\")[-1]
return ret

def __len__(self):
Expand Down
Empty file modified data/util/auto_augment.py
100644 → 100755
Empty file.
Empty file modified data/util/mask.py
100644 → 100755
Empty file.
Empty file modified eval.py
100644 → 100755
Empty file.
Empty file modified misc/Palette Image-to-Image Diffusion Models.pdf
100644 → 100755
Empty file.
Empty file modified misc/image/Process_02323.jpg
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified misc/image/Process_26190.jpg
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified misc/image/Process_Places365_test_00124460.jpg
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified misc/image/Process_Places365_test_00157365.jpg
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified misc/image/Process_Places365_test_00278428.jpg
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 0 additions & 6 deletions models/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
from core.praser import init_obj

def create_model(**cfg_model):
Expand Down Expand Up @@ -28,8 +27,3 @@ def define_loss(logger, loss_opt):
def define_metric(logger, metric_opt):
return init_obj(metric_opt, logger, default_file_name='models.metric', init_type='Metric')

def define_optimizer(networks, logger, optimizer_opt):
return init_obj(optimizer_opt, logger, networks, given_module=torch.optim, init_type='Optimizer')

def define_scheduler(optimizers, logger, scheduler_opt):
return init_obj(scheduler_opt, logger, optimizers, given_module=torch.optim.lr_scheduler, init_type='Scheduler')
Empty file modified models/guided_diffusion_modules/nn.py
100644 → 100755
Empty file.
Empty file modified models/guided_diffusion_modules/unet.py
100644 → 100755
Empty file.
Empty file modified models/loss.py
100644 → 100755
Empty file.
Empty file modified models/metric.py
100644 → 100755
Empty file.
20 changes: 10 additions & 10 deletions models/model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def update_average(self, old, new):
return old * self.beta + (1 - self.beta) * new

class Palette(BaseModel):
def __init__(self, networks, optimizers, lr_schedulers, losses, sample_num, task, ema_scheduler=None, **kwargs):
def __init__(self, networks, losses, sample_num, task, optimizers, ema_scheduler=None, **kwargs):

This comment has been minimized.

Copy link
@luislofer89

luislofer89 Jun 20, 2022

Current status of main branch is providing an error when calling the constructor of Palette model.

''' must to init BaseModel with kwargs '''
super(Palette, self).__init__(**kwargs)

Expand All @@ -30,16 +30,17 @@ def __init__(self, networks, optimizers, lr_schedulers, losses, sample_num, task
self.EMA = EMA(beta=self.ema_scheduler['ema_decay'])
else:
self.ema_scheduler = None
''' ddp '''

''' networks can be a list, and must convert by self.set_device function if using multiple GPU. '''
self.netG = self.set_device(self.netG, distributed=self.opt['distributed'])
if self.ema_scheduler is not None:
self.netG_EMA = self.set_device(self.netG_EMA, distributed=self.opt['distributed'])

self.schedulers = lr_schedulers
self.optG = optimizers[0]
self.load_networks()

self.optG = torch.optim.Adam(list(filter(lambda p: p.requires_grad, self.netG.parameters())), **optimizers[0])
self.optimizers.append(self.optG)
self.resume_training()

''' networks can be a list, and must convert by self.set_device function if using multiple GPU. '''
self.load_everything()
if self.opt['distributed']:
self.netG.module.set_loss(self.loss_fn)
self.netG.module.set_new_noise_schedule(phase=self.phase)
Expand Down Expand Up @@ -188,7 +189,7 @@ def test(self):
self.writer.add_images(key, value)
self.writer.save_images(self.save_current_results())

def load_everything(self):
def load_networks(self):
""" save pretrained model and training state, which only do on GPU 0. """
if self.opt['distributed']:
netG_label = self.netG.module.__class__.__name__
Expand All @@ -197,7 +198,6 @@ def load_everything(self):
self.load_network(network=self.netG, network_label=netG_label, strict=False)
if self.ema_scheduler is not None:
self.load_network(network=self.netG_EMA, network_label=netG_label+'_ema', strict=False)
self.resume_training([self.optG], self.schedulers)

def save_everything(self):
""" load pretrained model and training state, optimizers and schedulers must be a list. """
Expand All @@ -208,4 +208,4 @@ def save_everything(self):
self.save_network(network=self.netG, network_label=netG_label)
if self.ema_scheduler is not None:
self.save_network(network=self.netG_EMA, network_label=netG_label+'_ema')
self.save_training_state([self.optG], self.schedulers)
self.save_training_state()
Empty file modified models/network.py
100644 → 100755
Empty file.
Empty file modified models/sr3_modules/unet.py
100644 → 100755
Empty file.
Empty file modified preprocess/mirflickr25k_preprocess.py
100644 → 100755
Empty file.
Empty file modified requirements.txt
100644 → 100755
Empty file.
13 changes: 1 addition & 12 deletions run.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import core.praser as Praser
import core.util as Util
from data import define_dataloader
from models import create_model, define_network, define_loss, define_metric, define_optimizer, define_scheduler
from models import create_model, define_network, define_loss, define_metric

def main_worker(gpu, ngpus_per_node, opt):
""" threads running on each GPU """
Expand Down Expand Up @@ -41,22 +41,11 @@ def main_worker(gpu, ngpus_per_node, opt):
metrics = [define_metric(phase_logger, item_opt) for item_opt in opt['model']['which_metrics']]
losses = [define_loss(phase_logger, item_opt) for item_opt in opt['model']['which_losses']]

trian_params = [list(filter(lambda p: p.requires_grad, network.parameters())) for network in networks]
optimizers = [define_optimizer(trian_params[_idx], phase_logger, item_opt)
for _idx, item_opt in enumerate(opt['model']['which_optimizers'])]
optimizers = [optimizer for optimizer in optimizers if optimizer is not None]

lr_schedulers = [define_scheduler(optimizers[_idx], phase_logger, item_opt)
for _idx, item_opt in enumerate(opt['model']['which_lr_schedulers'])]
lr_schedulers = [lr_scheduler for lr_scheduler in lr_schedulers if lr_scheduler is not None]

model = create_model(
opt = opt,
networks = networks,
phase_loader = phase_loader,
val_loader = val_loader,
optimizers = optimizers,
lr_schedulers = lr_schedulers,
losses = losses,
metrics = metrics,
logger = phase_logger,
Expand Down
Empty file modified slurm/inpainting_places2.slurm
100644 → 100755
Empty file.

0 comments on commit ed29b1c

Please sign in to comment.