-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathe_track_unet.py
145 lines (117 loc) · 5.11 KB
/
e_track_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# J. Akeret, C. Chang, A. Lucchi, A. Refregier,
# Radio frequency interference mitigation using deep convolutional neural networks,
# Astronomy and Computing, # Volume 18, # 2017, # Pages 35-39,
# ISSN 2213-1337, # https://doi.org/10.1016/j.ascom.2017.01.002.
# unet is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# unet is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with unet. If not, see <http://www.gnu.org/licenses/>.
import logging
import numpy as np
import tensorflow as tf
import unet
import timeit
from tensorflow.keras import backend as K
from unet import custom_objects, utils
from dataset import e_track_dataset
img_x_size = 352 # 346
img_y_size = 256 # 260
LEARNING_RATE = 1e-3
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
np.random.seed(98765)
def weighted_categorical_crossentropy(weights):
weights = K.variable(weights)
def loss(y_true, y_pred):
y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
loss_wcc = y_true * K.log(y_pred) * weights
loss_wcc = -K.sum(loss_wcc, -1)
return loss_wcc
return loss
def train():
unet_model = unet.build_model(nx=img_x_size,
ny=img_y_size,
channels=3,
num_classes=2,
layer_depth=3,
filters_root=32,
padding="same"
)
unet.finalize_model(unet_model,
loss=weighted_categorical_crossentropy(np.array([0.1, 0.9])),
auc=False,
epsilon=0.000001,
learning_rate=LEARNING_RATE
)
unet_model.summary()
callback = tf.keras.callbacks.EarlyStopping(
monitor="mean_iou",
min_delta=0,
patience=3,
verbose=1,
mode="max",
baseline=None,
restore_best_weights=True
)
trainer = unet.Trainer(name="pupil_event", callbacks=[callback])
users = range(4, 28, 1)
train_tfrecs = np.array(tf.io.gfile.glob(f"data/tfrecord_0/*user-{user}*.tfrec" for user in users))
valid_tfrecs = np.array(tf.io.gfile.glob(f"data/tfrecord_1/*user-{user}*.tfrec" for user in users))
test_tfrecs = np.array(tf.io.gfile.glob(f"data/tfrecord_2/*user-{user}*.tfrec" for user in users))
train_dataset = e_track_dataset.load_data(train_tfrecs)
valid_dataset = e_track_dataset.load_data(valid_tfrecs)
test_dataset = e_track_dataset.load_data(test_tfrecs)
for feat in train_dataset.take(1):
print(f"Image shape: {feat[0].shape}")
print(f"Label shape: {feat[1].shape}")
print(f"Label shape: {feat[1].dtype}")
print(f"Start Fit")
trainer.fit(unet_model,
train_dataset,
valid_dataset,
test_dataset,
verbose=2,
epochs=40,
batch_size=8)
print(f"End Fit")
return unet_model
def predict():
custom_objects['loss'] = weighted_categorical_crossentropy(np.array([0.1, 0.9]))
unet_model = tf.keras.models.load_model('trained_model/2023-01-24T00-11_42',
custom_objects=custom_objects)
unet_model.summary()
users = range(4, 28, 1)
test_tfrecs = np.array(tf.io.gfile.glob(f"data/tfrecord_2/*user-{user}*.tfrec" for user in users))
test_dataset = e_track_dataset.load_data(test_tfrecs)
count = 0
for element in test_dataset:
count += 1
warmupResult = unet_model.predict(tf.zeros((1, img_x_size, img_y_size, 3)))
prediction = unet_model.predict(test_dataset.batch(batch_size=1))
print(f'***** test_dataset len: {count}')
results = unet_model.evaluate(test_dataset.take(3).batch(batch_size=1))
print(results)
test_dataset_3 = test_dataset.take(3)
start_time = timeit.default_timer()
prediction = unet_model.predict(test_dataset.batch(batch_size=1))
prediction = unet_model.predict(test_dataset.batch(batch_size=1))
prediction = unet_model.predict(test_dataset.batch(batch_size=1))
prediction = unet_model.predict(test_dataset.batch(batch_size=1))
run_time = timeit.default_timer() - start_time
print(f'***** time: {run_time}')
dataset = test_dataset.map(utils.crop_image_and_label_to_shape(prediction.shape[1:]))
print(prediction.shape)
if __name__ == '__main__':
# Training
# train()
# Testing
with tf.device('/cpu:0'):
predict()