From a9406ae5d811ee7ad6cd72334f2f7a7a63b375af Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Wed, 20 Nov 2024 09:54:24 +0800 Subject: [PATCH 1/4] update supports stackflow. --- examples/KWS_ASR/KWS_ASR.ino | 4 +++- src/M5ModuleLLM.h | 8 ++++++++ src/api/api_asr.cpp | 7 +++++-- src/api/api_asr.h | 15 +++++++-------- src/api/api_kws.cpp | 6 +++++- src/api/api_kws.h | 10 +++++----- src/api/api_llm.cpp | 7 +++++-- src/api/api_llm.h | 12 ++++++------ src/api/api_tts.cpp | 10 +++++++--- src/api/api_tts.h | 10 +++++----- src/presets/voice_assistant.cpp | 10 +++++++--- 11 files changed, 63 insertions(+), 36 deletions(-) diff --git a/examples/KWS_ASR/KWS_ASR.ino b/examples/KWS_ASR/KWS_ASR.ino index 0633850..68792b7 100644 --- a/examples/KWS_ASR/KWS_ASR.ino +++ b/examples/KWS_ASR/KWS_ASR.ino @@ -54,7 +54,9 @@ void setup() /* Setup ASR module and save returned work id */ M5.Display.printf(">> Setup asr..\n"); - asr_work_id = module_llm.asr.setup(); + m5_module_llm::ApiAsrSetupConfig_t asr_config; + asr_config.input = {"sys.pcm", kws_work_id}; + asr_work_id = module_llm.asr.setup(asr_config); M5.Display.printf(">> Setup ok\n>> Say \"%s\" to wakeup\n", wake_up_keyword.c_str()); } diff --git a/src/M5ModuleLLM.h b/src/M5ModuleLLM.h index 9b2a113..8dec292 100644 --- a/src/M5ModuleLLM.h +++ b/src/M5ModuleLLM.h @@ -11,6 +11,7 @@ #include "api/api_llm.h" #include "api/api_audio.h" #include "api/api_tts.h" +#include "api/api_melotts.h" #include "api/api_kws.h" #include "api/api_asr.h" @@ -63,6 +64,12 @@ class M5ModuleLLM { */ m5_module_llm::ApiTts tts; + /** + * @brief MELOTTS module api set + * + */ + m5_module_llm::ApiMelotts melotts; + /** * @brief KWS module api set * @@ -163,6 +170,7 @@ class M5ModuleLLM_VoiceAssistant { String asr; String llm; String tts; + String melotts; }; WorkId_t _work_id; diff --git a/src/api/api_asr.cpp b/src/api/api_asr.cpp index 85283eb..27edd2a 100644 --- a/src/api/api_asr.cpp +++ b/src/api/api_asr.cpp @@ -23,9 +23,12 @@ String ApiAsr::setup(ApiAsrSetupConfig_t config, String request_id) doc["object"] = "asr.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) + { + inputArray.add(str); + } doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enkws"] = config.enkws; doc["data"]["rule1"] = config.rule1; doc["data"]["rule2"] = config.rule2; doc["data"]["rule3"] = config.rule3; diff --git a/src/api/api_asr.h b/src/api/api_asr.h index 87499f8..7924850 100644 --- a/src/api/api_asr.h +++ b/src/api/api_asr.h @@ -10,14 +10,13 @@ namespace m5_module_llm { struct ApiAsrSetupConfig_t { - String model = "sherpa-ncnn-streaming-zipformer-20M-2023-02-17"; - String response_format = "asr.utf-8.stream"; - String input = "sys.pcm"; - bool enoutput = true; - bool enkws = true; - float rule1 = 2.4; - float rule2 = 1.2; - float rule3 = 30.0; + String model = "sherpa-ncnn-streaming-zipformer-20M-2023-02-17"; + String response_format = "asr.utf-8.stream"; + std::vector input = {"sys.pcm"}; + bool enoutput = true; + float rule1 = 2.4; + float rule2 = 1.2; + float rule3 = 30.0; }; class ApiAsr { diff --git a/src/api/api_kws.cpp b/src/api/api_kws.cpp index 47f2d32..4f82361 100644 --- a/src/api/api_kws.cpp +++ b/src/api/api_kws.cpp @@ -23,7 +23,11 @@ String ApiKws::setup(ApiKwsSetupConfig_t config, String request_id) doc["object"] = "kws.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) + { + inputArray.add(str); + } doc["data"]["enoutput"] = config.enoutput; doc["data"]["kws"] = config.kws; serializeJson(doc, cmd); diff --git a/src/api/api_kws.h b/src/api/api_kws.h index 2e09553..b1e21e4 100644 --- a/src/api/api_kws.h +++ b/src/api/api_kws.h @@ -10,11 +10,11 @@ namespace m5_module_llm { struct ApiKwsSetupConfig_t { - String kws = "HELLO"; - String model = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"; - String response_format = "kws.bool"; - String input = "sys.pcm"; - bool enoutput = true; + String kws = "HELLO"; + String model = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"; + String response_format = "kws.bool"; + std::vector input = {"sys.pcm"}; + bool enoutput = true; }; class ApiKws { diff --git a/src/api/api_llm.cpp b/src/api/api_llm.cpp index e2f54ac..9005557 100644 --- a/src/api/api_llm.cpp +++ b/src/api/api_llm.cpp @@ -23,9 +23,12 @@ String ApiLlm::setup(ApiLlmSetupConfig_t config, String request_id) doc["object"] = "llm.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) + { + inputArray.add(str); + } doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enkws"] = config.enkws; doc["data"]["max_token_len"] = config.max_token_len; doc["data"]["prompt"] = config.prompt; serializeJson(doc, cmd); diff --git a/src/api/api_llm.h b/src/api/api_llm.h index 32c7eef..28d57d4 100644 --- a/src/api/api_llm.h +++ b/src/api/api_llm.h @@ -11,12 +11,12 @@ namespace m5_module_llm { struct ApiLlmSetupConfig_t { String prompt; - String model = "qwen2.5-0.5b"; - String response_format = "llm.utf-8.stream"; - String input = "llm.utf-8.stream"; - bool enoutput = true; - bool enkws = true; - int max_token_len = 127; + String model = "qwen2.5-0.5B-prefill-20e"; + String response_format = "llm.utf-8.stream"; + std::vector input = {"llm.utf-8.stream"}; + bool enoutput = true; + int max_token_len = 127; + // int max_token_len = 512; }; class ApiLlm { diff --git a/src/api/api_tts.cpp b/src/api/api_tts.cpp index 5dd8877..4b837d1 100644 --- a/src/api/api_tts.cpp +++ b/src/api/api_tts.cpp @@ -23,9 +23,13 @@ String ApiTts::setup(ApiTtsSetupConfig_t config, String request_id) doc["object"] = "tts.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) + { + inputArray.add(str); + } doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enkws"] = config.enkws; + doc["data"]["enaudio"] = config.enaudio; serializeJson(doc, cmd); } @@ -50,7 +54,7 @@ int ApiTts::inference(String work_id, String input, uint32_t timeout, String req doc["action"] = "inference"; doc["object"] = "tts.utf-8.stream"; doc["data"]["delta"] = input; - doc["data"]["index"] = 1; + doc["data"]["index"] = 0; doc["data"]["finish"] = true; serializeJson(doc, cmd); } diff --git a/src/api/api_tts.h b/src/api/api_tts.h index c3dd1b9..5ad601a 100644 --- a/src/api/api_tts.h +++ b/src/api/api_tts.h @@ -10,11 +10,11 @@ namespace m5_module_llm { struct ApiTtsSetupConfig_t { - String model = "single_speaker_english_fast"; - String response_format = "tts.base64.wav"; - String input = "tts.utf-8.stream"; - bool enoutput = true; - bool enkws = true; + String model = "single_speaker_english_fast"; + String response_format = "sys.pcm"; + std::vector input = {"sys.pcm"}; + bool enoutput = false; + bool enaudio = true; }; class ApiTts { diff --git a/src/presets/voice_assistant.cpp b/src/presets/voice_assistant.cpp index d34cc0d..7b1913c 100644 --- a/src/presets/voice_assistant.cpp +++ b/src/presets/voice_assistant.cpp @@ -37,7 +37,11 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt) } _debug("setup module asr.."); - _work_id.asr = _m5_module_llm->asr.setup(); + { + ApiAsrSetupConfig_t config; + config.input = {"sys.pcm", _work_id.kws}; + _work_id.asr = _m5_module_llm->asr.setup(config); + } if (_work_id.asr.isEmpty()) { return MODULE_LLM_ERROR_NONE; } @@ -45,7 +49,7 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt) _debug("setup module llm.."); { ApiLlmSetupConfig_t config; - config.input = _work_id.asr; + config.input = {_work_id.asr, _work_id.kws}; config.prompt = prompt; _work_id.llm = _m5_module_llm->llm.setup(config); } @@ -56,7 +60,7 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt) _debug("setup module tts.."); { ApiTtsSetupConfig_t config; - config.input = _work_id.llm; + config.input = {_work_id.llm, _work_id.kws}; _work_id.tts = _m5_module_llm->tts.setup(config); } if (_work_id.tts.isEmpty()) { From 4cb0d648d76fe9ada94ab03aa0492207e0c49be3 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Wed, 20 Nov 2024 10:38:21 +0800 Subject: [PATCH 2/4] add melotts --- examples/KWS_ASR/KWS_ASR.ino | 2 +- src/M5ModuleLLM.h | 6 +-- src/api/api_asr.cpp | 11 +++--- src/api/api_kws.cpp | 7 ++-- src/api/api_llm.cpp | 9 ++--- src/api/api_melotts.cpp | 75 ++++++++++++++++++++++++++++++++++++ src/api/api_melotts.h | 48 +++++++++++++++++++++++ src/api/api_sys.cpp | 3 +- src/api/api_tts.cpp | 7 ++-- 9 files changed, 143 insertions(+), 25 deletions(-) create mode 100644 src/api/api_melotts.cpp create mode 100644 src/api/api_melotts.h diff --git a/examples/KWS_ASR/KWS_ASR.ino b/examples/KWS_ASR/KWS_ASR.ino index 68792b7..93d4506 100644 --- a/examples/KWS_ASR/KWS_ASR.ino +++ b/examples/KWS_ASR/KWS_ASR.ino @@ -56,7 +56,7 @@ void setup() M5.Display.printf(">> Setup asr..\n"); m5_module_llm::ApiAsrSetupConfig_t asr_config; asr_config.input = {"sys.pcm", kws_work_id}; - asr_work_id = module_llm.asr.setup(asr_config); + asr_work_id = module_llm.asr.setup(asr_config); M5.Display.printf(">> Setup ok\n>> Say \"%s\" to wakeup\n", wake_up_keyword.c_str()); } diff --git a/src/M5ModuleLLM.h b/src/M5ModuleLLM.h index 8dec292..31f2a72 100644 --- a/src/M5ModuleLLM.h +++ b/src/M5ModuleLLM.h @@ -65,9 +65,9 @@ class M5ModuleLLM { m5_module_llm::ApiTts tts; /** - * @brief MELOTTS module api set - * - */ + * @brief MELOTTS module api set + * + */ m5_module_llm::ApiMelotts melotts; /** diff --git a/src/api/api_asr.cpp b/src/api/api_asr.cpp index 27edd2a..43a86e9 100644 --- a/src/api/api_asr.cpp +++ b/src/api/api_asr.cpp @@ -24,14 +24,13 @@ String ApiAsr::setup(ApiAsrSetupConfig_t config, String request_id) doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; JsonArray inputArray = doc["data"]["input"].to(); - for (const String& str : config.input) - { + for (const String& str : config.input) { inputArray.add(str); } - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["rule1"] = config.rule1; - doc["data"]["rule2"] = config.rule2; - doc["data"]["rule3"] = config.rule3; + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["rule1"] = config.rule1; + doc["data"]["rule2"] = config.rule2; + doc["data"]["rule3"] = config.rule3; serializeJson(doc, cmd); } diff --git a/src/api/api_kws.cpp b/src/api/api_kws.cpp index 4f82361..1431732 100644 --- a/src/api/api_kws.cpp +++ b/src/api/api_kws.cpp @@ -24,12 +24,11 @@ String ApiKws::setup(ApiKwsSetupConfig_t config, String request_id) doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; JsonArray inputArray = doc["data"]["input"].to(); - for (const String& str : config.input) - { + for (const String& str : config.input) { inputArray.add(str); } - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["kws"] = config.kws; + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["kws"] = config.kws; serializeJson(doc, cmd); } diff --git a/src/api/api_llm.cpp b/src/api/api_llm.cpp index 9005557..af5c539 100644 --- a/src/api/api_llm.cpp +++ b/src/api/api_llm.cpp @@ -24,13 +24,12 @@ String ApiLlm::setup(ApiLlmSetupConfig_t config, String request_id) doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; JsonArray inputArray = doc["data"]["input"].to(); - for (const String& str : config.input) - { + for (const String& str : config.input) { inputArray.add(str); } - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["max_token_len"] = config.max_token_len; - doc["data"]["prompt"] = config.prompt; + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["max_token_len"] = config.max_token_len; + doc["data"]["prompt"] = config.prompt; serializeJson(doc, cmd); } diff --git a/src/api/api_melotts.cpp b/src/api/api_melotts.cpp new file mode 100644 index 0000000..ad4a602 --- /dev/null +++ b/src/api/api_melotts.cpp @@ -0,0 +1,75 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#include "api_melotts.h" + +using namespace m5_module_llm; + +void ApiMelotts::init(ModuleMsg* moduleMsg) +{ + _module_msg = moduleMsg; +} + +String ApiMelotts::setup(ApiMelottsSetupConfig_t config, String request_id) +{ + String cmd; + { + JsonDocument doc; + doc["request_id"] = request_id; + doc["work_id"] = "melotts"; + doc["action"] = "setup"; + doc["object"] = "melotts.setup"; + doc["data"]["model"] = config.model; + doc["data"]["response_format"] = config.response_format; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) { + inputArray.add(str); + } + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["enoutput"] = config.enaudio; + serializeJson(doc, cmd); + } + + String work_id; + _module_msg->sendCmdAndWaitToTakeMsg( + cmd.c_str(), request_id, + [&work_id](ResponseMsg_t& msg) { + // Copy work id + work_id = msg.work_id; + }, + 10000); + return work_id; +} + +int ApiMelotts::inference(String work_id, String input, uint32_t timeout, String request_id) +{ + String cmd; + { + JsonDocument doc; + doc["request_id"] = request_id; + doc["work_id"] = work_id; + doc["action"] = "inference"; + doc["object"] = "melotts.utf-8.stream"; + doc["data"]["delta"] = input; + doc["data"]["index"] = 0; + doc["data"]["finish"] = true; + serializeJson(doc, cmd); + } + + if (timeout == 0) { + _module_msg->sendCmd(cmd.c_str()); + return MODULE_LLM_OK; + } + + int ret = MODULE_LLM_WAIT_RESPONSE_TIMEOUT; + _module_msg->sendCmdAndWaitToTakeMsg( + cmd.c_str(), request_id, + [&ret](ResponseMsg_t& msg) { + // Copy error code + ret = msg.error.code; + }, + timeout); + return ret; +} diff --git a/src/api/api_melotts.h b/src/api/api_melotts.h new file mode 100644 index 0000000..04aa4dd --- /dev/null +++ b/src/api/api_melotts.h @@ -0,0 +1,48 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#pragma once +#include "../utils/msg.h" +#include + +namespace m5_module_llm { + +struct ApiMelottsSetupConfig_t { + String model = ""; + String response_format = "tts.base64.wav"; + std::vector input = {"sys.pcm"}; + bool enoutput = false; + bool enaudio = false; +}; + +class ApiMelotts { +public: + void init(ModuleMsg* moduleMsg); + + /** + * @brief Setup module TTS, return TTS work_id + * + * @param config + * @param request_id + * @return String + */ + String setup(ApiMelottsSetupConfig_t config = ApiMelottsSetupConfig_t(), String request_id = "tts_setup"); + + /** + * @brief Inference input data by TTS module + * + * @param work_id + * @param input + * @param timeout wait response timeout, default 0 (do not wait response) + * @param request_id + * @return int + */ + int inference(String work_id, String input, uint32_t timeout = 0, String request_id = "tts_inference"); + +private: + ModuleMsg* _module_msg = nullptr; +}; + +} // namespace m5_module_llm diff --git a/src/api/api_sys.cpp b/src/api/api_sys.cpp index 3f90b4c..bf2491c 100644 --- a/src/api/api_sys.cpp +++ b/src/api/api_sys.cpp @@ -44,8 +44,7 @@ int ApiSys::reset(bool waitResetFinish) if (waitResetFinish) { ret = MODULE_LLM_WAIT_RESPONSE_TIMEOUT; _module_msg->responseMsgList.clear(); - _module_msg->waitAndTakeMsg( - "0", [&ret](ResponseMsg_t& msg) { ret = msg.error.code; }, 15000); + _module_msg->waitAndTakeMsg("0", [&ret](ResponseMsg_t& msg) { ret = msg.error.code; }, 15000); } return ret; } diff --git a/src/api/api_tts.cpp b/src/api/api_tts.cpp index 4b837d1..5cfa2e6 100644 --- a/src/api/api_tts.cpp +++ b/src/api/api_tts.cpp @@ -24,12 +24,11 @@ String ApiTts::setup(ApiTtsSetupConfig_t config, String request_id) doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; JsonArray inputArray = doc["data"]["input"].to(); - for (const String& str : config.input) - { + for (const String& str : config.input) { inputArray.add(str); } - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enaudio"] = config.enaudio; + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["enaudio"] = config.enaudio; serializeJson(doc, cmd); } From 7de1d92930f5daa4a4f4ad8bbba21b14468bfb47 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Wed, 20 Nov 2024 10:45:59 +0800 Subject: [PATCH 3/4] update melotts add model name. --- src/api/api_melotts.cpp | 2 +- src/api/api_melotts.h | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/api/api_melotts.cpp b/src/api/api_melotts.cpp index ad4a602..9fafe4b 100644 --- a/src/api/api_melotts.cpp +++ b/src/api/api_melotts.cpp @@ -28,7 +28,7 @@ String ApiMelotts::setup(ApiMelottsSetupConfig_t config, String request_id) inputArray.add(str); } doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enoutput"] = config.enaudio; + doc["data"]["enaudio"] = config.enaudio; serializeJson(doc, cmd); } diff --git a/src/api/api_melotts.h b/src/api/api_melotts.h index 04aa4dd..1550b01 100644 --- a/src/api/api_melotts.h +++ b/src/api/api_melotts.h @@ -10,11 +10,11 @@ namespace m5_module_llm { struct ApiMelottsSetupConfig_t { - String model = ""; - String response_format = "tts.base64.wav"; + String model = "melotts_zh-cn"; + String response_format = "sys.pcm"; std::vector input = {"sys.pcm"}; bool enoutput = false; - bool enaudio = false; + bool enaudio = true; }; class ApiMelotts { From 1ad17b4ab0a9086bb1396124f043ab1525901b10 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Wed, 20 Nov 2024 11:06:22 +0800 Subject: [PATCH 4/4] fix format --- src/api/api_sys.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/api/api_sys.cpp b/src/api/api_sys.cpp index bf2491c..3f90b4c 100644 --- a/src/api/api_sys.cpp +++ b/src/api/api_sys.cpp @@ -44,7 +44,8 @@ int ApiSys::reset(bool waitResetFinish) if (waitResetFinish) { ret = MODULE_LLM_WAIT_RESPONSE_TIMEOUT; _module_msg->responseMsgList.clear(); - _module_msg->waitAndTakeMsg("0", [&ret](ResponseMsg_t& msg) { ret = msg.error.code; }, 15000); + _module_msg->waitAndTakeMsg( + "0", [&ret](ResponseMsg_t& msg) { ret = msg.error.code; }, 15000); } return ret; }