-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcallbacks.py
41 lines (31 loc) · 1.63 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger, Callback, ReduceLROnPlateau
from sklearn.metrics import roc_auc_score
class RocAucEvaluation(Callback):
"""
Считает рок_аук из сета который получается при автоматическом разделении модели,
либо из заданного сета
validation_data - tuple or list в котором 2 элемента, первый это X а второй y
"""
def __init__(self, validation_data=[], interval=1):
super(Callback, self).__init__()
self.interval = interval
if validation_data:
self.validation_data = validation_data
def on_epoch_end(self, epoch, logs={}):
if epoch % self.interval == 0:
y_pred = self.model.predict(self.validation_data[0], verbose=0)
score = roc_auc_score(self.validation_data[1], y_pred, average='micro')
print(f"ROC-AUC micro avg - epoch: {epoch + 1} - score: {score}\n")
def get_callbacks(patience=2, model_path='temp_.hdf5'):
checkpoint = ModelCheckpoint(model_path,
monitor='val_loss',
verbose=1,
save_best_only=True,
mode='auto')
callbacks_list = [
EarlyStopping(monitor="val_loss", mode="min", patience=patience, verbose=1),
RocAucEvaluation(),
CSVLogger('keras.log', separator=',', append=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=patience, min_lr=0.0001)
]
return callbacks_list