-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate.py
35 lines (29 loc) · 1012 Bytes
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from pathlib import Path
from committee import net_dirs, nets
from cuda import *
from data import *
from net import *
from torch.utils.data import DataLoader
config_path = Path('train_config.yaml')
config = None
with config_path.open('r') as f:
config = yaml.load(f)
_, device = get_cuda_if_available()
validation_dataset, _ = train_validation_split(
config['train_path'], pretransform=True)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=1,
shuffle=False, num_workers=1, pin_memory=True)
for net, net_dir in zip(nets, net_dirs):
net.eval()
net.to(device)
net_state_path = net_dir.joinpath('net.state')
with net_state_path.open(mode='rb') as f:
net.load_state_dict(torch.load(f))
acc = 0
for x, y_true in tqdm(validation_loader):
x = x.to(device)
y_preds = [net(x).argmax(dim=1) for net in nets]
y_pred = np.argmax(np.bincount(y_preds))
if y_true.numpy() == y_pred:
acc += 1
print(acc / 42000.0)