diff --git a/precise/functions.py b/precise/functions.py index be59596a..7bcf69a8 100644 --- a/precise/functions.py +++ b/precise/functions.py @@ -33,7 +33,7 @@ def weighted_log_loss(yt, yp) -> Any: yt: Target yp: Prediction """ - from keras import backend as K + import tensorflow.keras.backend as K pos_loss = -(0 + yt) * K.log(0 + yp + K.epsilon()) neg_loss = -(1 - yt) * K.log(1 - yp + K.epsilon()) @@ -42,7 +42,7 @@ def weighted_log_loss(yt, yp) -> Any: def weighted_mse_loss(yt, yp) -> Any: - from keras import backend as K + import tensorflow.keras.backend as K total = K.sum(K.ones_like(yt)) neg_loss = total * K.sum(K.square(yp * (1 - yt))) / K.sum(1 - yt) @@ -52,12 +52,12 @@ def weighted_mse_loss(yt, yp) -> Any: def false_pos(yt, yp) -> Any: - from keras import backend as K + import tensorflow.keras.backend as K return K.sum(K.cast(yp * (1 - yt) > 0.5, 'float')) / K.maximum(1.0, K.sum(1 - yt)) def false_neg(yt, yp) -> Any: - from keras import backend as K + import tensorflow.keras.backend as K return K.sum(K.cast((1 - yp) * (0 + yt) > 0.5, 'float')) / K.maximum(1.0, K.sum(0 + yt)) diff --git a/precise/model.py b/precise/model.py index 6ee90ca6..63a1df74 100644 --- a/precise/model.py +++ b/precise/model.py @@ -19,7 +19,7 @@ from precise.params import inject_params, pr if TYPE_CHECKING: - from keras.models import Sequential + from tensorflow.keras.models import Sequential @attr.s() @@ -45,7 +45,8 @@ def load_precise_model(model_name: str) -> Any: print('Warning: Unknown model type, ', model_name) inject_params(model_name) - return load_keras().models.load_model(model_name) + from tensorflow.keras.models import load_model + return load_model(model_name, custom_objects=globals()) def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential': @@ -63,9 +64,8 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential' print('Loading from ' + model_name + '...') model = load_precise_model(model_name) else: - from keras.layers.core import Dense - from keras.layers.recurrent import GRU - from keras.models import Sequential + from tensorflow.keras.layers import Dense, GRU + from tensorflow.keras.models import Sequential model = Sequential() model.add(GRU( @@ -74,7 +74,6 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential' )) model.add(Dense(1, activation='sigmoid')) - load_keras() metrics = ['accuracy'] + params.extra_metrics * [false_pos, false_neg] set_loss_bias(params.loss_bias) for i in model.layers[:params.freeze_till]: diff --git a/precise/network_runner.py b/precise/network_runner.py index 2283b9db..5fe7b624 100644 --- a/precise/network_runner.py +++ b/precise/network_runner.py @@ -41,15 +41,15 @@ def __init__(self, model_name: str): print('Warning: ', model_name, 'looks like a Keras model.') self.tf = import_module('tensorflow') self.graph = self.load_graph(model_name) + with self.graph.as_default(): + self.inp_var = self.graph.get_operation_by_name('import/net_input').outputs[0] + self.out_var = self.graph.get_operation_by_name('import/net_output').outputs[0] - self.inp_var = self.graph.get_operation_by_name('import/net_input').outputs[0] - self.out_var = self.graph.get_operation_by_name('import/net_output').outputs[0] - - self.sess = self.tf.Session(graph=self.graph) + self.sess = self.tf.compat.v1.Session(graph=self.graph) def load_graph(self, model_file: str) -> 'tf.Graph': graph = self.tf.Graph() - graph_def = self.tf.GraphDef() + graph_def = self.tf.compat.v1.GraphDef() with open(model_file, "rb") as f: graph_def.ParseFromString(f.read()) @@ -68,24 +68,46 @@ def run(self, inp: np.ndarray) -> float: class KerasRunner(Runner): def __init__(self, model_name: str): - import tensorflow as tf - # ISSUE 88 - Following 3 lines added to resolve issue 88 - JM 2020-02-04 per liny90626 - from tensorflow.python.keras.backend import set_session # ISSUE 88 - self.sess = tf.Session() # ISSUE 88 - set_session(self.sess) # ISSUE 88 self.model = load_precise_model(model_name) - self.graph = tf.get_default_graph() def predict(self, inputs: np.ndarray): - from tensorflow.python.keras.backend import set_session # ISSUE 88 - with self.graph.as_default(): - set_session(self.sess) # ISSUE 88 - return self.model.predict(inputs) + return self.model.predict(inputs) def run(self, inp: np.ndarray) -> float: return self.predict(inp[np.newaxis])[0][0] +class TFLiteRunner(Runner): + def __init__(self, model_name: str): + import tensorflow as tf + # Setup tflite environment + self.interpreter = tf.lite.Interpreter(model_path=model_name) + self.interpreter.allocate_tensors() + + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + + def predict(self, inputs: np.ndarray): + # Format output to match Keras's model.predict output + count = 0 + output_data = np.ndarray((inputs.shape[0],1), dtype=np.float32) + + # Support for multiple inputs + for input in inputs: + # Format as float32. Add a wrapper dimension. + current = np.array([input]).astype(np.float32) + + # Load data, run inference and extract output from tensor + self.interpreter.set_tensor(self.input_details[0]['index'], current) + self.interpreter.invoke() + output_data[count] = self.interpreter.get_tensor(self.output_details[0]['index']) + count += 1 + + return output_data + + def run(self, inp: np.ndarray) -> float: + return self.predict(inp[np.newaxis])[0][0] + class Listener: """Listener that preprocesses audio into MFCC vectors and executes neural networks""" @@ -102,7 +124,8 @@ def __init__(self, model_name: str, chunk_size: int = -1, runner_cls: type = Non def find_runner(model_name: str) -> Type[Runner]: runners = { '.net': KerasRunner, - '.pb': TensorFlowRunner + '.pb': TensorFlowRunner, + '.tflite': TFLiteRunner } ext = splitext(model_name)[-1] if ext not in runners: diff --git a/precise/scripts/add_noise.py b/precise/scripts/add_noise.py old mode 100755 new mode 100644 diff --git a/precise/scripts/convert.py b/precise/scripts/convert.py old mode 100755 new mode 100644 index d3bacb43..cece7b2b --- a/precise/scripts/convert.py +++ b/precise/scripts/convert.py @@ -20,7 +20,6 @@ from precise.scripts.base_script import BaseScript - class ConvertScript(BaseScript): usage = Usage(''' Convert wake word model from Keras to TensorFlow @@ -28,8 +27,8 @@ class ConvertScript(BaseScript): :model str Input Keras model (.net) - :-o --out str {model}.pb - Custom output TensorFlow protobuf filename + :-o --out str {model}.tflite + Custom output TensorFlow Lite filename ''') def run(self): @@ -39,49 +38,36 @@ def run(self): def convert(self, model_path: str, out_file: str): """ - Converts an HD5F file from Keras to a .pb for use with TensorFlow + Converts an HD5F file from Keras to a .tflite for use with TensorFlow Runtime Args: model_path: location of Keras model - out_file: location to write protobuf + out_file: location to write TFLite model """ print('Converting', model_path, 'to', out_file, '...') - import tensorflow as tf + import tensorflow as tf # Using tensorflow v2.2 + from tensorflow import keras as K from precise.model import load_precise_model - from keras import backend as K + from precise.functions import weighted_log_loss out_dir, filename = split(out_file) out_dir = out_dir or '.' os.makedirs(out_dir, exist_ok=True) - K.set_learning_phase(0) - model = load_precise_model(model_path) - - out_name = 'net_output' - tf.identity(model.output, name=out_name) - print('Output node name:', out_name) - print('Output folder:', out_dir) - - sess = K.get_session() - - # Write the graph in human readable - tf.train.write_graph(sess.graph.as_graph_def(), out_dir, filename + 'txt', as_text=True) - print('Saved readable graph to:', filename + 'txt') - - # Write the graph in binary .pb file - from tensorflow.python.framework import graph_util - from tensorflow.python.framework import graph_io - - cgraph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [out_name]) - graph_io.write_graph(cgraph, out_dir, filename, as_text=False) - - if isfile(model_path + '.params'): - copyfile(model_path + '.params', out_file + '.params') + # Load custom loss function with model + model = K.models.load_model(model_path, custom_objects={'weighted_log_loss': weighted_log_loss}) + model.summary() - print('Saved graph to:', filename) + # Support for freezing Keras models to .pb has been removed in TF 2.0. - del sess + # Converting instead to TFLite model + print('Starting TFLite conversion.') + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS] + tflite_model = converter.convert() + open(out_file, "wb").write(tflite_model) + print('Wrote to ' + out_file) main = ConvertScript.run_main diff --git a/precise/scripts/engine.py b/precise/scripts/engine.py old mode 100755 new mode 100644 diff --git a/precise/scripts/eval.py b/precise/scripts/eval.py old mode 100755 new mode 100644 diff --git a/precise/scripts/graph.py b/precise/scripts/graph.py old mode 100755 new mode 100644 diff --git a/precise/scripts/listen.py b/precise/scripts/listen.py old mode 100755 new mode 100644 diff --git a/precise/scripts/test.py b/precise/scripts/test.py old mode 100755 new mode 100644 diff --git a/precise/scripts/train.py b/precise/scripts/train.py old mode 100755 new mode 100644 index b89b044b..7b59626c --- a/precise/scripts/train.py +++ b/precise/scripts/train.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from fitipy import Fitipy -from keras.callbacks import LambdaCallback +from tensorflow.keras.callbacks import LambdaCallback from os.path import splitext, isfile from prettyparse import Usage from typing import Any, Tuple @@ -85,7 +85,7 @@ def __init__(self, args): self.model = create_model(args.model, params) self.train, self.test = self.load_data(self.args) - from keras.callbacks import ModelCheckpoint, TensorBoard + from tensorflow.keras.callbacks import ModelCheckpoint checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor, save_best_only=args.save_best) epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch') @@ -104,9 +104,8 @@ def on_epoch_end(_a, _b): self.hash_to_ind = {} self.callbacks = [ - checkpoint, TensorBoard( - log_dir=self.model_base + '.logs', - ), LambdaCallback(on_epoch_end=on_epoch_end) + checkpoint, + LambdaCallback(on_epoch_end=on_epoch_end), ] @staticmethod @@ -160,7 +159,9 @@ def run(self): self.model.fit( train_inputs, train_outputs, self.args.batch_size, self.epoch + self.args.epochs, validation_data=self.test, - initial_epoch=self.epoch, callbacks=self.callbacks + initial_epoch=self.epoch, callbacks=self.callbacks, + use_multiprocessing=True, validation_freq=5, + verbose=1 ) diff --git a/precise/scripts/train_generated.py b/precise/scripts/train_generated.py index 9c275dfe..aa981188 100644 --- a/precise/scripts/train_generated.py +++ b/precise/scripts/train_generated.py @@ -18,7 +18,8 @@ import numpy as np from contextlib import suppress from fitipy import Fitipy -from keras.callbacks import LambdaCallback +from tensorflow.keras.callbacks import LambdaCallback +from os import rename from os.path import splitext, join, basename from prettyparse import Usage from random import random, shuffle @@ -90,8 +91,8 @@ def __init__(self, args): self.model = create_model(args.model, params) self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None) - from keras.callbacks import ModelCheckpoint, TensorBoard - checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor, + from tensorflow.keras.callbacks import ModelCheckpoint + checkpoint = ModelCheckpoint(args.model + '.pb', monitor=args.metric_monitor, save_best_only=args.save_best) epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch') self.epoch = epoch_fiti.read().read(0, int) @@ -103,9 +104,8 @@ def on_epoch_end(_a, _b): self.model_base = splitext(self.args.model)[0] self.callbacks = [ - checkpoint, TensorBoard( - log_dir=self.model_base + '.logs', - ), LambdaCallback(on_epoch_end=on_epoch_end) + checkpoint, + LambdaCallback(on_epoch_end=on_epoch_end) ] self.data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder) @@ -225,16 +225,17 @@ def generate_samples(self): def run(self): """Train the model on randomly generated batches""" - _, test_data = self.data.load(train=False, test=True) + _, test_data = self.data.load(train=True, test=True) try: - self.model.fit_generator( + self.model.fit( self.samples_to_batches(self.generate_samples(), self.args.batch_size), steps_per_epoch=self.args.steps_per_epoch, epochs=self.epoch + self.args.epochs, validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch ) finally: - self.model.save(self.args.model) + self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format + rename(self.args.model + '.h5', self.args.model) # Rename with original save_params(self.args.model) diff --git a/precise/scripts/train_incremental.py b/precise/scripts/train_incremental.py old mode 100755 new mode 100644 index d0238cef..67d744ee --- a/precise/scripts/train_incremental.py +++ b/precise/scripts/train_incremental.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np -from os import makedirs +from os import makedirs, rename from os.path import basename, splitext, isfile, join from prettyparse import Usage from random import random @@ -107,7 +107,8 @@ def retrain(self): validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch ) finally: - self.listener.runner.model.save(self.args.model) + self.listener.runner.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format + rename(self.args.model + '.h5', self.args.model) # Rename with original def train_on_audio(self, fn: str): """Run through a single audio file""" diff --git a/precise/scripts/train_optimize.py b/precise/scripts/train_optimize.py index fc1bee6b..2ceaa55f 100644 --- a/precise/scripts/train_optimize.py +++ b/precise/scripts/train_optimize.py @@ -52,7 +52,7 @@ def __init__(self, args): data = TrainData.from_both(self.args.tags_file, self.args.tags_folder, self.args.folder) _, self.test = data.load(False, True) - from keras.callbacks import ModelCheckpoint + from tensorflow.keras.callbacks import ModelCheckpoint for i in list(self.callbacks): if isinstance(i, ModelCheckpoint): self.callbacks.remove(i) diff --git a/precise/scripts/train_sampled.py b/precise/scripts/train_sampled.py old mode 100755 new mode 100644 diff --git a/setup.py b/setup.py index a55da052..44b22d98 100644 --- a/setup.py +++ b/setup.py @@ -71,10 +71,9 @@ }, install_requires=[ 'numpy', - 'tensorflow>=1.13,<1.14', # Must be on piwheels + 'tensorflow-gpu==2.2.0', 'sonopy', 'pyaudio', - 'keras<=2.1.5', 'h5py', 'wavio', 'typing',