-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathfid_evaluation.py
117 lines (93 loc) · 4.68 KB
/
fid_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Contains code for logging approximate FID scores during training.
If you want to output ground-truth images from the training dataset, you can
run this file as a script.
"""
import os
import shutil
import torch
import copy
import argparse
from torchvision.utils import save_image
from pytorch_fid import fid_score
from tqdm import tqdm
import datasets
import curriculums
def output_real_images(dataloader, num_imgs, real_dir):
img_counter = 0
batch_size = dataloader.batch_size
dataloader = iter(dataloader)
for i in range(num_imgs//batch_size):
real_imgs, _ = next(dataloader)
for img in real_imgs:
save_image(img, os.path.join(real_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1))
img_counter += 1
def setup_evaluation(dataset_name, generated_dir, target_size=128, num_imgs=2048, data_path=None):
# Only make real images if they haven't been made yet
real_dir = os.path.join('EvalImages', dataset_name + '_real_images_' + str(target_size))
# real_dir = os.path.join('EvalImages', dataset_name + '_real_images') #+ str(target_size))
if not os.path.exists(real_dir):
os.makedirs(real_dir)
dataloader, CHANNELS = datasets.get_dataset(dataset_name, img_size=target_size, dataset_path=data_path)
print('outputting real images...')
output_real_images(dataloader, num_imgs, real_dir)
print('...done')
if generated_dir is not None:
os.makedirs(generated_dir, exist_ok=True)
return real_dir
def output_images(generator, input_metadata, rank, world_size, output_dir, num_imgs=2048):
metadata = copy.deepcopy(input_metadata)
metadata['batch_size'] = metadata.get('batch_size', metadata['batch_size'])
metadata['h_stddev'] = metadata.get('h_stddev_eval', metadata['h_stddev'])
metadata['v_stddev'] = metadata.get('v_stddev_eval', metadata['v_stddev'])
metadata['sample_dist'] = metadata.get('sample_dist_eval', metadata['sample_dist'])
metadata['psi'] = 1
img_counter = rank
generator.eval()
img_counter = rank
if rank == 0: pbar = tqdm("generating images", total = num_imgs)
with torch.no_grad():
while img_counter < num_imgs:
z = torch.randn((metadata['batch_size'], 9, metadata["n_basis"]), device=generator.module.device)
z_noise = torch.randn((metadata['batch_size'], 1, 256), device=generator.module.device)
metadata['img_size'] = 64
generated_imgs, _ = generator.module.staged_forward(z, z_noise, **metadata)
for img in generated_imgs:
save_image(img, os.path.join(output_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1))
img_counter += world_size
if rank == 0: pbar.update(world_size)
if rank == 0: pbar.close()
def output_images_diff(generator, input_metadata, rank, world_size, output_dir, num_imgs=2048):
metadata = copy.deepcopy(input_metadata)
# metadata['img_size'] = 128
metadata['batch_size'] = 4
metadata['h_stddev'] = metadata.get('h_stddev_eval', metadata['h_stddev'])
metadata['v_stddev'] = metadata.get('v_stddev_eval', metadata['v_stddev'])
metadata['sample_dist'] = metadata.get('sample_dist_eval', metadata['sample_dist'])
metadata['psi'] = 1
img_counter = rank
generator.eval()
img_counter = rank
if rank == 0: pbar = tqdm("generating images", total = num_imgs)
with torch.no_grad():
while img_counter < num_imgs:
z = torch.randn((metadata['batch_size'], 8, metadata["n_basis"]), device=generator.module.device)
generated_imgs, _ = generator.module.staged_forward(z, **metadata)
for img in generated_imgs:
save_image(img, os.path.join(output_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1))
img_counter += world_size
if rank == 0: pbar.update(world_size)
if rank == 0: pbar.close()
def calculate_fid(dataset_name, generated_dir, target_size=256):
real_dir = os.path.join('EvalImages', dataset_name + '_real_images_' + str(target_size))
# real_dir = os.path.join('EvalImages', dataset_name + '_real_images') #+ str(target_size))
fid = fid_score.calculate_fid_given_paths([real_dir, generated_dir], target_size, 'cuda', 2048)
torch.cuda.empty_cache()
return fid
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='CelebA')
parser.add_argument('--img_size', type=int, default=128)
parser.add_argument('--num_imgs', type=int, default=8000)
opt = parser.parse_args()
real_images_dir = setup_evaluation(opt.dataset, None, target_size=opt.img_size, num_imgs=opt.num_imgs)