Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Automatic Speech Recognition using CTC example for Keras v3 #1768

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions examples/audio/ctc_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,16 @@
## Setup
"""

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from keras import layers

import tensorflow as tf
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from IPython import display
from jiwer import wer
Expand Down Expand Up @@ -118,9 +123,9 @@
# The set of characters accepted in the transcription.
characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]
# Mapping characters to integers
char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="")
char_to_num = layers.StringLookup(vocabulary=characters, oov_token="")
# Mapping integers back to original characters
num_to_char = keras.layers.StringLookup(
num_to_char = layers.StringLookup(
vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
)

Expand Down Expand Up @@ -151,9 +156,9 @@ def encode_single_sample(wav_file, label):
file = tf.io.read_file(wavs_path + wav_file + ".wav")
# 2. Decode the wav file
audio, _ = tf.audio.decode_wav(file)
audio = tf.squeeze(audio, axis=-1)
audio = keras.ops.squeeze(audio, axis=-1)
# 3. Change type to float
audio = tf.cast(audio, tf.float32)
audio = keras.ops.cast(audio, tf.float32)
# 4. Get the spectrogram
spectrogram = tf.signal.stft(
audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length
Expand All @@ -162,8 +167,8 @@ def encode_single_sample(wav_file, label):
spectrogram = tf.abs(spectrogram)
spectrogram = tf.math.pow(spectrogram, 0.5)
# 6. normalisation
means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
means = keras.ops.mean(spectrogram, 1, keepdims=True)
stddevs = keras.ops.std(spectrogram, 1, keepdims=True)
spectrogram = (spectrogram - means) / (stddevs + 1e-10)
###########################################
## Process the label
Expand Down Expand Up @@ -240,24 +245,6 @@ def encode_single_sample(wav_file, label):
"""
## Model

We first define the CTC Loss function.
"""


def CTCLoss(y_true, y_pred):
# Compute the training-time loss value
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
return loss


"""
We now define our model. We will define a model similar to
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).
"""
Expand Down Expand Up @@ -320,7 +307,7 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
# Optimizer
opt = keras.optimizers.Adam(learning_rate=1e-4)
# Compile the model and return
model.compile(optimizer=opt, loss=CTCLoss)
model.compile(optimizer=opt, loss=keras.losses.CTC())
return model


Expand All @@ -337,11 +324,38 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
"""


# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L715-L739


def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
input_shape = tf.shape(y_pred)
num_samples, num_steps = input_shape[0], input_shape[1]
y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon())
input_length = tf.cast(input_length, tf.int32)

if greedy:
(decoded, log_prob) = tf.nn.ctc_greedy_decoder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we're going to have to use TF for this and ctc_beam_search_decoder I guess, unless we implement them as new backend ops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, thanks for the feedback 👍
I created an issue to address this.
Please let me know if I should change the description or add/remove details.
Thanks!

inputs=y_pred, sequence_length=input_length
)
else:
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
inputs=y_pred,
sequence_length=input_length,
beam_width=beam_width,
top_paths=top_paths,
)
decoded_dense = []
for st in decoded:
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
return (decoded_dense, log_prob)


# A utility function to decode the output of the network
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
# Iterate over the results and get back the text
output_text = []
for result in results:
Expand Down
Loading
Loading