-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
148 lines (130 loc) · 4.58 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import random
import numpy as np
import os
import struct
import gzip
from PIL import Image
def set_seeds(seed: int = 42) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def mmd(x, y, gammas, device):
gammas = gammas.to(device)
cost = torch.mean(gram_matrix(x, x, gammas=gammas)).to(device)
cost += torch.mean(gram_matrix(y, y, gammas=gammas)).to(device)
cost -= 2 * torch.mean(gram_matrix(x, y, gammas=gammas)).to(device)
if cost < 0:
return torch.tensor(0).to(device)
return cost
def gram_matrix(x, y, gammas):
gammas = gammas.unsqueeze(1)
pairwise_distances = torch.cdist(x, y, p=2.0)
pairwise_distances_sq = torch.square(pairwise_distances)
tmp = torch.matmul(gammas, torch.reshape(pairwise_distances_sq, (1, -1)))
tmp = torch.reshape(torch.sum(torch.exp(-tmp), 0), pairwise_distances_sq.shape)
return tmp
def resize_and_crop(img, size=(100,100), crop_type='middle'):
# If height is higher we resize vertically, if not we resize horizontally
# Get current and desired ratio for the images
img_ratio = img.size[0] / float(img.size[1])
ratio = size[0] / float(size[1])
# The image is scaled/cropped vertically or horizontally
# depending on the ratio
if ratio > img_ratio:
img = img.resize((
size[0],
int(round(size[0] * img.size[1] / img.size[0]))),
Image.ANTIALIAS)
# Crop in the top, middle or bottom
if crop_type == 'top':
box = (0, 0, img.size[0], size[1])
elif crop_type == 'middle':
box = (
0,
int(round((img.size[1] - size[1]) / 2)),
img.size[0],
int(round((img.size[1] + size[1]) / 2)))
elif crop_type == 'bottom':
box = (0, img.size[1] - size[1], img.size[0], img.size[1])
else:
raise ValueError('ERROR: invalid value for crop_type')
img = img.crop(box)
elif ratio < img_ratio:
img = img.resize((
int(round(size[1] * img.size[0] / img.size[1])),
size[1]),
Image.ANTIALIAS)
# Crop in the top, middle or bottom
if crop_type == 'top':
box = (0, 0, size[0], img.size[1])
elif crop_type == 'middle':
box = (
int(round((img.size[0] - size[0]) / 2)),
0,
int(round((img.size[0] + size[0]) / 2)),
img.size[1])
elif crop_type == 'bottom':
box = (
img.size[0] - size[0],
0,
img.size[0],
img.size[1])
else:
raise ValueError('ERROR: invalid value for crop_type')
img = img.crop(box)
else:
img = img.resize((
size[0],
size[1]),
Image.ANTIALIAS)
# If the scale is the same, we do not need to crop
return img
def _load_uint8(f):
idx_dtype, ndim = struct.unpack('BBBB', f.read(4))[2:]
shape = struct.unpack('>' + 'I' * ndim, f.read(4 * ndim))
buffer_length = int(np.prod(shape))
data = np.frombuffer(f.read(buffer_length), dtype=np.uint8).reshape(shape)
return data
def _save_uint8(data, f):
data = np.asarray(data, dtype=np.uint8)
f.write(struct.pack('BBBB', 0, 0, 0x08, data.ndim))
f.write(struct.pack('>' + 'I' * data.ndim, *data.shape))
f.write(data.tobytes())
def save_idx(data: np.ndarray, path: str):
"""Writes an array to disk in IDX format.
Parameters
----------
data : array_like
Input array of dtype ``uint8`` (will be coerced if different dtype).
path : str
Path of the output file. Will compress with `gzip` if path ends in '.gz'.
References
----------
http://yann.lecun.com/exdb/mnist/
"""
open_fcn = gzip.open if path.endswith('.gz') else open
with open_fcn(path, 'wb') as f:
_save_uint8(data, f)
def load_idx(path: str) -> np.ndarray:
"""Reads an array in IDX format from disk.
Parameters
----------
path : str
Path of the input file. Will uncompress with `gzip` if path ends in '.gz'.
Returns
-------
np.ndarray
Output array of dtype ``uint8``.
References
----------
http://yann.lecun.com/exdb/mnist/
"""
open_fcn = gzip.open if path.endswith('.gz') else open
with open_fcn(path, 'rb') as f:
return _load_uint8(f)