Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrades to TF 2.2. Compat with TFLite. #141

Open
wants to merge 21 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e748075
Removes pocketsphinx
andreselizondo-adestech Mar 19, 2020
3df8f7a
Convert all to tf2 using tf_upgrade_v2.
andreselizondo-adestech Mar 25, 2020
317a091
Upgrade TF to 2.1, Keras to 2.3.1 and remove pocketsphinx.
andreselizondo-adestech Mar 25, 2020
920641a
Adjusts training parameters for 20 units in GRU. Adds early stop, red…
andreselizondo-adestech Mar 25, 2020
285499e
Adds TFLiteRunner with enough compatility for inference.
andreselizondo-adestech Mar 30, 2020
eeaebf4
Bugfix: Increments counter for each prediction.
andreselizondo-adestech Apr 2, 2020
8f32394
Adjusts KerasRunner for compat with tf2.
andreselizondo-adestech Apr 2, 2020
15a5bf4
Sets TF version to 2.2.0rc2. Update when 2.2.0 is released.
andreselizondo-adestech Apr 2, 2020
51fc562
Adapts convert script for tf2 and tflite.
andreselizondo-adestech Apr 2, 2020
7d1ad8d
Changes precise-convert default output file extension to tflite.
andreselizondo-adestech Apr 20, 2020
1c3fe87
General cleanup for merge.
andreselizondo-adestech Apr 20, 2020
26e6d25
Restore pocketsphinx files.
andreselizondo-adestech Apr 21, 2020
149a649
Adds quick fix to allow save as h5.
andreselizondo-adestech Apr 21, 2020
2b1bb37
Bugfix: Imports rename from os.
andreselizondo-adestech Apr 21, 2020
e2a3c2d
Fixes saving model to .h5 in train_generated.py
andreselizondo-adestech Apr 29, 2020
3baf5da
Allows for generator training.
andreselizondo-adestech Jul 8, 2020
9e7d884
Swap out keras for tf.keras
MatthewScholefield Jul 17, 2020
06ed742
Merge pull request #1 from MatthewScholefield/tf2
andreselizondo-adestech Aug 19, 2020
04502b2
Merge branch 'dev' into incremental_training
andreselizondo-adestech Aug 19, 2020
1979f14
Merge pull request #2 from andreselizondo-adestech/incremental_training
andreselizondo-adestech Aug 19, 2020
0e0ac5b
Restores pocketsphinx.
andreselizondo-adestech Aug 19, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions precise/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def weighted_log_loss(yt, yp) -> Any:
yt: Target
yp: Prediction
"""
from keras import backend as K
import tensorflow.keras.backend as K

pos_loss = -(0 + yt) * K.log(0 + yp + K.epsilon())
neg_loss = -(1 - yt) * K.log(1 - yp + K.epsilon())
Expand All @@ -42,7 +42,7 @@ def weighted_log_loss(yt, yp) -> Any:


def weighted_mse_loss(yt, yp) -> Any:
from keras import backend as K
import tensorflow.keras.backend as K

total = K.sum(K.ones_like(yt))
neg_loss = total * K.sum(K.square(yp * (1 - yt))) / K.sum(1 - yt)
Expand All @@ -52,12 +52,12 @@ def weighted_mse_loss(yt, yp) -> Any:


def false_pos(yt, yp) -> Any:
from keras import backend as K
import tensorflow.keras.backend as K
return K.sum(K.cast(yp * (1 - yt) > 0.5, 'float')) / K.maximum(1.0, K.sum(1 - yt))


def false_neg(yt, yp) -> Any:
from keras import backend as K
import tensorflow.keras.backend as K
return K.sum(K.cast((1 - yp) * (0 + yt) > 0.5, 'float')) / K.maximum(1.0, K.sum(0 + yt))


Expand Down
11 changes: 5 additions & 6 deletions precise/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from precise.params import inject_params, pr

if TYPE_CHECKING:
from keras.models import Sequential
from tensorflow.keras.models import Sequential


@attr.s()
Expand All @@ -45,7 +45,8 @@ def load_precise_model(model_name: str) -> Any:
print('Warning: Unknown model type, ', model_name)

inject_params(model_name)
return load_keras().models.load_model(model_name)
from tensorflow.keras.models import load_model
return load_model(model_name, custom_objects=globals())


def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential':
Expand All @@ -63,9 +64,8 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential'
print('Loading from ' + model_name + '...')
model = load_precise_model(model_name)
else:
from keras.layers.core import Dense
from keras.layers.recurrent import GRU
from keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU
from tensorflow.keras.models import Sequential

model = Sequential()
model.add(GRU(
Expand All @@ -74,7 +74,6 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential'
))
model.add(Dense(1, activation='sigmoid'))

load_keras()
metrics = ['accuracy'] + params.extra_metrics * [false_pos, false_neg]
set_loss_bias(params.loss_bias)
for i in model.layers[:params.freeze_till]:
Expand Down
55 changes: 39 additions & 16 deletions precise/network_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def __init__(self, model_name: str):
print('Warning: ', model_name, 'looks like a Keras model.')
self.tf = import_module('tensorflow')
self.graph = self.load_graph(model_name)
with self.graph.as_default():
self.inp_var = self.graph.get_operation_by_name('import/net_input').outputs[0]
self.out_var = self.graph.get_operation_by_name('import/net_output').outputs[0]

self.inp_var = self.graph.get_operation_by_name('import/net_input').outputs[0]
self.out_var = self.graph.get_operation_by_name('import/net_output').outputs[0]

self.sess = self.tf.Session(graph=self.graph)
self.sess = self.tf.compat.v1.Session(graph=self.graph)

def load_graph(self, model_file: str) -> 'tf.Graph':
graph = self.tf.Graph()
graph_def = self.tf.GraphDef()
graph_def = self.tf.compat.v1.GraphDef()

with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
Expand All @@ -68,24 +68,46 @@ def run(self, inp: np.ndarray) -> float:

class KerasRunner(Runner):
def __init__(self, model_name: str):
import tensorflow as tf
# ISSUE 88 - Following 3 lines added to resolve issue 88 - JM 2020-02-04 per liny90626
from tensorflow.python.keras.backend import set_session # ISSUE 88
self.sess = tf.Session() # ISSUE 88
set_session(self.sess) # ISSUE 88
self.model = load_precise_model(model_name)
self.graph = tf.get_default_graph()

def predict(self, inputs: np.ndarray):
from tensorflow.python.keras.backend import set_session # ISSUE 88
with self.graph.as_default():
set_session(self.sess) # ISSUE 88
return self.model.predict(inputs)
return self.model.predict(inputs)

def run(self, inp: np.ndarray) -> float:
return self.predict(inp[np.newaxis])[0][0]


class TFLiteRunner(Runner):
def __init__(self, model_name: str):
import tensorflow as tf
# Setup tflite environment
self.interpreter = tf.lite.Interpreter(model_path=model_name)
self.interpreter.allocate_tensors()

self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()

def predict(self, inputs: np.ndarray):
# Format output to match Keras's model.predict output
count = 0
output_data = np.ndarray((inputs.shape[0],1), dtype=np.float32)

# Support for multiple inputs
for input in inputs:
# Format as float32. Add a wrapper dimension.
current = np.array([input]).astype(np.float32)

# Load data, run inference and extract output from tensor
self.interpreter.set_tensor(self.input_details[0]['index'], current)
self.interpreter.invoke()
output_data[count] = self.interpreter.get_tensor(self.output_details[0]['index'])
count += 1

return output_data

def run(self, inp: np.ndarray) -> float:
return self.predict(inp[np.newaxis])[0][0]

class Listener:
"""Listener that preprocesses audio into MFCC vectors and executes neural networks"""

Expand All @@ -102,7 +124,8 @@ def __init__(self, model_name: str, chunk_size: int = -1, runner_cls: type = Non
def find_runner(model_name: str) -> Type[Runner]:
runners = {
'.net': KerasRunner,
'.pb': TensorFlowRunner
'.pb': TensorFlowRunner,
'.tflite': TFLiteRunner
}
ext = splitext(model_name)[-1]
if ext not in runners:
Expand Down
Empty file modified precise/scripts/add_noise.py
100755 → 100644
Empty file.
50 changes: 18 additions & 32 deletions precise/scripts/convert.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@

from precise.scripts.base_script import BaseScript


class ConvertScript(BaseScript):
usage = Usage('''
Convert wake word model from Keras to TensorFlow

:model str
Input Keras model (.net)

:-o --out str {model}.pb
Custom output TensorFlow protobuf filename
:-o --out str {model}.tflite
Custom output TensorFlow Lite filename
''')

def run(self):
Expand All @@ -39,49 +38,36 @@ def run(self):

def convert(self, model_path: str, out_file: str):
"""
Converts an HD5F file from Keras to a .pb for use with TensorFlow
Converts an HD5F file from Keras to a .tflite for use with TensorFlow Runtime

