Skip to content

Commit

Permalink
Support non-streaming WeNet CTC models. (#426)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Nov 15, 2023
1 parent d34640e commit b83b3e3
Show file tree
Hide file tree
Showing 21 changed files with 469 additions and 32 deletions.
41 changes: 41 additions & 0 deletions .github/scripts/test-offline-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,47 @@ echo "PATH: $PATH"

which $EXE

log "------------------------------------------------------------"
log "Run Wenet models"
log "------------------------------------------------------------"
wenet_models=(
sherpa-onnx-zh-wenet-aishell
sherpa-onnx-zh-wenet-aishell2
sherpa-onnx-zh-wenet-wenetspeech
sherpa-onnx-zh-wenet-multi-cn
sherpa-onnx-en-wenet-librispeech
sherpa-onnx-en-wenet-gigaspeech
)
for name in ${wenet_models[@]}; do
repo_url=https://huggingface.co/csukuangfj/$name
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd

log "test float32 models"
time $EXE \
--tokens=$repo/tokens.txt \
--wenet-ctc-model=$repo/model.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

log "test int8 models"
time $EXE \
--tokens=$repo/tokens.txt \
--wenet-ctc-model=$repo/model.int8.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

rm -rf $repo
done

log "------------------------------------------------------------"
log "Run tdnn yesno (Hebrew)"
log "------------------------------------------------------------"
Expand Down
11 changes: 0 additions & 11 deletions .github/workflows/export-wenet-to-onnx.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
name: export-wenet-to-onnx

on:
push:
branches:
- master
paths:
- 'scripts/wenet/**'
- '.github/workflows/export-wenet-to-onnx.yaml'
pull_request:
paths:
- 'scripts/wenet/**'
- '.github/workflows/export-wenet-to-onnx.yaml'

workflow_dispatch:

concurrency:
Expand Down
16 changes: 8 additions & 8 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
- name: Test offline CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
- name: Test offline TTS
shell: bash
run: |
Expand All @@ -115,14 +123,6 @@ jobs:
.github/scripts/test-offline-whisper.sh
- name: Test offline CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
- name: Test offline transducer
shell: bash
run: |
Expand Down
3 changes: 2 additions & 1 deletion scripts/wenet/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def main():
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
url = os.environ.get("WENET_URL", "")
meta_data = {
"model_type": "wenet-ctc",
"model_type": "wenet_ctc",
"version": "1",
"model_author": "wenet",
"comment": "streaming",
Expand All @@ -185,6 +185,7 @@ def main():
"cnn_module_kernel": cnn_module_kernel,
"right_context": right_context,
"subsampling_factor": subsampling_factor,
"vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
}
add_meta_data(filename=filename, meta_data=meta_data)

Expand Down
4 changes: 3 additions & 1 deletion scripts/wenet/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ def main():
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
url = os.environ.get("WENET_URL", "")
meta_data = {
"model_type": "wenet-ctc",
"model_type": "wenet_ctc",
"version": "1",
"model_author": "wenet",
"comment": "non-streaming",
"subsampling_factor": torch_model.encoder.embed.subsampling_rate,
"vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
"url": url,
}
add_meta_data(filename=filename, meta_data=meta_data)
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ set(sources
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
offline-wenet-ctc-model-config.cc
offline-wenet-ctc-model.cc
offline-whisper-greedy-search-decoder.cc
offline-whisper-model-config.cc
offline-whisper-model.cc
Expand Down
19 changes: 18 additions & 1 deletion sherpa-onnx/csrc/offline-ctc-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"

Expand All @@ -21,10 +22,11 @@ enum class ModelType {
kEncDecCTCModelBPE,
kTdnn,
kZipformerCtc,
kWenetCtc,
kUnkown,
};

}
} // namespace

