Skip to content

Commit

Permalink
Refactor committee architecture to separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
ArcCha committed May 11, 2018
1 parent bff5b16 commit b6dd81c
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
5 changes: 5 additions & 0 deletions committee.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from net import *

nets = [CNN() for _ in range(35)]
net_dirs = [Path('./' + type(net).__name__ + str(i))
for i, net in enumerate(nets)]
4 changes: 1 addition & 3 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from committee import net_dirs, nets
from cuda import *
from data import *
from net import *
Expand All @@ -14,9 +15,6 @@
shuffle=False, num_workers=1, pin_memory=True)

_, device = get_cuda_if_available()
nets = [CNN(), CNN(), CNN()]
net_dirs = [Path('./' + type(net).__name__ + str(i))
for i, net in enumerate(nets)]

for net, net_dir in zip(nets, net_dirs):
net.eval()
Expand Down
4 changes: 1 addition & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time

import yaml
from committee import net_dirs, nets
from cuda import *
from data import *
from net import *
Expand Down Expand Up @@ -46,9 +47,6 @@
validation_classes[idx] += counts
H['validation_classes'] = validation_classes.tolist()

nets = [CNN(), CNN(), CNN()]
net_dirs = [Path('./' + type(net).__name__ + str(i))
for i, net in enumerate(nets)]
for net_dir in net_dirs:
net_dir.mkdir(parents=True, exist_ok=True)

Expand Down
5 changes: 2 additions & 3 deletions validate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

from committee import net_dirs, nets
from cuda import *
from data import *
from net import *
Expand All @@ -16,9 +17,7 @@
config['train_path'], pretransform=True)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=1,
shuffle=False, num_workers=1, pin_memory=True)
nets = [CNN(), CNN(), CNN()]
net_dirs = [Path('./' + type(net).__name__ + str(i))
for i, net in enumerate(nets)]

for net, net_dir in zip(nets, net_dirs):
net.eval()
net.to(device)
Expand Down

0 comments on commit b6dd81c

Please sign in to comment.