-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathEvaluateCIFAR.py
51 lines (45 loc) · 1.72 KB
/
EvaluateCIFAR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Libraries
import numpy as np
import pickle
from skimage.color import rgb2gray
# Other python files
import constants as k
from evaluation_metrics import *
import experiments as exp
import dataset as ds
from NetworkTF import *
import GIST as cust_gist # You have to install FFTW and lear-gist first before you can import
# Documentation here: https://github.com/tuttieee/lear-gist-python
# Converts numpy images (32x32x3) into 512-D GIST vector
# Converts images into grayscale first
def convertToGIST(images):
param = {
"orientationsPerScale": np.array([8,8,8,8]),
"numberBlocks": [4,4],
"fc_prefilt": 4,
"boundaryExtension": 10
}
gist = cust_gist.GIST(param)
gist_vectors = [gist._gist_extract(rgb2gray(img)) for img in images]
gist_vectors = np.array(gist_vectors, copy=False)
return gist_vectors
def evaluateCIFAR():
# Preprocess CIFAR data into standardized form (classDict)
try:
classDict = pickle.load(open("cifar_classdict.pkl","rb"))
print("Classdict dump found, loading previous")
except FileNotFoundError:
print("First breaking classes since no existing dump was found")
images, labels = ds.load_ds_raw('cifar_numpy.npz')
gistData = np.array(convertToGIST(images), copy=False)
classDict = ds.break_classes(gistData, labels)
pickle.dump(classDict, open("cifar_classdict.pkl","wb"))
# Generate inputs (supervised is a superset of unsupervised + +/- pairs)
print("Generating Inputs")
inputs = ds.generateInputs(classDict, "supervised")
# Run experiments
print("Running unsupervised experiment")
#upervised", inputs)
exp.runExperiment("cifar_model", "supervised", inputs)
if __name__ == '__main__':
evaluateCIFAR()