diff --git a/examples/KWS_ASR/KWS_ASR.ino b/examples/KWS_ASR/KWS_ASR.ino index 0633850..93d4506 100644 --- a/examples/KWS_ASR/KWS_ASR.ino +++ b/examples/KWS_ASR/KWS_ASR.ino @@ -54,7 +54,9 @@ void setup() /* Setup ASR module and save returned work id */ M5.Display.printf(">> Setup asr..\n"); - asr_work_id = module_llm.asr.setup(); + m5_module_llm::ApiAsrSetupConfig_t asr_config; + asr_config.input = {"sys.pcm", kws_work_id}; + asr_work_id = module_llm.asr.setup(asr_config); M5.Display.printf(">> Setup ok\n>> Say \"%s\" to wakeup\n", wake_up_keyword.c_str()); } diff --git a/src/M5ModuleLLM.h b/src/M5ModuleLLM.h index 9b2a113..31f2a72 100644 --- a/src/M5ModuleLLM.h +++ b/src/M5ModuleLLM.h @@ -11,6 +11,7 @@ #include "api/api_llm.h" #include "api/api_audio.h" #include "api/api_tts.h" +#include "api/api_melotts.h" #include "api/api_kws.h" #include "api/api_asr.h" @@ -63,6 +64,12 @@ class M5ModuleLLM { */ m5_module_llm::ApiTts tts; + /** + * @brief MELOTTS module api set + * + */ + m5_module_llm::ApiMelotts melotts; + /** * @brief KWS module api set * @@ -163,6 +170,7 @@ class M5ModuleLLM_VoiceAssistant { String asr; String llm; String tts; + String melotts; }; WorkId_t _work_id; diff --git a/src/api/api_asr.cpp b/src/api/api_asr.cpp index 85283eb..43a86e9 100644 --- a/src/api/api_asr.cpp +++ b/src/api/api_asr.cpp @@ -23,12 +23,14 @@ String ApiAsr::setup(ApiAsrSetupConfig_t config, String request_id) doc["object"] = "asr.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enkws"] = config.enkws; - doc["data"]["rule1"] = config.rule1; - doc["data"]["rule2"] = config.rule2; - doc["data"]["rule3"] = config.rule3; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) { + inputArray.add(str); + } + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["rule1"] = config.rule1; + doc["data"]["rule2"] = config.rule2; + doc["data"]["rule3"] = config.rule3; serializeJson(doc, cmd); } diff --git a/src/api/api_asr.h b/src/api/api_asr.h index 87499f8..7924850 100644 --- a/src/api/api_asr.h +++ b/src/api/api_asr.h @@ -10,14 +10,13 @@ namespace m5_module_llm { struct ApiAsrSetupConfig_t { - String model = "sherpa-ncnn-streaming-zipformer-20M-2023-02-17"; - String response_format = "asr.utf-8.stream"; - String input = "sys.pcm"; - bool enoutput = true; - bool enkws = true; - float rule1 = 2.4; - float rule2 = 1.2; - float rule3 = 30.0; + String model = "sherpa-ncnn-streaming-zipformer-20M-2023-02-17"; + String response_format = "asr.utf-8.stream"; + std::vector input = {"sys.pcm"}; + bool enoutput = true; + float rule1 = 2.4; + float rule2 = 1.2; + float rule3 = 30.0; }; class ApiAsr { diff --git a/src/api/api_kws.cpp b/src/api/api_kws.cpp index 47f2d32..1431732 100644 --- a/src/api/api_kws.cpp +++ b/src/api/api_kws.cpp @@ -23,9 +23,12 @@ String ApiKws::setup(ApiKwsSetupConfig_t config, String request_id) doc["object"] = "kws.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["kws"] = config.kws; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) { + inputArray.add(str); + } + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["kws"] = config.kws; serializeJson(doc, cmd); } diff --git a/src/api/api_kws.h b/src/api/api_kws.h index 2e09553..b1e21e4 100644 --- a/src/api/api_kws.h +++ b/src/api/api_kws.h @@ -10,11 +10,11 @@ namespace m5_module_llm { struct ApiKwsSetupConfig_t { - String kws = "HELLO"; - String model = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"; - String response_format = "kws.bool"; - String input = "sys.pcm"; - bool enoutput = true; + String kws = "HELLO"; + String model = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"; + String response_format = "kws.bool"; + std::vector input = {"sys.pcm"}; + bool enoutput = true; }; class ApiKws { diff --git a/src/api/api_llm.cpp b/src/api/api_llm.cpp index e2f54ac..af5c539 100644 --- a/src/api/api_llm.cpp +++ b/src/api/api_llm.cpp @@ -23,11 +23,13 @@ String ApiLlm::setup(ApiLlmSetupConfig_t config, String request_id) doc["object"] = "llm.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enkws"] = config.enkws; - doc["data"]["max_token_len"] = config.max_token_len; - doc["data"]["prompt"] = config.prompt; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) { + inputArray.add(str); + } + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["max_token_len"] = config.max_token_len; + doc["data"]["prompt"] = config.prompt; serializeJson(doc, cmd); } diff --git a/src/api/api_llm.h b/src/api/api_llm.h index 32c7eef..28d57d4 100644 --- a/src/api/api_llm.h +++ b/src/api/api_llm.h @@ -11,12 +11,12 @@ namespace m5_module_llm { struct ApiLlmSetupConfig_t { String prompt; - String model = "qwen2.5-0.5b"; - String response_format = "llm.utf-8.stream"; - String input = "llm.utf-8.stream"; - bool enoutput = true; - bool enkws = true; - int max_token_len = 127; + String model = "qwen2.5-0.5B-prefill-20e"; + String response_format = "llm.utf-8.stream"; + std::vector input = {"llm.utf-8.stream"}; + bool enoutput = true; + int max_token_len = 127; + // int max_token_len = 512; }; class ApiLlm { diff --git a/src/api/api_melotts.cpp b/src/api/api_melotts.cpp new file mode 100644 index 0000000..9fafe4b --- /dev/null +++ b/src/api/api_melotts.cpp @@ -0,0 +1,75 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#include "api_melotts.h" + +using namespace m5_module_llm; + +void ApiMelotts::init(ModuleMsg* moduleMsg) +{ + _module_msg = moduleMsg; +} + +String ApiMelotts::setup(ApiMelottsSetupConfig_t config, String request_id) +{ + String cmd; + { + JsonDocument doc; + doc["request_id"] = request_id; + doc["work_id"] = "melotts"; + doc["action"] = "setup"; + doc["object"] = "melotts.setup"; + doc["data"]["model"] = config.model; + doc["data"]["response_format"] = config.response_format; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) { + inputArray.add(str); + } + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["enaudio"] = config.enaudio; + serializeJson(doc, cmd); + } + + String work_id; + _module_msg->sendCmdAndWaitToTakeMsg( + cmd.c_str(), request_id, + [&work_id](ResponseMsg_t& msg) { + // Copy work id + work_id = msg.work_id; + }, + 10000); + return work_id; +} + +int ApiMelotts::inference(String work_id, String input, uint32_t timeout, String request_id) +{ + String cmd; + { + JsonDocument doc; + doc["request_id"] = request_id; + doc["work_id"] = work_id; + doc["action"] = "inference"; + doc["object"] = "melotts.utf-8.stream"; + doc["data"]["delta"] = input; + doc["data"]["index"] = 0; + doc["data"]["finish"] = true; + serializeJson(doc, cmd); + } + + if (timeout == 0) { + _module_msg->sendCmd(cmd.c_str()); + return MODULE_LLM_OK; + } + + int ret = MODULE_LLM_WAIT_RESPONSE_TIMEOUT; + _module_msg->sendCmdAndWaitToTakeMsg( + cmd.c_str(), request_id, + [&ret](ResponseMsg_t& msg) { + // Copy error code + ret = msg.error.code; + }, + timeout); + return ret; +} diff --git a/src/api/api_melotts.h b/src/api/api_melotts.h new file mode 100644 index 0000000..1550b01 --- /dev/null +++ b/src/api/api_melotts.h @@ -0,0 +1,48 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#pragma once +#include "../utils/msg.h" +#include + +namespace m5_module_llm { + +struct ApiMelottsSetupConfig_t { + String model = "melotts_zh-cn"; + String response_format = "sys.pcm"; + std::vector input = {"sys.pcm"}; + bool enoutput = false; + bool enaudio = true; +}; + +class ApiMelotts { +public: + void init(ModuleMsg* moduleMsg); + + /** + * @brief Setup module TTS, return TTS work_id + * + * @param config + * @param request_id + * @return String + */ + String setup(ApiMelottsSetupConfig_t config = ApiMelottsSetupConfig_t(), String request_id = "tts_setup"); + + /** + * @brief Inference input data by TTS module + * + * @param work_id + * @param input + * @param timeout wait response timeout, default 0 (do not wait response) + * @param request_id + * @return int + */ + int inference(String work_id, String input, uint32_t timeout = 0, String request_id = "tts_inference"); + +private: + ModuleMsg* _module_msg = nullptr; +}; + +} // namespace m5_module_llm diff --git a/src/api/api_tts.cpp b/src/api/api_tts.cpp index 5dd8877..5cfa2e6 100644 --- a/src/api/api_tts.cpp +++ b/src/api/api_tts.cpp @@ -23,9 +23,12 @@ String ApiTts::setup(ApiTtsSetupConfig_t config, String request_id) doc["object"] = "tts.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enkws"] = config.enkws; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) { + inputArray.add(str); + } + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["enaudio"] = config.enaudio; serializeJson(doc, cmd); } @@ -50,7 +53,7 @@ int ApiTts::inference(String work_id, String input, uint32_t timeout, String req doc["action"] = "inference"; doc["object"] = "tts.utf-8.stream"; doc["data"]["delta"] = input; - doc["data"]["index"] = 1; + doc["data"]["index"] = 0; doc["data"]["finish"] = true; serializeJson(doc, cmd); } diff --git a/src/api/api_tts.h b/src/api/api_tts.h index c3dd1b9..5ad601a 100644 --- a/src/api/api_tts.h +++ b/src/api/api_tts.h @@ -10,11 +10,11 @@ namespace m5_module_llm { struct ApiTtsSetupConfig_t { - String model = "single_speaker_english_fast"; - String response_format = "tts.base64.wav"; - String input = "tts.utf-8.stream"; - bool enoutput = true; - bool enkws = true; + String model = "single_speaker_english_fast"; + String response_format = "sys.pcm"; + std::vector input = {"sys.pcm"}; + bool enoutput = false; + bool enaudio = true; }; class ApiTts { diff --git a/src/presets/voice_assistant.cpp b/src/presets/voice_assistant.cpp index d34cc0d..7b1913c 100644 --- a/src/presets/voice_assistant.cpp +++ b/src/presets/voice_assistant.cpp @@ -37,7 +37,11 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt) } _debug("setup module asr.."); - _work_id.asr = _m5_module_llm->asr.setup(); + { + ApiAsrSetupConfig_t config; + config.input = {"sys.pcm", _work_id.kws}; + _work_id.asr = _m5_module_llm->asr.setup(config); + } if (_work_id.asr.isEmpty()) { return MODULE_LLM_ERROR_NONE; } @@ -45,7 +49,7 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt) _debug("setup module llm.."); { ApiLlmSetupConfig_t config; - config.input = _work_id.asr; + config.input = {_work_id.asr, _work_id.kws}; config.prompt = prompt; _work_id.llm = _m5_module_llm->llm.setup(config); } @@ -56,7 +60,7 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt) _debug("setup module tts.."); { ApiTtsSetupConfig_t config; - config.input = _work_id.llm; + config.input = {_work_id.llm, _work_id.kws}; _work_id.tts = _m5_module_llm->tts.setup(config); } if (_work_id.tts.isEmpty()) {