Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[http] Update http server #189

Merged
merged 1 commit into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions runtime/core/bin/http_server_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,69 @@
#include "glog/logging.h"

#include "http/http_server.h"
#include "utils/log.h"
#include "processor/wetext_processor.h"

#include "frontend/g2p_en.h"
#include "frontend/g2p_prosody.h"
#include "frontend/wav.h"
#include "model/tts_model.h"
#include "utils/string.h"

// Flags
DEFINE_string(frontend_flags, "", "frontend flags file");
DEFINE_string(vits_flags, "", "vits flags file");

// Text Normalization
DEFINE_string(tagger, "", "tagger fst file");
DEFINE_string(verbalizer, "", "verbalizer fst file");

// Tokenizer
DEFINE_string(vocab, "", "tokenizer vocab file");

// G2P for English
DEFINE_string(cmudict, "", "cmudict for english words");
DEFINE_string(g2p_en_model, "", "english g2p fst model for oov");
DEFINE_string(g2p_en_sym, "", "english g2p symbol table for oov");

// G2P for Chinese
DEFINE_string(char2pinyin, "", "chinese character to pinyin");
DEFINE_string(pinyin2id, "", "pinyin to id");
DEFINE_string(pinyin2phones, "", "pinyin to phones");
DEFINE_string(g2p_prosody_model, "", "g2p prosody model file");

// VITS
DEFINE_string(speaker2id, "", "speaker to id");
DEFINE_string(phone2id, "", "phone to id");
DEFINE_string(vits_model, "", "e2e tts model file");
DEFINE_int32(sampling_rate, 22050, "sampling rate of pcm");


// port
DEFINE_int32(port, 10086, "http listening port");

int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
gflags::ReadFromFlagsFile(FLAGS_frontend_flags, "", false);
gflags::ReadFromFlagsFile(FLAGS_vits_flags, "", false);

auto tn = std::make_shared<wetext::Processor>(FLAGS_tagger, FLAGS_verbalizer);

bool has_en = !FLAGS_g2p_en_model.empty() && !FLAGS_g2p_en_sym.empty() &&
!FLAGS_g2p_en_sym.empty();
std::shared_ptr<wetts::G2pEn> g2p_en =
has_en ? std::make_shared<wetts::G2pEn>(FLAGS_cmudict, FLAGS_g2p_en_model,
FLAGS_g2p_en_sym)
: nullptr;

auto g2p_prosody = std::make_shared<wetts::G2pProsody>(
FLAGS_g2p_prosody_model, FLAGS_vocab, FLAGS_char2pinyin, FLAGS_pinyin2id,
FLAGS_pinyin2phones);
auto tts_model = std::make_shared<wetts::TtsModel>(
FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, tn, g2p_prosody);
FLAGS_pinyin2phones, g2p_en);
auto model = std::make_shared<wetts::TtsModel>(
FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, FLAGS_sampling_rate,
tn, g2p_prosody);

wetts::HttpServer server(FLAGS_port, tts_model);
wetts::HttpServer server(FLAGS_port, model);
LOG(INFO) << "Listening at port " << FLAGS_port;
server.Start();
return 0;
Expand Down
3 changes: 2 additions & 1 deletion runtime/core/bin/tts_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ int main(int argc, char* argv[]) {
FLAGS_g2p_prosody_model, FLAGS_vocab, FLAGS_char2pinyin, FLAGS_pinyin2id,
FLAGS_pinyin2phones, g2p_en);
auto model = std::make_shared<wetts::TtsModel>(
FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, tn, g2p_prosody);
FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, FLAGS_sampling_rate,
tn, g2p_prosody);

std::vector<float> audio;
int sid = model->GetSid(FLAGS_sname);
Expand Down
61 changes: 39 additions & 22 deletions runtime/core/http/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,25 @@

#include "frontend/wav.h"
#include "utils/string.h"
#include "utils/timer.h"

namespace wetts {

namespace urls = boost::urls;
namespace uuids = boost::uuids;

http::message_generator ConnectionHandler::handle_request(
const std::string& wav_path) {
// Attempt to open the file
http::message_generator ConnectionHandler::HandleRequest(
char* wav_data, int data_size) {
beast::error_code ec;
http::file_body::value_type body;
body.open(wav_path.c_str(), beast::file_mode::scan, ec);

// Cache the size since we need it after the move
auto const size = body.size();
// Respond to GET request
http::response<http::file_body> res{
std::piecewise_construct, std::make_tuple(std::move(body)),
std::make_tuple(http::status::ok, request_.version())};
http::response<http::buffer_body> res;
res.result(http::status::ok);
res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
res.set(http::field::content_type, "audio/wav");
res.content_length(size);
res.keep_alive(request_.keep_alive());
res.body().data = wav_data;
res.body().size = data_size;
res.body().more = false;
res.prepare_payload();
return res;
}

Expand Down Expand Up @@ -80,16 +76,37 @@ void ConnectionHandler::operator()() {
std::string name = (*params.find("name")).value;
int sid = tts_model_->GetSid(name);
// 2. Synthesis audio from text
std::vector<float> audio;
tts_model_->Synthesis(text, sid, &audio);
wetts::WavWriter wav_writer(audio.data(), audio.size(), 1, 22050, 16);
// 3. Write samples to file named uuid.wav
std::string wav_path =
uuids::to_string(uuids::random_generator()()) + ".wav";
wav_writer.Write(wav_path);

int sample_rate = tts_model_->sampling_rate();
int num_channels = 1;
int bits_per_sample = 16;
LOG(INFO) << "Sample rate: " << sample_rate;
LOG(INFO) << "Num of channels: " << num_channels;
LOG(INFO) << "Bit per sample: " << bits_per_sample;
int extract_time = 0;
wetts::Timer timer;
std::vector<float> pcm;
tts_model_->Synthesis(text, sid, &pcm);
int pcm_size = pcm.size();
extract_time = timer.Elapsed();
LOG(INFO) << "TTS pcm duration: "
<< pcm_size * 1000 / num_channels / sample_rate << "ms";
LOG(INFO) << "Cost time: " << static_cast<float>(extract_time) << "ms";
// 3. Convert pcm to wav
std::vector<int16_t> audio(pcm_size);
for (int i = 0; i < pcm_size; ++i) {
audio[i] = static_cast<int16_t>(pcm[i]);
}
int audio_size = pcm_size * sizeof(int16_t);
int data_size = audio_size + 44;
WavHeader header(pcm_size, num_channels, sample_rate, bits_per_sample);
std::vector<char> wav_data;
wav_data.insert(wav_data.end(), reinterpret_cast<char*>(&header),
reinterpret_cast<char*>(&header) + 44);
wav_data.insert(wav_data.end(), reinterpret_cast<char*>(audio.data()),
reinterpret_cast<char*>(audio.data()) + audio_size);
// Handle request
http::message_generator msg = handle_request(wav_path);
http::message_generator msg =
HandleRequest(wav_data.data(), data_size);
// Determine if we should close the connection
bool keep_alive = msg.keep_alive();
// Send the response
Expand Down
2 changes: 1 addition & 1 deletion runtime/core/http/http_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ConnectionHandler {
ConnectionHandler(tcp::socket&& socket, std::shared_ptr<TtsModel> tts_model)
: socket_(std::move(socket)), tts_model_(std::move(tts_model)) {}
void operator()();
http::message_generator handle_request(const std::string& wav_path);
http::message_generator HandleRequest(char* wav_data, int data_size);

private:
tcp::socket socket_;
Expand Down
3 changes: 2 additions & 1 deletion runtime/core/model/tts_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
namespace wetts {

TtsModel::TtsModel(const std::string& model_path, const std::string& speaker2id,
const std::string& phone2id,
const std::string& phone2id, const int sampling_rate,
std::shared_ptr<wetext::Processor> tn,
std::shared_ptr<G2pProsody> g2p_prosody)
: OnnxModel(model_path),
tn_(std::move(tn)),
g2p_prosody_(std::move(g2p_prosody)) {
sampling_rate_ = sampling_rate;
ReadTableFile(phone2id, &phone2id_);
ReadTableFile(speaker2id, &speaker2id_);
}
Expand Down
5 changes: 4 additions & 1 deletion runtime/core/model/tts_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ namespace wetts {
class TtsModel : public OnnxModel {
public:
explicit TtsModel(const std::string& model_path,
const std::string& speaker2id, const std::string& phone2id,
const std::string& speaker2id,
const std::string& phone2id,
const int sampling_rate,
std::shared_ptr<wetext::Processor> processor,
std::shared_ptr<G2pProsody> g2p_prosody);
void Forward(const std::vector<int64_t>& phonemes, const int sid,
std::vector<float>* audio);
void Synthesis(const std::string& text, const int sid,
std::vector<float>* audio);
int GetSid(const std::string& name);
int sampling_rate() const { return sampling_rate_; }

private:
int sampling_rate_;
Expand Down
39 changes: 39 additions & 0 deletions runtime/core/utils/timer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef UTILS_TIMER_H_
#define UTILS_TIMER_H_

#include <chrono>

namespace wetts {

class Timer {
public:
Timer() : time_start_(std::chrono::steady_clock::now()) {}
void Reset() { time_start_ = std::chrono::steady_clock::now(); }
// return int in milliseconds
int Elapsed() const {
auto time_now = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(time_now -
time_start_)
.count();
}

private:
std::chrono::time_point<std::chrono::steady_clock> time_start_;
};
} // namespace wetts

#endif // UTILS_TIMER_H_
Loading