Skip to content

Commit

Permalink
Support decoding with byte-level BPE (bbpe) models. (#1633)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Dec 20, 2024
1 parent 7192e57 commit b76cd90
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 10 deletions.
1 change: 1 addition & 0 deletions scripts/bbpe/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bbpe.cc
67 changes: 67 additions & 0 deletions scripts/bbpe/generate_bbpe_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_bpe.py#L28
# and
# https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py
#
# Caution: The PRINTABLE_LATIN from fairseq is different from PRINTABLE_BASE_CHARS from icefall

import re

BPE_UNK = chr(8263)
PRINTABLE_BASE_CHARS = (
list(range(256, 287 + 1))
+ list(range(32, 126 + 1))
+ list(range(288, 305 + 1))
+ list(range(308, 318 + 1))
+ list(range(321, 328 + 1))
+ list(range(330, 382 + 1))
+ list(range(384, 422 + 1))
)


BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
BCHAR_TO_BYTE[BPE_UNK] = 32 # map unk to space


def main():
s = ""
s += "// sherpa-onnx/csrc/bbpe.cc\n"
s += "//\n"
s += "// Copyright (c) 2024 Xiaomi Corporation\n"
s += "\n"
s += "// Auto-generated! DO NOT EDIT\n"
s += "\n"
s += '#include "sherpa-onnx/csrc/bbpe.h"\n'
s += "\n"
s += "#include <cstdint>\n"
s += "#include <string>\n"
s += "#include <unordered_map>\n"
s += "\n"
s += "const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() {\n"
s += " static const std::unordered_map<std::string, uint8_t> table = {\n"

s += " "
for i, (k, v) in enumerate(BCHAR_TO_BYTE.items()):
s += "{"
if k in ["\\", '"']:
s += f'"\{k}", {v}'
else:
s += f'"{k}", {v}'
s += "}, "
if i > 0 and i % 7 == 0:
s += "\n"
s += " "
s += "};\n"
s += "\n"
s += " return table\n;"
s += "}\n"

with open("bbpe.cc", "w", encoding="utf-8") as f:
f.write(s)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ endif()

set(sources
base64-decode.cc
bbpe.cc
cat.cc
circular-buffer.cc
context-graph.cc
Expand Down Expand Up @@ -78,11 +79,11 @@ set(sources
online-stream.cc
online-transducer-decoder.cc
online-transducer-greedy-search-decoder.cc
online-transducer-greedy-search-nemo-decoder.cc
online-transducer-model-config.cc
online-transducer-model.cc
online-transducer-modified-beam-search-decoder.cc
online-transducer-nemo-model.cc
online-transducer-greedy-search-nemo-decoder.cc
online-wenet-ctc-model-config.cc
online-wenet-ctc-model.cc
online-zipformer-transducer-model.cc
Expand Down
61 changes: 61 additions & 0 deletions sherpa-onnx/csrc/bbpe.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// sherpa-onnx/csrc/bbpe.cc
//
// Copyright (c) 2024 Xiaomi Corporation

// Auto-generated! DO NOT EDIT

#include "sherpa-onnx/csrc/bbpe.h"

#include <cstdint>
#include <string>
#include <unordered_map>

const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() {
static const std::unordered_map<std::string, uint8_t> table = {
{"Ā", 0}, {"ā", 1}, {"Ă", 2}, {"ă", 3}, {"Ą", 4}, {"ą", 5},
{"Ć", 6}, {"ć", 7}, {"Ĉ", 8}, {"ĉ", 9}, {"Ċ", 10}, {"ċ", 11},
{"Č", 12}, {"č", 13}, {"Ď", 14}, {"ď", 15}, {"Đ", 16}, {"đ", 17},
{"Ē", 18}, {"ē", 19}, {"Ĕ", 20}, {"ĕ", 21}, {"Ė", 22}, {"ė", 23},
{"Ę", 24}, {"ę", 25}, {"Ě", 26}, {"ě", 27}, {"Ĝ", 28}, {"ĝ", 29},
{"Ğ", 30}, {"ğ", 31}, {" ", 32}, {"!", 33}, {"\"", 34}, {"#", 35},
{"$", 36}, {"%", 37}, {"&", 38}, {"'", 39}, {"(", 40}, {")", 41},
{"*", 42}, {"+", 43}, {",", 44}, {"-", 45}, {".", 46}, {"/", 47},
{"0", 48}, {"1", 49}, {"2", 50}, {"3", 51}, {"4", 52}, {"5", 53},
{"6", 54}, {"7", 55}, {"8", 56}, {"9", 57}, {":", 58}, {";", 59},
{"<", 60}, {"=", 61}, {">", 62}, {"?", 63}, {"@", 64}, {"A", 65},
{"B", 66}, {"C", 67}, {"D", 68}, {"E", 69}, {"F", 70}, {"G", 71},
{"H", 72}, {"I", 73}, {"J", 74}, {"K", 75}, {"L", 76}, {"M", 77},
{"N", 78}, {"O", 79}, {"P", 80}, {"Q", 81}, {"R", 82}, {"S", 83},
{"T", 84}, {"U", 85}, {"V", 86}, {"W", 87}, {"X", 88}, {"Y", 89},
{"Z", 90}, {"[", 91}, {"\\", 92}, {"]", 93}, {"^", 94}, {"_", 95},
{"`", 96}, {"a", 97}, {"b", 98}, {"c", 99}, {"d", 100}, {"e", 101},
{"f", 102}, {"g", 103}, {"h", 104}, {"i", 105}, {"j", 106}, {"k", 107},
{"l", 108}, {"m", 109}, {"n", 110}, {"o", 111}, {"p", 112}, {"q", 113},
{"r", 114}, {"s", 115}, {"t", 116}, {"u", 117}, {"v", 118}, {"w", 119},
{"x", 120}, {"y", 121}, {"z", 122}, {"{", 123}, {"|", 124}, {"}", 125},
{"~", 126}, {"Ġ", 127}, {"ġ", 128}, {"Ģ", 129}, {"ģ", 130}, {"Ĥ", 131},
{"ĥ", 132}, {"Ħ", 133}, {"ħ", 134}, {"Ĩ", 135}, {"ĩ", 136}, {"Ī", 137},
{"ī", 138}, {"Ĭ", 139}, {"ĭ", 140}, {"Į", 141}, {"į", 142}, {"İ", 143},
{"ı", 144}, {"Ĵ", 145}, {"ĵ", 146}, {"Ķ", 147}, {"ķ", 148}, {"ĸ", 149},
{"Ĺ", 150}, {"ĺ", 151}, {"Ļ", 152}, {"ļ", 153}, {"Ľ", 154}, {"ľ", 155},
{"Ł", 156}, {"ł", 157}, {"Ń", 158}, {"ń", 159}, {"Ņ", 160}, {"ņ", 161},
{"Ň", 162}, {"ň", 163}, {"Ŋ", 164}, {"ŋ", 165}, {"Ō", 166}, {"ō", 167},
{"Ŏ", 168}, {"ŏ", 169}, {"Ő", 170}, {"ő", 171}, {"Œ", 172}, {"œ", 173},
{"Ŕ", 174}, {"ŕ", 175}, {"Ŗ", 176}, {"ŗ", 177}, {"Ř", 178}, {"ř", 179},
{"Ś", 180}, {"ś", 181}, {"Ŝ", 182}, {"ŝ", 183}, {"Ş", 184}, {"ş", 185},
{"Š", 186}, {"š", 187}, {"Ţ", 188}, {"ţ", 189}, {"Ť", 190}, {"ť", 191},
{"Ŧ", 192}, {"ŧ", 193}, {"Ũ", 194}, {"ũ", 195}, {"Ū", 196}, {"ū", 197},
{"Ŭ", 198}, {"ŭ", 199}, {"Ů", 200}, {"ů", 201}, {"Ű", 202}, {"ű", 203},
{"Ų", 204}, {"ų", 205}, {"Ŵ", 206}, {"ŵ", 207}, {"Ŷ", 208}, {"ŷ", 209},
{"Ÿ", 210}, {"Ź", 211}, {"ź", 212}, {"Ż", 213}, {"ż", 214}, {"Ž", 215},
{"ž", 216}, {"ƀ", 217}, {"Ɓ", 218}, {"Ƃ", 219}, {"ƃ", 220}, {"Ƅ", 221},
{"ƅ", 222}, {"Ɔ", 223}, {"Ƈ", 224}, {"ƈ", 225}, {"Ɖ", 226}, {"Ɗ", 227},
{"Ƌ", 228}, {"ƌ", 229}, {"ƍ", 230}, {"Ǝ", 231}, {"Ə", 232}, {"Ɛ", 233},
{"Ƒ", 234}, {"ƒ", 235}, {"Ɠ", 236}, {"Ɣ", 237}, {"ƕ", 238}, {"Ɩ", 239},
{"Ɨ", 240}, {"Ƙ", 241}, {"ƙ", 242}, {"ƚ", 243}, {"ƛ", 244}, {"Ɯ", 245},
{"Ɲ", 246}, {"ƞ", 247}, {"Ɵ", 248}, {"Ơ", 249}, {"ơ", 250}, {"Ƣ", 251},
{"ƣ", 252}, {"Ƥ", 253}, {"ƥ", 254}, {"Ʀ", 255}, {"", 32},
};

return table;
}
16 changes: 16 additions & 0 deletions sherpa-onnx/csrc/bbpe.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/csrc/bbpe.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_BBPE_H_
#define SHERPA_ONNX_CSRC_BBPE_H_
#include <cstdint>
#include <string>
#include <unordered_map>

// It is equivalent to the map BCHAR_TO_BYTE
// from
// https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py#L280
const std::unordered_map<std::string, uint8_t> &GetByteBpeTable();

#endif // SHERPA_ONNX_CSRC_BBPE_H_
7 changes: 6 additions & 1 deletion sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
text.append(sym);

if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
// for byte bpe models
// for bpe models with byte_fallback
// (but don't rewrite printable characters 0x20..0x7e,
// which collide with standard BPE units)
std::ostringstream os;
Expand All @@ -52,6 +52,11 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,

r.tokens.push_back(std::move(sym));
}

if (sym_table.IsByteBpe()) {
text = sym_table.DecodeByteBpe(text);
}

r.text = std::move(text);

float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
Expand Down
6 changes: 5 additions & 1 deletion sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static OfflineRecognitionResult Convert(
text.append(sym);

if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
// for byte bpe models,
// for bpe models with byte_fallback,
// (but don't rewrite printable characters 0x20..0x7e,
// which collide with standard BPE units)
std::ostringstream os;
Expand All @@ -54,6 +54,10 @@ static OfflineRecognitionResult Convert(

r.tokens.push_back(std::move(sym));
}
if (sym_table.IsByteBpe()) {
text = sym_table.DecodeByteBpe(text);
}

r.text = std::move(text);

float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
Expand Down
11 changes: 9 additions & 2 deletions sherpa-onnx/csrc/online-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());

std::string text;
for (auto i : src.tokens) {
auto sym = sym_table[i];

r.text.append(sym);
text.append(sym);

if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
// for byte bpe models
// for bpe models with byte_fallback
// (but don't rewrite printable characters 0x20..0x7e,
// which collide with standard BPE units)
std::ostringstream os;
Expand All @@ -52,6 +53,12 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
r.tokens.push_back(std::move(sym));
}

if (sym_table.IsByteBpe()) {
text = sym_table.DecodeByteBpe(text);
}

r.text = std::move(text);

float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (auto t : src.timestamps) {
float time = frame_shift_s * t;
Expand Down
11 changes: 9 additions & 2 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());

std::string text;
for (auto i : src.tokens) {
auto sym = sym_table[i];

r.text.append(sym);
text.append(sym);

if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
// for byte bpe models
// for bpe models with byte_fallback
// (but don't rewrite printable characters 0x20..0x7e,
// which collide with standard BPE units)
std::ostringstream os;
Expand All @@ -56,6 +57,12 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r.tokens.push_back(std::move(sym));
}

if (sym_table.IsByteBpe()) {
text = sym_table.DecodeByteBpe(text);
}

r.text = std::move(text);

float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (auto t : src.timestamps) {
float time = frame_shift_s * t;
Expand Down
Loading

0 comments on commit b76cd90

Please sign in to comment.