-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy patheval_loader.py
executable file
·117 lines (88 loc) · 3.26 KB
/
eval_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Copyright (c) 2024-present Naver Cloud Corp.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import numpy as np
import torch
import json
from PIL import Image
import torch.utils.data as data
from torch.utils.data.distributed import DistributedSampler
def get_evalloader(config):
loader_dict = {}
for data_type in config.dataset.data_type:
dataset = Dataset(
config.data_root,
config.dataset,
data_type,
)
if config.local_rank == 0:
print(f"LOG) ZIM Dataset: {data_type} ({len(dataset)})")
sampler = None
if config.use_ddp:
sampler = DistributedSampler(
dataset,
rank=config.local_rank,
num_replicas=config.world_size,
)
dataloader = data.DataLoader(
dataset,
batch_size=1,
num_workers=config.eval.workers,
sampler=sampler,
shuffle=False,
pin_memory=True,
drop_last=False,
)
loader_dict[data_type] = dataloader
return loader_dict
class Dataset(data.Dataset):
def __init__(
self,
data_root,
dataset_config,
data_type,
):
super(Dataset, self).__init__()
self.root = os.path.join(data_root, dataset_config.valset)
with open(os.path.join(self.root, dataset_config.data_list_txt), "r") as f:
f_list = f.read().splitlines()
f_list = [p for p in f_list if data_type in p]
self.images = []
self.mattes = []
self.jsons = []
for fname in f_list:
img_path, matte_path, json_path, seg_path = fname.split(" ")
img_path = os.path.join(self.root, img_path)
matte_path = os.path.join(self.root, matte_path)
json_path = os.path.join(self.root, json_path)
self.images.append(img_path)
self.mattes.append(matte_path)
self.jsons.append(json_path)
assert len(self.images) == len(self.mattes) == len(self.jsons)
def __getitem__(self, index):
fname = os.path.basename(self.mattes[index])
img = Image.open(self.images[index]).convert('RGB')
matte = Image.open(self.mattes[index]).convert('L')
orig_w, orig_h = img.size
img = np.float32(img)
matte = np.float32(matte) / 255.
ratio = (matte > 0.3).sum() / matte.size
with open(self.jsons[index], "r") as f:
meta_data = json.load(f)
points = meta_data["point"]
points += [(-1, -1, -1) for _ in range(50-len(points))] # padding
bbox = meta_data["bbox"]
output = {
"images": torch.tensor(img, dtype=torch.float),
"mattes": torch.tensor(matte, dtype=torch.float),
"points": torch.tensor(points, dtype=torch.float),
"bboxes": torch.tensor(bbox, dtype=torch.float),
"fname": fname,
"ratio": ratio,
}
return output
def __len__(self):
return len(self.images)