diff --git a/examples/KWS_ASR/KWS_ASR.ino b/examples/KWS_ASR/KWS_ASR.ino index 0633850..68792b7 100644 --- a/examples/KWS_ASR/KWS_ASR.ino +++ b/examples/KWS_ASR/KWS_ASR.ino @@ -54,7 +54,9 @@ void setup() /* Setup ASR module and save returned work id */ M5.Display.printf(">> Setup asr..\n"); - asr_work_id = module_llm.asr.setup(); + m5_module_llm::ApiAsrSetupConfig_t asr_config; + asr_config.input = {"sys.pcm", kws_work_id}; + asr_work_id = module_llm.asr.setup(asr_config); M5.Display.printf(">> Setup ok\n>> Say \"%s\" to wakeup\n", wake_up_keyword.c_str()); } diff --git a/src/M5ModuleLLM.h b/src/M5ModuleLLM.h index 9b2a113..8dec292 100644 --- a/src/M5ModuleLLM.h +++ b/src/M5ModuleLLM.h @@ -11,6 +11,7 @@ #include "api/api_llm.h" #include "api/api_audio.h" #include "api/api_tts.h" +#include "api/api_melotts.h" #include "api/api_kws.h" #include "api/api_asr.h" @@ -63,6 +64,12 @@ class M5ModuleLLM { */ m5_module_llm::ApiTts tts; + /** + * @brief MELOTTS module api set + * + */ + m5_module_llm::ApiMelotts melotts; + /** * @brief KWS module api set * @@ -163,6 +170,7 @@ class M5ModuleLLM_VoiceAssistant { String asr; String llm; String tts; + String melotts; }; WorkId_t _work_id; diff --git a/src/api/api_asr.cpp b/src/api/api_asr.cpp index 85283eb..27edd2a 100644 --- a/src/api/api_asr.cpp +++ b/src/api/api_asr.cpp @@ -23,9 +23,12 @@ String ApiAsr::setup(ApiAsrSetupConfig_t config, String request_id) doc["object"] = "asr.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) + { + inputArray.add(str); + } doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enkws"] = config.enkws; doc["data"]["rule1"] = config.rule1; doc["data"]["rule2"] = config.rule2; doc["data"]["rule3"] = config.rule3; diff --git a/src/api/api_asr.h b/src/api/api_asr.h index 87499f8..7924850 100644 --- a/src/api/api_asr.h +++ b/src/api/api_asr.h @@ -10,14 +10,13 @@ namespace m5_module_llm { struct ApiAsrSetupConfig_t { - String model = "sherpa-ncnn-streaming-zipformer-20M-2023-02-17"; - String response_format = "asr.utf-8.stream"; - String input = "sys.pcm"; - bool enoutput = true; - bool enkws = true; - float rule1 = 2.4; - float rule2 = 1.2; - float rule3 = 30.0; + String model = "sherpa-ncnn-streaming-zipformer-20M-2023-02-17"; + String response_format = "asr.utf-8.stream"; + std::vector input = {"sys.pcm"}; + bool enoutput = true; + float rule1 = 2.4; + float rule2 = 1.2; + float rule3 = 30.0; }; class ApiAsr { diff --git a/src/api/api_kws.cpp b/src/api/api_kws.cpp index 47f2d32..4f82361 100644 --- a/src/api/api_kws.cpp +++ b/src/api/api_kws.cpp @@ -23,7 +23,11 @@ String ApiKws::setup(ApiKwsSetupConfig_t config, String request_id) doc["object"] = "kws.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) + { + inputArray.add(str); + } doc["data"]["enoutput"] = config.enoutput; doc["data"]["kws"] = config.kws; serializeJson(doc, cmd); diff --git a/src/api/api_kws.h b/src/api/api_kws.h index 2e09553..b1e21e4 100644 --- a/src/api/api_kws.h +++ b/src/api/api_kws.h @@ -10,11 +10,11 @@ namespace m5_module_llm { struct ApiKwsSetupConfig_t { - String kws = "HELLO"; - String model = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"; - String response_format = "kws.bool"; - String input = "sys.pcm"; - bool enoutput = true; + String kws = "HELLO"; + String model = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"; + String response_format = "kws.bool"; + std::vector input = {"sys.pcm"}; + bool enoutput = true; }; class ApiKws { diff --git a/src/api/api_llm.cpp b/src/api/api_llm.cpp index e2f54ac..9005557 100644 --- a/src/api/api_llm.cpp +++ b/src/api/api_llm.cpp @@ -23,9 +23,12 @@ String ApiLlm::setup(ApiLlmSetupConfig_t config, String request_id) doc["object"] = "llm.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) + { + inputArray.add(str); + } doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enkws"] = config.enkws; doc["data"]["max_token_len"] = config.max_token_len; doc["data"]["prompt"] = config.prompt; serializeJson(doc, cmd); diff --git a/src/api/api_llm.h b/src/api/api_llm.h index 32c7eef..28d57d4 100644 --- a/src/api/api_llm.h +++ b/src/api/api_llm.h @@ -11,12 +11,12 @@ namespace m5_module_llm { struct ApiLlmSetupConfig_t { String prompt; - String model = "qwen2.5-0.5b"; - String response_format = "llm.utf-8.stream"; - String input = "llm.utf-8.stream"; - bool enoutput = true; - bool enkws = true; - int max_token_len = 127; + String model = "qwen2.5-0.5B-prefill-20e"; + String response_format = "llm.utf-8.stream"; + std::vector input = {"llm.utf-8.stream"}; + bool enoutput = true; + int max_token_len = 127; + // int max_token_len = 512; }; class ApiLlm { diff --git a/src/api/api_tts.cpp b/src/api/api_tts.cpp index 5dd8877..4b837d1 100644 --- a/src/api/api_tts.cpp +++ b/src/api/api_tts.cpp @@ -23,9 +23,13 @@ String ApiTts::setup(ApiTtsSetupConfig_t config, String request_id) doc["object"] = "tts.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - doc["data"]["input"] = config.input; + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) + { + inputArray.add(str); + } doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enkws"] = config.enkws; + doc["data"]["enaudio"] = config.enaudio; serializeJson(doc, cmd); } @@ -50,7 +54,7 @@ int ApiTts::inference(String work_id, String input, uint32_t timeout, String req doc["action"] = "inference"; doc["object"] = "tts.utf-8.stream"; doc["data"]["delta"] = input; - doc["data"]["index"] = 1; + doc["data"]["index"] = 0; doc["data"]["finish"] = true; serializeJson(doc, cmd); } diff --git a/src/api/api_tts.h b/src/api/api_tts.h index c3dd1b9..5ad601a 100644 --- a/src/api/api_tts.h +++ b/src/api/api_tts.h @@ -10,11 +10,11 @@ namespace m5_module_llm { struct ApiTtsSetupConfig_t { - String model = "single_speaker_english_fast"; - String response_format = "tts.base64.wav"; - String input = "tts.utf-8.stream"; - bool enoutput = true; - bool enkws = true; + String model = "single_speaker_english_fast"; + String response_format = "sys.pcm"; + std::vector input = {"sys.pcm"}; + bool enoutput = false; + bool enaudio = true; }; class ApiTts { diff --git a/src/presets/voice_assistant.cpp b/src/presets/voice_assistant.cpp index d34cc0d..7b1913c 100644 --- a/src/presets/voice_assistant.cpp +++ b/src/presets/voice_assistant.cpp @@ -37,7 +37,11 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt) } _debug("setup module asr.."); - _work_id.asr = _m5_module_llm->asr.setup(); + { + ApiAsrSetupConfig_t config; + config.input = {"sys.pcm", _work_id.kws}; + _work_id.asr = _m5_module_llm->asr.setup(config); + } if (_work_id.asr.isEmpty()) { return MODULE_LLM_ERROR_NONE; } @@ -45,7 +49,7 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt) _debug("setup module llm.."); { ApiLlmSetupConfig_t config; - config.input = _work_id.asr; + config.input = {_work_id.asr, _work_id.kws}; config.prompt = prompt; _work_id.llm = _m5_module_llm->llm.setup(config); } @@ -56,7 +60,7 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt) _debug("setup module tts.."); { ApiTtsSetupConfig_t config; - config.input = _work_id.llm; + config.input = {_work_id.llm, _work_id.kws}; _work_id.tts = _m5_module_llm->tts.setup(config); } if (_work_id.tts.isEmpty()) {