Skip to content

Commit

Permalink
Fix index error
Browse files Browse the repository at this point in the history
  • Loading branch information
Janspiry committed Jul 25, 2022
1 parent feca17c commit e598b2e
Showing 1 changed file with 36 additions and 26 deletions.
62 changes: 36 additions & 26 deletions models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def set_input(self, data):
self.mask = self.set_device(data.get('mask'))
self.mask_image = data.get('mask_image')
self.path = data['path']
self.batch_size = len(data['path'])

def get_current_visuals(self, phase='train'):
dict = {
Expand Down Expand Up @@ -151,7 +152,7 @@ def val_step(self):

for met in self.metrics:
key = met.__name__
value = met(self.cond_image, self.output)
value = met(self.gt_image, self.output)
self.val_metrics.update(key, value)
self.writer.add_scalar(key, value)
for key, value in self.get_current_visuals(phase='val').items():
Expand All @@ -163,31 +164,40 @@ def val_step(self):
def test(self):
self.netG.eval()
self.test_metrics.reset()
for phase_data in tqdm.tqdm(self.phase_loader):
self.set_input(phase_data)
if self.opt['distributed']:
if self.task in ['inpainting','uncropping']:
self.output, self.visuals = self.netG.module.restoration(self.cond_image, y_t=self.cond_image,
y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num)
else:
self.output, self.visuals = self.netG.module.restoration(self.cond_image, sample_num=self.sample_num)
else:
if self.task in ['inpainting','uncropping']:
self.output, self.visuals = self.netG.restoration(self.cond_image, y_t=self.cond_image,
y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num)
with torch.no_grad():
for phase_data in tqdm.tqdm(self.phase_loader):
self.set_input(phase_data)
if self.opt['distributed']:
if self.task in ['inpainting','uncropping']:
self.output, self.visuals = self.netG.module.restoration(self.cond_image, y_t=self.cond_image,
y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num)
else:
self.output, self.visuals = self.netG.module.restoration(self.cond_image, sample_num=self.sample_num)
else:
self.output, self.visuals = self.netG.restoration(self.cond_image, sample_num=self.sample_num)

self.iter += self.batch_size
self.writer.set_iter(self.epoch, self.iter, phase='test')
for met in self.metrics:
key = met.__name__
value = met(self.cond_image, self.output)
self.val_metrics.update(key, value)
self.writer.add_scalar(key, value)
for key, value in self.get_current_visuals(phase='test').items():
self.writer.add_images(key, value)
self.writer.save_images(self.save_current_results())
if self.task in ['inpainting','uncropping']:
self.output, self.visuals = self.netG.restoration(self.cond_image, y_t=self.cond_image,
y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num)
else:
self.output, self.visuals = self.netG.restoration(self.cond_image, sample_num=self.sample_num)

self.iter += self.batch_size
self.writer.set_iter(self.epoch, self.iter, phase='test')
for met in self.metrics:
key = met.__name__
value = met(self.gt_image, self.output)
self.test_metrics.update(key, value)
self.writer.add_scalar(key, value)
for key, value in self.get_current_visuals(phase='test').items():
self.writer.add_images(key, value)
self.writer.save_images(self.save_current_results())

test_log = self.test_metrics.result()
''' save logged informations into log dict '''
test_log.update({'epoch': self.epoch, 'iters': self.iter})

''' print logged informations to the screen and tensorboard '''
for key, value in test_log.items():
self.logger.info('{:5s}: {}\t'.format(str(key), value))

def load_networks(self):
""" save pretrained model and training state, which only do on GPU 0. """
Expand All @@ -200,7 +210,7 @@ def load_networks(self):
self.load_network(network=self.netG_EMA, network_label=netG_label+'_ema', strict=False)

def save_everything(self):
""" load pretrained model and training state, optimizers and schedulers must be a list. """
""" load pretrained model and training state. """
if self.opt['distributed']:
netG_label = self.netG.module.__class__.__name__
else:
Expand Down

0 comments on commit e598b2e

Please sign in to comment.