-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_test_data.py
110 lines (83 loc) · 3.25 KB
/
get_test_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
import json
import torch
import random
from dataset.dataset import ImageFolderDataset
import numpy as np
from torch.utils.data import DataLoader, random_split
import argparse
import torchvision
from tqdm import tqdm
class Inference:
def __init__(self, opts) -> None:
self.opts = opts
self.global_step = 0
torch.backends.cudnn.benchmark = True
SEED = 2107
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed_all(SEED)
_, self.test_dataset = self.configure_datasets()
self.configure_dataloaders(batch_size=1)
def configure_datasets(self):
full_dataset = ImageFolderDataset(
path=self.opts.data_path,
synth_path=self.opts.synth_data_path,
load_original_ws=False,
synth_data_ratio=0,
use_labels=True,
)
train_size, test_size = 1 - 0.05, 0.05
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
# print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")
return train_dataset, test_dataset
def configure_dataloaders(self, batch_size):
self.test_dataloader = DataLoader(
self.test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=1,
drop_last=True,
)
def generate_from_images(self):
torch.cuda.empty_cache()
dataset_json = {"labels": []}
for batch_idx, batch in tqdm(enumerate(self.test_dataloader), total=self.opts.num_samples):
if batch_idx >= self.opts.num_samples:
print("Finished generating samples.")
break
(
x_resized,
x,
camera_param,
x_mirror_resized,
x_mirror,
camera_param_mirror,
w_original,
fname,
synth_data,
) = batch
torchvision.utils.save_image(
((x + 1) / 2).clamp(0, 1),
f"{self.opts.output_dir}/{fname[0]}",
value_range=(0, 1),
)
dataset_json["labels"].append([fname[0], camera_param.numpy().tolist()[0]])
print("Done creating samples!")
with open(f"{self.opts.output_dir}/dataset.json", "w") as f:
json.dump(dataset_json, f)
print("Saved dataset.json!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate samples from a trained PanoHead model")
parser.add_argument("--output_dir", type=str, help="Path to the output directory", required=True)
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate")
parser.add_argument("--data_path", type=str, help="Path to data images.", required=True)
parser.add_argument("--synth_data_path", type=str, help="Path to syhtn data images.", required=True)
opts = parser.parse_args()
os.makedirs(opts.output_dir, exist_ok=True)
inference = Inference(opts)
inference.generate_from_images()