Skip to content

Commit

Permalink
Fix predict
Browse files Browse the repository at this point in the history
  • Loading branch information
ArcCha committed May 5, 2018
1 parent 95ce3fa commit 51d2576
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 31 deletions.
10 changes: 10 additions & 0 deletions cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch


def get_cuda_if_available():
if torch.cuda.is_available():
device = torch.device('cuda:0')
capable = torch.cuda.get_device_capability(0)[0] >= 4
if capable:
return (True, device)
return (False, torch.device('cpu'))
22 changes: 15 additions & 7 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,44 @@


class DigitRecognizerDataset(Dataset):
def __init__(self, X, Y, pretransform=False):
def __init__(self, X, Y, pretransform=False, test=False):
self.test = test
self.pretransform = pretransform
self.transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()])
if self.pretransform:
X = list(map(self.transform, map(Image.fromarray, X)))
self.X = X
self.Y = Y
if not test:
self.Y = Y
self.len = len(X)

def __getitem__(self, i):
x = self.X[i]
if not self.pretransform:
x = Image.fromarray(x)
x = self.transform(x)
y = int(self.Y[i])
return (x, y)
if not self.test:
y = int(self.Y[i])
return (x, y)
return x

def __len__(self):
return self.len


def train_validation_split(csv_file, max_rows=None, validation_num=None):
def train_validation_split(csv_file, max_rows=None, validation_num=None, test=False):
if max_rows:
data = np.genfromtxt(csv_file, delimiter=',',
skip_header=1, max_rows=max_rows)
else:
data = np.genfromtxt(csv_file, delimiter=',', skip_header=1)
Y, X = np.split(data, [1], axis=1)
if not test:
Y, X = np.split(data, [1], axis=1)
else:
X = data
Y = None
X = X.reshape(-1, 28, 28)
if validation_num:
idx = np.random.randint(0, len(X), validation_num)
Expand All @@ -46,5 +54,5 @@ def train_validation_split(csv_file, max_rows=None, validation_num=None):
validation_dataset = DigitRecognizerDataset(V, VY)
else:
validation_dataset = None
train_dataset = DigitRecognizerDataset(X, Y)
train_dataset = DigitRecognizerDataset(X, Y, test=test)
return (train_dataset, validation_dataset)
41 changes: 25 additions & 16 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
from pathlib import Path

import numpy as np
import torch
from net import Net
from PIL import Image
from torchvision import transforms
from cuda import *
from data import *
from net import *
from torch.utils.data import DataLoader
from tqdm import tqdm

test_path = '/home/arccha/.kaggle/competitions/digit-recognizer/test.csv'
net = Net()
net.load_state_dict(torch.load('./net_state'))
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()])

X = np.genfromtxt(test_path, delimiter=',', skip_header=1)
X = X.reshape(-1, 28, 28)
X = list(map(Image.fromarray, X))
X = list(map(transform, X))
test_dataset, _ = train_validation_split(test_path, test=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=5,
shuffle=False, num_workers=1, pin_memory=True)

_, device = get_cuda_if_available()
net = SimpleCNN()
net_dir = Path('./' + type(net).__name__)
net.to(device)
net_state_path = net_dir.joinpath('net.state')
with net_state_path.open(mode='rb') as f:
net.load_state_dict(torch.load(f))

print('ImageId,Label')
for i, x in enumerate(X):
y = net(x.unsqueeze(0)).argmax()
print(str(i + 1) + ',' + str(int(y)))
result_path = net_dir.joinpath('result.txt')
with result_path.open('w') as f:
f.write('ImageId,Label\n')
for i, x in enumerate(tqdm(test_loader)):
x = x.to(device)
y = net(x).argmax()
f.write(str(i + 1) + ',' + str(int(y)) + '\n')
12 changes: 4 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from cuda import *
from data import *
from net import *
from torch.autograd import Variable
Expand All @@ -15,15 +16,10 @@

H = {} # Training history and statistics
USE_CUDA = True
CUDA = USE_CUDA and torch.cuda.is_available()

if CUDA:
device = torch.device('cuda:0')
capable = torch.cuda.get_device_capability(0)[0] >= 4
if not capable:
device = torch.device('cpu')
CUDA = CUDA and capable
if USE_CUDA:
CUDA, device = get_cuda_if_available()
else:
CUDA = False
device = torch.device('cpu')
H['cuda'] = CUDA

Expand Down

0 comments on commit 51d2576

Please sign in to comment.