Skip to content

Commit

Permalink
Bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ArcCha committed May 10, 2018
1 parent c5d6928 commit 6b1a5b6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
13 changes: 9 additions & 4 deletions plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""
isort:skip_file
"""
import itertools

import matplotlib
matplotlib.use('Agg')
import numpy as np
from matplotlib import pyplot as plt

Expand All @@ -20,7 +25,7 @@ def plot_confusion_matrix(cm, classes, normalize=False,

if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
print('Normalized confusion matrix')
else:
print('Confusion matrix, without normalization')

Expand All @@ -34,16 +39,16 @@ def plot_confusion_matrix(cm, classes, normalize=False,
yticks = []
for i in (range(cm.shape[0])):
acc = cm[i, i] / np.sum(cm[i])
yticks.append("{} (acc={:.10f})".format(i, acc))
yticks.append('{} (acc={:.10f})'.format(i, acc))

plt.yticks(tick_marks, yticks)

fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
horizontalalignment='center',
color='white' if cm[i, j] > thresh else 'black')

plt.ylabel('True label')
plt.xlabel('Predicted label')
Expand Down
34 changes: 22 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import time

import yaml
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader
from tqdm import tqdm

from cuda import *
from data import *
from net import *
from plot import plot_confusion_matrix
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm

config_path = Path('train_config.yaml')
config = None
Expand All @@ -20,16 +21,25 @@
H['cuda'] = CUDA

train_path = config['train_path']
H['data_num'] = config['data_num']
H['validation_num'] = config['validation_num']
H['batch_size'] = config['batch_size']
train_dataset, validation_dataset = train_validation_split(
train_path, max_rows=config['data_num'], validation_num=config['validation_num'], pretransform=True)

# train_dataset, validation_dataset = train_validation_split(
# train_path, max_rows=config['data_num'], validation_num=config['validation_num'], pretransform=True)
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.Grayscale(),
transforms.ToTensor()])
train_dataset = ImageFolder('augment/out/', transform=transform)
validation_dataset, _ = train_validation_split(train_path, pretransform=True)

train_loader = DataLoader(dataset=train_dataset,
batch_size=config['batch_size'], shuffle=True, num_workers=4, pin_memory=True)
validation_loader = DataLoader(
dataset=validation_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=1, pin_memory=True)

H['data_num'] = len(train_loader)
H['validation_num'] = len(validation_loader)

validation_classes = np.zeros(10)
for x, y in tqdm(validation_loader, desc='Validation stats'):
idx, counts = np.unique(y, return_counts=True)
Expand Down Expand Up @@ -85,8 +95,8 @@ def is_last_epoch():
y_pred = net(x).argmax(dim=1)
acc += y_true.eq(y_pred).sum()
if is_last_epoch():
true_train += y_true.numpy().tolist()
predicted_train += y_pred.numpy().tolist()
true_train += y_true.to(torch.device('cpu')).numpy().tolist()
predicted_train += y_pred.to(torch.device('cpu')).numpy().tolist()

acc = float(acc) / (len(train_loader) * config['batch_size'])
H['train_acc'].append(acc)
Expand All @@ -97,8 +107,8 @@ def is_last_epoch():
y_pred = net(x).argmax(dim=1)
acc += y_true.eq(y_pred).sum()
if is_last_epoch():
true_test += y_true.numpy().tolist()
predicted_test += y_pred.numpy().tolist()
true_test += y_true.to(torch.device('cpu')).numpy().tolist()
predicted_test += y_pred.to(torch.device('cpu')).numpy().tolist()

acc = float(acc) / config['validation_num']
H['test_acc'].append(acc)
Expand Down

0 comments on commit 6b1a5b6

Please sign in to comment.