-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkeras_custom_callbacks.py
73 lines (60 loc) · 2.83 KB
/
keras_custom_callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
File name: keras_custom_callbacks.py
Author: Benjamin Planche
Date created: 14.02.2019
Date last modified: 14:49 14.02.2019
Python Version: 3.6
Copyright = "Copyright (C) 2018-2019 of Packt"
Credits = ["Eliot Andres, Benjamin Planche"]
License = "MIT"
Version = "1.0.0"
Maintainer = "non"
Status = "Prototype" # "Prototype", "Development", or "Production"
"""
#==============================================================================
# Imported Modules
#==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import math
import tensorflow as tf
#==============================================================================
# Class Definitions
#==============================================================================
class SimpleLogCallback(tf.keras.callbacks.Callback):
""" Keras callback for simple, denser console logs."""
def __init__(self, metrics_dict, num_epochs='?', log_frequency=1,
metric_string_template='\033[1m[[name]]\033[0m = \033[94m{[[value]]:5.3f}\033[0m'):
"""
Initialize the Callback.
:param metrics_dict: Dictionary containing mappings for metrics names/keys
e.g. {"accuracy": "acc", "val. accuracy": "val_acc"}
:param num_epochs: Number of training epochs
:param log_frequency: Log frequency (in epochs)
:param metric_string_template: (opt.) String template to print each metric
"""
super().__init__()
self.metrics_dict = collections.OrderedDict(metrics_dict)
self.num_epochs = num_epochs
self.log_frequency = log_frequency
# We build a format string to later print the metrics, (e.g. "Epoch 0/9: loss = 1.00; val-loss = 2.00")
log_string_template = 'Epoch {0:2}/{1}: '
separator = '; '
i = 2
for metric_name in self.metrics_dict:
templ = metric_string_template.replace('[[name]]', metric_name).replace('[[value]]', str(i))
log_string_template += templ + separator
i += 1
# We remove the "; " after the last element:
log_string_template = log_string_template[:-len(separator)]
self.log_string_template = log_string_template
def on_train_begin(self, logs=None):
print("Training: \033[92mstart\033[0m.")
def on_train_end(self, logs=None):
print("Training: \033[91mend\033[0m.")
def on_epoch_end(self, epoch, logs={}):
if (epoch - 1) % self.log_frequency == 0 or epoch == self.num_epochs:
values = [logs[self.metrics_dict[metric_name]] for metric_name in self.metrics_dict]
print(self.log_string_template.format(epoch, self.num_epochs, *values))