From aa6108c48e54f47e41efac9152e9a6df8d08dbce Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 24 Oct 2024 12:19:22 +0800 Subject: [PATCH] Handle NaN embeddings in speaker diarization. See also https://github.com/thewh1teagle/sherpa-rs/issues/33 --- cxx-api-examples/sense-voice-cxx-api.cc | 2 +- .../streaming-zipformer-cxx-api.cc | 2 +- cxx-api-examples/whisper-cxx-api.cc | 2 +- scripts/check_style_cpplint.sh | 5 ++- ...ffline-speaker-diarization-pyannote-impl.h | 39 ++++++++++++++++++- .../speaker-embedding-extractor-nemo-impl.h | 2 +- 6 files changed, 45 insertions(+), 7 deletions(-) diff --git a/cxx-api-examples/sense-voice-cxx-api.cc b/cxx-api-examples/sense-voice-cxx-api.cc index 15d752058..ea642b980 100644 --- a/cxx-api-examples/sense-voice-cxx-api.cc +++ b/cxx-api-examples/sense-voice-cxx-api.cc @@ -19,7 +19,7 @@ #include "sherpa-onnx/c-api/cxx-api.h" int32_t main() { - using namespace sherpa_onnx::cxx; + using namespace sherpa_onnx::cxx; // NOLINT OfflineRecognizerConfig config; config.model_config.sense_voice.model = diff --git a/cxx-api-examples/streaming-zipformer-cxx-api.cc b/cxx-api-examples/streaming-zipformer-cxx-api.cc index 5a49dcfc9..ac4abc479 100644 --- a/cxx-api-examples/streaming-zipformer-cxx-api.cc +++ b/cxx-api-examples/streaming-zipformer-cxx-api.cc @@ -20,7 +20,7 @@ #include "sherpa-onnx/c-api/cxx-api.h" int32_t main() { - using namespace sherpa_onnx::cxx; + using namespace sherpa_onnx::cxx; // NOLINT OnlineRecognizerConfig config; // please see diff --git a/cxx-api-examples/whisper-cxx-api.cc b/cxx-api-examples/whisper-cxx-api.cc index 82f0ddb53..348d115bd 100644 --- a/cxx-api-examples/whisper-cxx-api.cc +++ b/cxx-api-examples/whisper-cxx-api.cc @@ -19,7 +19,7 @@ #include "sherpa-onnx/c-api/cxx-api.h" int32_t main() { - using namespace sherpa_onnx::cxx; + using namespace sherpa_onnx::cxx; // NOLINT OfflineRecognizerConfig config; config.model_config.whisper.encoder = diff --git a/scripts/check_style_cpplint.sh b/scripts/check_style_cpplint.sh index eedc9afc1..ea419242a 100755 --- a/scripts/check_style_cpplint.sh +++ b/scripts/check_style_cpplint.sh @@ -71,6 +71,9 @@ function is_source_code_file() { } function check_style() { + if [[ $1 == mfc-example* ]]; then + return + fi python3 $cpplint_src $1 || abort $1 } @@ -99,7 +102,7 @@ function do_check() { ;; 2) echo "Check all files" - files=$(find $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc") + files=$(find $sherpa_onnx_dir/cxx-api-examples $sherpa_onnx_dir/c-api-examples $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc") ;; *) echo "Check last commit" diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index aaedc3be0..62a663706 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -5,6 +5,7 @@ #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ #include +#include #include #include #include @@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl } auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); + + // The embedding model may output NaN. valid_indexes contains indexes + // in chunk_speaker_samples_list_pair.second that don't lead to + // NaN embeddings. + std::vector valid_indexes; + valid_indexes.reserve(chunk_speaker_samples_list_pair.second.size()); + Matrix2D embeddings = ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, - std::move(callback), callback_arg); + &valid_indexes, std::move(callback), callback_arg); + + if (valid_indexes.size() != chunk_speaker_samples_list_pair.second.size()) { + std::vector chunk_speaker_pair; + std::vector> sample_indexes; + + chunk_speaker_pair.reserve(valid_indexes.size()); + sample_indexes.reserve(valid_indexes.size()); + for (auto i : valid_indexes) { + chunk_speaker_pair.push_back(chunk_speaker_samples_list_pair.first[i]); + sample_indexes.push_back( + std::move(chunk_speaker_samples_list_pair.second[i])); + } + + chunk_speaker_samples_list_pair.first = std::move(chunk_speaker_pair); + chunk_speaker_samples_list_pair.second = std::move(sample_indexes); + } std::vector cluster_labels = clustering_->Cluster( &embeddings(0, 0), embeddings.rows(), embeddings.cols()); @@ -431,6 +455,7 @@ class OfflineSpeakerDiarizationPyannoteImpl Matrix2D ComputeEmbeddings( const float *audio, int32_t n, const std::vector> &sample_indexes, + std::vector *valid_indexes, OfflineSpeakerDiarizationProgressCallback callback, void *callback_arg) const { const auto &meta_data = segmentation_model_.GetModelMetaData(); @@ -438,6 +463,7 @@ class OfflineSpeakerDiarizationPyannoteImpl Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim()); int32_t k = 0; + int32_t cur_row_index = 0; for (const auto &v : sample_indexes) { auto stream = embedding_extractor_.CreateStream(); for (const auto &p : v) { @@ -459,7 +485,16 @@ class OfflineSpeakerDiarizationPyannoteImpl std::vector embedding = embedding_extractor_.Compute(stream.get()); - std::copy(embedding.begin(), embedding.end(), &ans(k, 0)); + float sum = + Eigen::Map(embedding.data(), embedding.size()) + .sum(); + + if (!isnan(sum)) { + // a valid embedding + std::copy(embedding.begin(), embedding.end(), &ans(cur_row_index, 0)); + cur_row_index += 1; + valid_indexes->push_back(k); + } k += 1; diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h index 66ad15af3..7e0883085 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h @@ -122,7 +122,7 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl { auto variance = EX2 - EX.array().pow(2); auto stddev = variance.array().sqrt(); - m = (m.rowwise() - EX).array().rowwise() / stddev.array(); + m = (m.rowwise() - EX).array().rowwise() / (stddev.array() + 1e-5); } private: