From 5f655be47a79f1745955a1de3f8c93f92d3eeef4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=9B=9B=E5=BC=BA?= Date: Sat, 13 Jan 2024 21:01:31 +0800 Subject: [PATCH] [http] Update http server --- runtime/core/bin/http_server_main.cc | 43 +++++++++++++++++--- runtime/core/bin/tts_main.cc | 3 +- runtime/core/http/http_server.cc | 61 ++++++++++++++++++---------- runtime/core/http/http_server.h | 2 +- runtime/core/model/tts_model.cc | 3 +- runtime/core/model/tts_model.h | 5 ++- runtime/core/utils/timer.h | 39 ++++++++++++++++++ 7 files changed, 125 insertions(+), 31 deletions(-) create mode 100644 runtime/core/utils/timer.h diff --git a/runtime/core/bin/http_server_main.cc b/runtime/core/bin/http_server_main.cc index d843e37..7609ffd 100644 --- a/runtime/core/bin/http_server_main.cc +++ b/runtime/core/bin/http_server_main.cc @@ -16,36 +16,69 @@ #include "glog/logging.h" #include "http/http_server.h" -#include "utils/log.h" +#include "processor/wetext_processor.h" +#include "frontend/g2p_en.h" +#include "frontend/g2p_prosody.h" +#include "frontend/wav.h" +#include "model/tts_model.h" +#include "utils/string.h" + +// Flags +DEFINE_string(frontend_flags, "", "frontend flags file"); +DEFINE_string(vits_flags, "", "vits flags file"); + +// Text Normalization DEFINE_string(tagger, "", "tagger fst file"); DEFINE_string(verbalizer, "", "verbalizer fst file"); +// Tokenizer DEFINE_string(vocab, "", "tokenizer vocab file"); +// G2P for English +DEFINE_string(cmudict, "", "cmudict for english words"); +DEFINE_string(g2p_en_model, "", "english g2p fst model for oov"); +DEFINE_string(g2p_en_sym, "", "english g2p symbol table for oov"); + +// G2P for Chinese DEFINE_string(char2pinyin, "", "chinese character to pinyin"); DEFINE_string(pinyin2id, "", "pinyin to id"); DEFINE_string(pinyin2phones, "", "pinyin to phones"); DEFINE_string(g2p_prosody_model, "", "g2p prosody model file"); +// VITS DEFINE_string(speaker2id, "", "speaker to id"); DEFINE_string(phone2id, "", "phone to id"); DEFINE_string(vits_model, "", "e2e tts model file"); +DEFINE_int32(sampling_rate, 22050, "sampling rate of pcm"); + +// port DEFINE_int32(port, 10086, "http listening port"); int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); + gflags::ReadFromFlagsFile(FLAGS_frontend_flags, "", false); + gflags::ReadFromFlagsFile(FLAGS_vits_flags, "", false); auto tn = std::make_shared(FLAGS_tagger, FLAGS_verbalizer); + + bool has_en = !FLAGS_g2p_en_model.empty() && !FLAGS_g2p_en_sym.empty() && + !FLAGS_g2p_en_sym.empty(); + std::shared_ptr g2p_en = + has_en ? std::make_shared(FLAGS_cmudict, FLAGS_g2p_en_model, + FLAGS_g2p_en_sym) + : nullptr; + auto g2p_prosody = std::make_shared( FLAGS_g2p_prosody_model, FLAGS_vocab, FLAGS_char2pinyin, FLAGS_pinyin2id, - FLAGS_pinyin2phones); - auto tts_model = std::make_shared( - FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, tn, g2p_prosody); + FLAGS_pinyin2phones, g2p_en); + auto model = std::make_shared( + FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, FLAGS_sampling_rate, + tn, g2p_prosody); - wetts::HttpServer server(FLAGS_port, tts_model); + wetts::HttpServer server(FLAGS_port, model); LOG(INFO) << "Listening at port " << FLAGS_port; server.Start(); return 0; diff --git a/runtime/core/bin/tts_main.cc b/runtime/core/bin/tts_main.cc index fd93eb9..5063e28 100644 --- a/runtime/core/bin/tts_main.cc +++ b/runtime/core/bin/tts_main.cc @@ -73,7 +73,8 @@ int main(int argc, char* argv[]) { FLAGS_g2p_prosody_model, FLAGS_vocab, FLAGS_char2pinyin, FLAGS_pinyin2id, FLAGS_pinyin2phones, g2p_en); auto model = std::make_shared( - FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, tn, g2p_prosody); + FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, FLAGS_sampling_rate, + tn, g2p_prosody); std::vector audio; int sid = model->GetSid(FLAGS_sname); diff --git a/runtime/core/http/http_server.cc b/runtime/core/http/http_server.cc index ed22ffe..696745b 100644 --- a/runtime/core/http/http_server.cc +++ b/runtime/core/http/http_server.cc @@ -25,29 +25,25 @@ #include "frontend/wav.h" #include "utils/string.h" +#include "utils/timer.h" namespace wetts { namespace urls = boost::urls; namespace uuids = boost::uuids; -http::message_generator ConnectionHandler::handle_request( - const std::string& wav_path) { - // Attempt to open the file +http::message_generator ConnectionHandler::HandleRequest( + char* wav_data, int data_size) { beast::error_code ec; - http::file_body::value_type body; - body.open(wav_path.c_str(), beast::file_mode::scan, ec); - - // Cache the size since we need it after the move - auto const size = body.size(); - // Respond to GET request - http::response res{ - std::piecewise_construct, std::make_tuple(std::move(body)), - std::make_tuple(http::status::ok, request_.version())}; + http::response res; + res.result(http::status::ok); res.set(http::field::server, BOOST_BEAST_VERSION_STRING); res.set(http::field::content_type, "audio/wav"); - res.content_length(size); res.keep_alive(request_.keep_alive()); + res.body().data = wav_data; + res.body().size = data_size; + res.body().more = false; + res.prepare_payload(); return res; } @@ -80,16 +76,37 @@ void ConnectionHandler::operator()() { std::string name = (*params.find("name")).value; int sid = tts_model_->GetSid(name); // 2. Synthesis audio from text - std::vector audio; - tts_model_->Synthesis(text, sid, &audio); - wetts::WavWriter wav_writer(audio.data(), audio.size(), 1, 22050, 16); - // 3. Write samples to file named uuid.wav - std::string wav_path = - uuids::to_string(uuids::random_generator()()) + ".wav"; - wav_writer.Write(wav_path); - + int sample_rate = tts_model_->sampling_rate(); + int num_channels = 1; + int bits_per_sample = 16; + LOG(INFO) << "Sample rate: " << sample_rate; + LOG(INFO) << "Num of channels: " << num_channels; + LOG(INFO) << "Bit per sample: " << bits_per_sample; + int extract_time = 0; + wetts::Timer timer; + std::vector pcm; + tts_model_->Synthesis(text, sid, &pcm); + int pcm_size = pcm.size(); + extract_time = timer.Elapsed(); + LOG(INFO) << "TTS pcm duration: " + << pcm_size * 1000 / num_channels / sample_rate << "ms"; + LOG(INFO) << "Cost time: " << static_cast(extract_time) << "ms"; + // 3. Convert pcm to wav + std::vector audio(pcm_size); + for (int i = 0; i < pcm_size; ++i) { + audio[i] = static_cast(pcm[i]); + } + int audio_size = pcm_size * sizeof(int16_t); + int data_size = audio_size + 44; + WavHeader header(pcm_size, num_channels, sample_rate, bits_per_sample); + std::vector wav_data; + wav_data.insert(wav_data.end(), reinterpret_cast(&header), + reinterpret_cast(&header) + 44); + wav_data.insert(wav_data.end(), reinterpret_cast(audio.data()), + reinterpret_cast(audio.data()) + audio_size); // Handle request - http::message_generator msg = handle_request(wav_path); + http::message_generator msg = + HandleRequest(wav_data.data(), data_size); // Determine if we should close the connection bool keep_alive = msg.keep_alive(); // Send the response diff --git a/runtime/core/http/http_server.h b/runtime/core/http/http_server.h index a57eb5d..30105eb 100644 --- a/runtime/core/http/http_server.h +++ b/runtime/core/http/http_server.h @@ -38,7 +38,7 @@ class ConnectionHandler { ConnectionHandler(tcp::socket&& socket, std::shared_ptr tts_model) : socket_(std::move(socket)), tts_model_(std::move(tts_model)) {} void operator()(); - http::message_generator handle_request(const std::string& wav_path); + http::message_generator HandleRequest(char* wav_data, int data_size); private: tcp::socket socket_; diff --git a/runtime/core/model/tts_model.cc b/runtime/core/model/tts_model.cc index bce430d..368aa06 100644 --- a/runtime/core/model/tts_model.cc +++ b/runtime/core/model/tts_model.cc @@ -26,12 +26,13 @@ namespace wetts { TtsModel::TtsModel(const std::string& model_path, const std::string& speaker2id, - const std::string& phone2id, + const std::string& phone2id, const int sampling_rate, std::shared_ptr tn, std::shared_ptr g2p_prosody) : OnnxModel(model_path), tn_(std::move(tn)), g2p_prosody_(std::move(g2p_prosody)) { + sampling_rate_ = sampling_rate; ReadTableFile(phone2id, &phone2id_); ReadTableFile(speaker2id, &speaker2id_); } diff --git a/runtime/core/model/tts_model.h b/runtime/core/model/tts_model.h index 16856ff..714c43e 100644 --- a/runtime/core/model/tts_model.h +++ b/runtime/core/model/tts_model.h @@ -30,7 +30,9 @@ namespace wetts { class TtsModel : public OnnxModel { public: explicit TtsModel(const std::string& model_path, - const std::string& speaker2id, const std::string& phone2id, + const std::string& speaker2id, + const std::string& phone2id, + const int sampling_rate, std::shared_ptr processor, std::shared_ptr g2p_prosody); void Forward(const std::vector& phonemes, const int sid, @@ -38,6 +40,7 @@ class TtsModel : public OnnxModel { void Synthesis(const std::string& text, const int sid, std::vector* audio); int GetSid(const std::string& name); + int sampling_rate() const { return sampling_rate_; } private: int sampling_rate_; diff --git a/runtime/core/utils/timer.h b/runtime/core/utils/timer.h new file mode 100644 index 0000000..4a626c6 --- /dev/null +++ b/runtime/core/utils/timer.h @@ -0,0 +1,39 @@ +// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_TIMER_H_ +#define UTILS_TIMER_H_ + +#include + +namespace wetts { + +class Timer { + public: + Timer() : time_start_(std::chrono::steady_clock::now()) {} + void Reset() { time_start_ = std::chrono::steady_clock::now(); } + // return int in milliseconds + int Elapsed() const { + auto time_now = std::chrono::steady_clock::now(); + return std::chrono::duration_cast(time_now - + time_start_) + .count(); + } + + private: + std::chrono::time_point time_start_; +}; +} // namespace wetts + +#endif // UTILS_TIMER_H_