forked from Markus-Goetz/block-prediction
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
executable file
·65 lines (48 loc) · 1.61 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import tensorflow as tf
import h5py
import numpy as np
def parse_cli():
parser = argparse.ArgumentParser()
parser.add_argument(
'-m', '--model',
metavar='MODEL',
type=str,
default='./latest_model.keras',
dest='model_file',
help='path to the model checkpoint file'
)
parser.add_argument(
'-f', '--file',
metavar='FILE',
type=str,
default='artificial.h5',
dest='predict_file',
help='path to the HDF5 file with the prediction data'
)
return parser.parse_args()
def load_data(path):
with h5py.File(path, 'r') as handle:
data = np.array(handle['diagonalset'])
labels = np.array(handle['vectorset'])
return data, labels
def load_model(model_file):
return tf.keras.models.load_model(model_file)
def preprocess(data, labels):
# simply add an additional dimension for the channels for data
# swap axis of the label set
return np.expand_dims(data, axis=3), np.moveaxis(labels, 0, -1)
def predict(data, model):
return model.predict(data, batch_size=1, verbose=True)
def store(prediction, path):
prediction_dataset = 'predictionset'
with h5py.File(path, 'r+') as handle:
if prediction_dataset in handle:
del handle[prediction_dataset]
handle[prediction_dataset] = prediction
if __name__ == '__main__':
arguments = parse_cli()
data, labels = preprocess(*load_data(arguments.predict_file))
model = load_model(arguments.model_file)
prediction = predict(data, model)
store(prediction, arguments.predict_file)