-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathshow_triplets.py
35 lines (30 loc) · 1.03 KB
/
show_triplets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
import argparse
from data import load_dataset, DATASETS
color = 'lime'
plt.rcParams['axes.edgecolor'] = color
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', default='resources/datasets')
parser.add_argument('--dataset', default='cifar100-coarse', choices=DATASETS)
args = parser.parse_args()
dataset = load_dataset(name=args.dataset, data_dir=args.data_dir)
triplets = dataset.get_triplets()
for k in range(20):
i1, i2, i3 = triplets[k]
fig, ax = plt.subplots(1, 3)
for i, idx in enumerate([i1, i2, i3]):
img, target = dataset[idx]
img = img.resize((224, 224))
ax[i].imshow(img)
if i == 2:
ax[i].patch.set_edgecolor(color)
ax[i].patch.set_linewidth('5')
ax[i].set_xticks([])
ax[i].set_yticks([])
else:
ax[i].axis('off')
plt.tight_layout()
plt.savefig(f'resources/plots/cifar_images/{k}.png')