diff --git a/src/M5ModuleLLM.cpp b/src/M5ModuleLLM.cpp index b7f4699..f004fe8 100644 --- a/src/M5ModuleLLM.cpp +++ b/src/M5ModuleLLM.cpp @@ -23,6 +23,7 @@ bool M5ModuleLLM::begin(Stream* serialPort) bool M5ModuleLLM::checkConnection() { + llm_version = (sys.version() == MODULE_LLM_OK); return sys.ping() == MODULE_LLM_OK; } diff --git a/src/M5ModuleLLM.h b/src/M5ModuleLLM.h index eaed899..abbc578 100644 --- a/src/M5ModuleLLM.h +++ b/src/M5ModuleLLM.h @@ -15,6 +15,7 @@ #include "api/api_kws.h" #include "api/api_asr.h" #include "api/api_yolo.h" +#include "api/api_version.h" class M5ModuleLLM { public: diff --git a/src/api/api_asr.cpp b/src/api/api_asr.cpp index 43a86e9..633cad1 100644 --- a/src/api/api_asr.cpp +++ b/src/api/api_asr.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: MIT */ #include "api_asr.h" +#include "api_version.h" using namespace m5_module_llm; @@ -23,14 +24,19 @@ String ApiAsr::setup(ApiAsrSetupConfig_t config, String request_id) doc["object"] = "asr.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - JsonArray inputArray = doc["data"]["input"].to(); - for (const String& str : config.input) { - inputArray.add(str); + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["enkws"] = config.enkws; + doc["data"]["rule1"] = config.rule1; + doc["data"]["rule2"] = config.rule2; + doc["data"]["rule3"] = config.rule3; + if (!llm_version) { + doc["data"]["input"] = config.input[0]; + } else { + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) { + inputArray.add(str); + } } - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["rule1"] = config.rule1; - doc["data"]["rule2"] = config.rule2; - doc["data"]["rule3"] = config.rule3; serializeJson(doc, cmd); } diff --git a/src/api/api_asr.h b/src/api/api_asr.h index 7924850..9b2d65d 100644 --- a/src/api/api_asr.h +++ b/src/api/api_asr.h @@ -14,6 +14,7 @@ struct ApiAsrSetupConfig_t { String response_format = "asr.utf-8.stream"; std::vector input = {"sys.pcm"}; bool enoutput = true; + bool enkws = true; float rule1 = 2.4; float rule2 = 1.2; float rule3 = 30.0; diff --git a/src/api/api_kws.cpp b/src/api/api_kws.cpp index 1431732..46db507 100644 --- a/src/api/api_kws.cpp +++ b/src/api/api_kws.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: MIT */ #include "api_kws.h" +#include "api_version.h" using namespace m5_module_llm; @@ -23,12 +24,16 @@ String ApiKws::setup(ApiKwsSetupConfig_t config, String request_id) doc["object"] = "kws.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - JsonArray inputArray = doc["data"]["input"].to(); - for (const String& str : config.input) { - inputArray.add(str); + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["kws"] = config.kws; + if (!llm_version) { + doc["data"]["input"] = config.input[0]; + } else { + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) { + inputArray.add(str); + } } - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["kws"] = config.kws; serializeJson(doc, cmd); } diff --git a/src/api/api_llm.cpp b/src/api/api_llm.cpp index af5c539..57b4ee9 100644 --- a/src/api/api_llm.cpp +++ b/src/api/api_llm.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: MIT */ #include "api_llm.h" +#include "api_version.h" using namespace m5_module_llm; @@ -23,13 +24,19 @@ String ApiLlm::setup(ApiLlmSetupConfig_t config, String request_id) doc["object"] = "llm.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - JsonArray inputArray = doc["data"]["input"].to(); - for (const String& str : config.input) { - inputArray.add(str); + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["enkws"] = config.enkws; + doc["data"]["max_token_len"] = config.max_token_len; + doc["data"]["prompt"] = config.prompt; + if (!llm_version) { + doc["data"]["model"] = "qwen2.5-0.5b"; + doc["data"]["input"] = config.input[0]; + } else { + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) { + inputArray.add(str); + } } - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["max_token_len"] = config.max_token_len; - doc["data"]["prompt"] = config.prompt; serializeJson(doc, cmd); } diff --git a/src/api/api_llm.h b/src/api/api_llm.h index 28d57d4..27536e9 100644 --- a/src/api/api_llm.h +++ b/src/api/api_llm.h @@ -15,6 +15,7 @@ struct ApiLlmSetupConfig_t { String response_format = "llm.utf-8.stream"; std::vector input = {"llm.utf-8.stream"}; bool enoutput = true; + bool enkws = true; int max_token_len = 127; // int max_token_len = 512; }; diff --git a/src/api/api_melotts.h b/src/api/api_melotts.h index 9f97108..9d22e48 100644 --- a/src/api/api_melotts.h +++ b/src/api/api_melotts.h @@ -12,7 +12,7 @@ namespace m5_module_llm { struct ApiMelottsSetupConfig_t { String model = "melotts_zh-cn"; String response_format = "sys.pcm"; - std::vector input = {"tts.utf-8,stream"}; + std::vector input = {"tts.utf-8.stream"}; bool enoutput = false; bool enaudio = true; }; diff --git a/src/api/api_sys.cpp b/src/api/api_sys.cpp index 3f90b4c..2baaa68 100644 --- a/src/api/api_sys.cpp +++ b/src/api/api_sys.cpp @@ -13,8 +13,8 @@ static const char* _cmd_reset = "{\"request_id\":\"sys_reset\",\"work_id\":\"sys\",\"action\":\"reset\",\"object\":\"None\",\"data\":\"None\"}"; static const char* _cmd_reboot = "{\"request_id\":\"sys_reboot\",\"work_id\":\"sys\",\"action\":\"reboot\",\"object\":\"None\",\"data\":\"None\"}"; -// static const char* _cmd_ls_mode = -// "{\"request_id\":\"sys_lsmode\",\"work_id\":\"sys\",\"action\":\"lsmode\",\"object\":\"None\",\"data\":\"None\"}"; +static const char* _cmd_version = + "{\"request_id\":\"sys_version\",\"work_id\":\"sys\",\"action\":\"version\",\"object\":\"None\",\"data\":\"None\"}"; void ApiSys::init(ModuleMsg* moduleMsg) { @@ -29,6 +29,14 @@ int ApiSys::ping() return ret; } +int ApiSys::version() +{ + int ret = MODULE_LLM_WAIT_RESPONSE_TIMEOUT; + _module_msg->sendCmdAndWaitToTakeMsg( + _cmd_version, "sys_version", [&ret](ResponseMsg_t& msg) { ret = msg.error.code; }, 2000); + return ret; +} + int ApiSys::reset(bool waitResetFinish) { int ret = MODULE_LLM_WAIT_RESPONSE_TIMEOUT; diff --git a/src/api/api_sys.h b/src/api/api_sys.h index 2683d1b..796038b 100644 --- a/src/api/api_sys.h +++ b/src/api/api_sys.h @@ -26,6 +26,16 @@ class ApiSys { * @param waitResetFinish * @return int */ + + int version(); + + /** + * @brief Check version + * + * @param waitCheckFinish + * @return int + */ + int reset(bool waitResetFinish = true); /** diff --git a/src/api/api_tts.cpp b/src/api/api_tts.cpp index 5cfa2e6..332ca3a 100644 --- a/src/api/api_tts.cpp +++ b/src/api/api_tts.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: MIT */ #include "api_tts.h" +#include "api_version.h" using namespace m5_module_llm; @@ -23,12 +24,19 @@ String ApiTts::setup(ApiTtsSetupConfig_t config, String request_id) doc["object"] = "tts.setup"; doc["data"]["model"] = config.model; doc["data"]["response_format"] = config.response_format; - JsonArray inputArray = doc["data"]["input"].to(); - for (const String& str : config.input) { - inputArray.add(str); + doc["data"]["enoutput"] = config.enoutput; + doc["data"]["enkws"] = config.enkws; + doc["data"]["enaudio"] = config.enaudio; + if (!llm_version) { + doc["data"]["response_format"] = "tts.base64.wav"; + doc["data"]["input"] = config.input[0]; + doc["data"]["enoutput"] = true; + } else { + JsonArray inputArray = doc["data"]["input"].to(); + for (const String& str : config.input) { + inputArray.add(str); + } } - doc["data"]["enoutput"] = config.enoutput; - doc["data"]["enaudio"] = config.enaudio; serializeJson(doc, cmd); } diff --git a/src/api/api_tts.h b/src/api/api_tts.h index 8016388..6725618 100644 --- a/src/api/api_tts.h +++ b/src/api/api_tts.h @@ -12,9 +12,10 @@ namespace m5_module_llm { struct ApiTtsSetupConfig_t { String model = "single_speaker_english_fast"; String response_format = "sys.pcm"; - std::vector input = {"tts.utf-8,stream"}; + std::vector input = {"tts.utf-8.stream"}; bool enoutput = false; bool enaudio = true; + bool enkws = true; }; class ApiTts { diff --git a/src/api/api_version.cpp b/src/api/api_version.cpp new file mode 100644 index 0000000..8cf6244 --- /dev/null +++ b/src/api/api_version.cpp @@ -0,0 +1,8 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#include "api_version.h" + +int llm_version = 0; diff --git a/src/api/api_version.h b/src/api/api_version.h new file mode 100644 index 0000000..64f5758 --- /dev/null +++ b/src/api/api_version.h @@ -0,0 +1,8 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#pragma once + +extern int llm_version; diff --git a/src/presets/voice_assistant.cpp b/src/presets/voice_assistant.cpp index 7b1913c..52d12f6 100644 --- a/src/presets/voice_assistant.cpp +++ b/src/presets/voice_assistant.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: MIT */ #include "../M5ModuleLLM.h" +#include "../src/api/api_version.h" using namespace m5_module_llm; @@ -59,11 +60,17 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt) _debug("setup module tts.."); { - ApiTtsSetupConfig_t config; - config.input = {_work_id.llm, _work_id.kws}; - _work_id.tts = _m5_module_llm->tts.setup(config); + if (!llm_version) { + ApiTtsSetupConfig_t config; + config.input = {_work_id.llm, _work_id.kws}; + _work_id.tts = _m5_module_llm->tts.setup(config); + } else { + ApiMelottsSetupConfig_t config; + config.input = {_work_id.llm, _work_id.kws}; + _work_id.melotts = _m5_module_llm->melotts.setup(config); + } } - if (_work_id.tts.isEmpty()) { + if (_work_id.tts.isEmpty() && _work_id.melotts.isEmpty()) { return MODULE_LLM_ERROR_NONE; }