Args:
model_path: location of Keras model
out_file: location to write protobuf
out_file: location to write TFLite model
"""
print('Converting', model_path, 'to', out_file, '...')

import tensorflow as tf
import tensorflow as tf # Using tensorflow v2.2
from tensorflow import keras as K
from precise.model import load_precise_model
from keras import backend as K
from precise.functions import weighted_log_loss

out_dir, filename = split(out_file)
out_dir = out_dir or '.'
os.makedirs(out_dir, exist_ok=True)

K.set_learning_phase(0)
model = load_precise_model(model_path)

out_name = 'net_output'
tf.identity(model.output, name=out_name)
print('Output node name:', out_name)
print('Output folder:', out_dir)

sess = K.get_session()

# Write the graph in human readable
tf.train.write_graph(sess.graph.as_graph_def(), out_dir, filename + 'txt', as_text=True)
print('Saved readable graph to:', filename + 'txt')

# Write the graph in binary .pb file
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io

cgraph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [out_name])
graph_io.write_graph(cgraph, out_dir, filename, as_text=False)

if isfile(model_path + '.params'):
copyfile(model_path + '.params', out_file + '.params')
# Load custom loss function with model
model = K.models.load_model(model_path, custom_objects={'weighted_log_loss': weighted_log_loss})
model.summary()

print('Saved graph to:', filename)
# Support for freezing Keras models to .pb has been removed in TF 2.0.

del sess
# Converting instead to TFLite model
print('Starting TFLite conversion.')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open(out_file, "wb").write(tflite_model)
print('Wrote to ' + out_file)


main = ConvertScript.run_main
Expand Down
Empty file modified precise/scripts/engine.py
100755 → 100644
Empty file.
Empty file modified precise/scripts/eval.py
100755 → 100644
Empty file.
Empty file modified precise/scripts/graph.py
100755 → 100644
Empty file.
Empty file modified precise/scripts/listen.py
100755 → 100644
Empty file.
Empty file modified precise/scripts/test.py
100755 → 100644
Empty file.
13 changes: 7 additions & 6 deletions precise/scripts/train.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from fitipy import Fitipy
from keras.callbacks import LambdaCallback
from tensorflow.keras.callbacks import LambdaCallback
andreselizondo-adestech marked this conversation as resolved.
Show resolved Hide resolved
from os.path import splitext, isfile
from prettyparse import Usage
from typing import Any, Tuple
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self, args):
self.model = create_model(args.model, params)
self.train, self.test = self.load_data(self.args)

from keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
save_best_only=args.save_best)
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
Expand All @@ -104,9 +104,8 @@ def on_epoch_end(_a, _b):
self.hash_to_ind = {}

self.callbacks = [
checkpoint, TensorBoard(
log_dir=self.model_base + '.logs',
), LambdaCallback(on_epoch_end=on_epoch_end)
checkpoint,
LambdaCallback(on_epoch_end=on_epoch_end),
]

@staticmethod
Expand Down Expand Up @@ -160,7 +159,9 @@ def run(self):
self.model.fit(
train_inputs, train_outputs, self.args.batch_size,
self.epoch + self.args.epochs, validation_data=self.test,
initial_epoch=self.epoch, callbacks=self.callbacks
initial_epoch=self.epoch, callbacks=self.callbacks,
use_multiprocessing=True, validation_freq=5,
verbose=1
)


Expand Down
19 changes: 10 additions & 9 deletions precise/scripts/train_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import numpy as np
from contextlib import suppress
from fitipy import Fitipy
from keras.callbacks import LambdaCallback
from tensorflow.keras.callbacks import LambdaCallback
from os import rename
from os.path import splitext, join, basename
from prettyparse import Usage
from random import random, shuffle
Expand Down Expand Up @@ -90,8 +91,8 @@ def __init__(self, args):
self.model = create_model(args.model, params)
self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None)

from keras.callbacks import ModelCheckpoint, TensorBoard
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(args.model + '.pb', monitor=args.metric_monitor,
save_best_only=args.save_best)
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
self.epoch = epoch_fiti.read().read(0, int)
Expand All @@ -103,9 +104,8 @@ def on_epoch_end(_a, _b):
self.model_base = splitext(self.args.model)[0]

self.callbacks = [
checkpoint, TensorBoard(
log_dir=self.model_base + '.logs',
), LambdaCallback(on_epoch_end=on_epoch_end)
checkpoint,
LambdaCallback(on_epoch_end=on_epoch_end)
]

self.data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
Expand Down Expand Up @@ -225,16 +225,17 @@ def generate_samples(self):

def run(self):
"""Train the model on randomly generated batches"""
_, test_data = self.data.load(train=False, test=True)
_, test_data = self.data.load(train=True, test=True)
try:
self.model.fit_generator(
self.model.fit(
self.samples_to_batches(self.generate_samples(), self.args.batch_size),
steps_per_epoch=self.args.steps_per_epoch,
epochs=self.epoch + self.args.epochs, validation_data=test_data,
callbacks=self.callbacks, initial_epoch=self.epoch
)
finally:
self.model.save(self.args.model)
self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
rename(self.args.model + '.h5', self.args.model) # Rename with original
save_params(self.args.model)


Expand Down
5 changes: 3 additions & 2 deletions precise/scripts/train_incremental.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from os import makedirs
from os import makedirs, rename
from os.path import basename, splitext, isfile, join
from prettyparse import Usage
from random import random
Expand Down Expand Up @@ -107,7 +107,8 @@ def retrain(self):
validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch
)
finally:
self.listener.runner.model.save(self.args.model)
self.listener.runner.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
rename(self.args.model + '.h5', self.args.model) # Rename with original

def train_on_audio(self, fn: str):
"""Run through a single audio file"""
Expand Down
2 changes: 1 addition & 1 deletion precise/scripts/train_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, args):
data = TrainData.from_both(self.args.tags_file, self.args.tags_folder, self.args.folder)
_, self.test = data.load(False, True)

from keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ModelCheckpoint
for i in list(self.callbacks):
if isinstance(i, ModelCheckpoint):
self.callbacks.remove(i)
Expand Down
Empty file modified precise/scripts/train_sampled.py
100755 → 100644
Empty file.
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@
},
install_requires=[
'numpy',
'tensorflow>=1.13,<1.14', # Must be on piwheels
'tensorflow-gpu==2.2.0',
'sonopy',
'pyaudio',
'keras<=2.1.5',
'h5py',
'wavio',
'typing',
Expand Down