This repository has been archived by the owner on Aug 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
66 lines (54 loc) · 2.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from modelNet import *
import numpy as np
import tensorflow as tf
# from tensorflow.keras import layers
# from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D, GaussianNoise
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.preprocessing import image
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.keras.applications.imagenet_utils import preprocess_input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard
# import pydot
# from IPython.display import SVG
from tensorflow.python.keras.utils.vis_utils import model_to_dot
from tensorflow.keras.utils import plot_model
# from tensorflow.resnets_utils import *
from tensorflow.keras.initializers import glorot_uniform
import scipy.misc
from matplotlib.pyplot import imshow
from sklearn.model_selection import train_test_split
import datetime
# %matplotlib inline
import tensorflow.keras.backend as K
K.set_image_data_format('channels_last')
K.set_learning_phase(1)
def main():
#define and compile model.
# model = simpleModel(input_shape=(64,64,3),classes=5)
model = ResNet50(input_shape=(64,64,3),classes=5)
opt = Adam(lr=0.0001)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['acc'])
#load the dataset
x = np.load('./data/x.npy')
# x = x.reshape((4685,128,128,3))
y = np.load('./data/y.npy')
# print(y.shape,x.shape)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.1, random_state = 400)
x_train = x_train / 128.0 -1
x_test = x_test /128.0 -1
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
#set model callback.
save_weights = ModelCheckpoint("./models/model.h5",
save_best_only=True, monitor='val_acc')
log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
callbacks = [save_weights, tensorboard_callback]
#train the model.
model.fit(x_train, y_train, epochs = 100, batch_size=32,
validation_data = (x_test, y_test), callbacks=callbacks)
if __name__ == '__main__':
main()