namespace sherpa_onnx {

Expand Down Expand Up @@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"If you are using models from NeMo, please refer to\n"
"https://huggingface.co/csukuangfj/"
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
"If you are using models from WeNet, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
"run.sh\n"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnkown;
Expand All @@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kTdnn;
} else if (model_type.get() == std::string("zipformer2_ctc")) {
return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) {
return ModelType::kWenetCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
Expand All @@ -80,6 +87,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.tdnn.model;
} else if (!config.zipformer_ctc.model.empty()) {
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
Expand All @@ -101,6 +110,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kZipformerCtc:
return std::make_unique<OfflineZipformerCtcModel>(config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
Expand All @@ -122,6 +134,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.tdnn.model;
} else if (!config.zipformer_ctc.model.empty()) {
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
Expand All @@ -143,6 +157,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kZipformerCtc:
return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/offline-ctc-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class OfflineCtcModel {
* for the features.
*/
virtual std::string FeatureNormalizationMethod() const { return {}; }

// Return true if the model supports batch size > 1
virtual bool SupportBatchProcessing() const { return true; }
};

} // namespace sherpa_onnx
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/offline-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
whisper.Register(po);
tdnn.Register(po);
zipformer_ctc.Register(po);
wenet_ctc.Register(po);

po->Register("tokens", &tokens, "Path to tokens.txt");

Expand Down Expand Up @@ -67,6 +68,10 @@ bool OfflineModelConfig::Validate() const {
return zipformer_ctc.Validate();
}

if (!wenet_ctc.model.empty()) {
return wenet_ctc.Validate();
}

return transducer.Validate();
}

Expand All @@ -80,6 +85,7 @@ std::string OfflineModelConfig::ToString() const {
os << "whisper=" << whisper.ToString() << ", ";
os << "tdnn=" << tdnn.ToString() << ", ";
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"

Expand All @@ -22,6 +23,7 @@ struct OfflineModelConfig {
OfflineWhisperModelConfig whisper;
OfflineTdnnModelConfig tdnn;
OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc;

std::string tokens;
int32_t num_threads = 2;
Expand All @@ -46,6 +48,7 @@ struct OfflineModelConfig {
const OfflineWhisperModelConfig &whisper,
const OfflineTdnnModelConfig &tdnn,
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
Expand All @@ -54,6 +57,7 @@ struct OfflineModelConfig {
whisper(whisper),
tdnn(tdnn),
zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc),
tokens(tokens),
num_threads(num_threads),
debug(debug),
Expand Down
55 changes: 53 additions & 2 deletions sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
#endif

void Init() {
if (!config_.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}

config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();

Expand All @@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
config_.ctc_fst_decoder_config);
} else if (config_.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>") &&
!symbol_table_.contains("<eps>")) {
!symbol_table_.contains("<eps>") &&
!symbol_table_.contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> and its ID.");
"the symbol <blk> or <eps> or <blank> and its ID.");
exit(-1);
}

Expand All @@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
} else if (symbol_table_.contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = symbol_table_["<eps>"];
} else if (symbol_table_.contains("<blank>")) {
// for Wenet CTC models
blank_id = symbol_table_["<blank>"];
}

decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
Expand All @@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}

void DecodeStreams(OfflineStream **ss, int32_t n) const override {
if (!model_->SupportBatchProcessing()) {
// If the model does not support batch process,
// we process each stream independently.
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
return;
}

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

Expand Down Expand Up @@ -164,6 +183,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
}

private:
// Decode a single stream.
// Some models do not support batch size > 1, e.g., WeNet CTC models.
void DecodeStream(OfflineStream *s) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

int32_t feat_dim = config_.feat_config.feature_dim;
std::vector<float> f = s->GetFrames();

int32_t num_frames = f.size() / feat_dim;

std::array<int64_t, 3> shape = {1, num_frames, feat_dim};

Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());

int64_t x_length_scalar = num_frames;
std::array<int64_t, 1> x_length_shape = {1};
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
x_length_shape.data(), x_length_shape.size());

auto t = model_->Forward(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
int32_t frame_shift_ms = 10;

auto r = Convert(results[0], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
s->SetResult(r);
}

private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
Expand Down
Loading

0 comments on commit b83b3e3

Please sign in to comment.