Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update supports stackflow. #5

Merged
merged 4 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/KWS_ASR/KWS_ASR.ino
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ void setup()

/* Setup ASR module and save returned work id */
M5.Display.printf(">> Setup asr..\n");
asr_work_id = module_llm.asr.setup();
m5_module_llm::ApiAsrSetupConfig_t asr_config;
asr_config.input = {"sys.pcm", kws_work_id};
asr_work_id = module_llm.asr.setup(asr_config);

M5.Display.printf(">> Setup ok\n>> Say \"%s\" to wakeup\n", wake_up_keyword.c_str());
}
Expand Down
8 changes: 8 additions & 0 deletions src/M5ModuleLLM.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "api/api_llm.h"
#include "api/api_audio.h"
#include "api/api_tts.h"
#include "api/api_melotts.h"
#include "api/api_kws.h"
#include "api/api_asr.h"

Expand Down Expand Up @@ -63,6 +64,12 @@ class M5ModuleLLM {
*/
m5_module_llm::ApiTts tts;

/**
* @brief MELOTTS module api set
*
*/
m5_module_llm::ApiMelotts melotts;

/**
* @brief KWS module api set
*
Expand Down Expand Up @@ -163,6 +170,7 @@ class M5ModuleLLM_VoiceAssistant {
String asr;
String llm;
String tts;
String melotts;
};

WorkId_t _work_id;
Expand Down
14 changes: 8 additions & 6 deletions src/api/api_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ String ApiAsr::setup(ApiAsrSetupConfig_t config, String request_id)
doc["object"] = "asr.setup";
doc["data"]["model"] = config.model;
doc["data"]["response_format"] = config.response_format;
doc["data"]["input"] = config.input;
doc["data"]["enoutput"] = config.enoutput;
doc["data"]["enkws"] = config.enkws;
doc["data"]["rule1"] = config.rule1;
doc["data"]["rule2"] = config.rule2;
doc["data"]["rule3"] = config.rule3;
JsonArray inputArray = doc["data"]["input"].to<JsonArray>();
for (const String& str : config.input) {
inputArray.add(str);
}
doc["data"]["enoutput"] = config.enoutput;
doc["data"]["rule1"] = config.rule1;
doc["data"]["rule2"] = config.rule2;
doc["data"]["rule3"] = config.rule3;
serializeJson(doc, cmd);
}

Expand Down
15 changes: 7 additions & 8 deletions src/api/api_asr.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
namespace m5_module_llm {

struct ApiAsrSetupConfig_t {
String model = "sherpa-ncnn-streaming-zipformer-20M-2023-02-17";
String response_format = "asr.utf-8.stream";
String input = "sys.pcm";
bool enoutput = true;
bool enkws = true;
float rule1 = 2.4;
float rule2 = 1.2;
float rule3 = 30.0;
String model = "sherpa-ncnn-streaming-zipformer-20M-2023-02-17";
String response_format = "asr.utf-8.stream";
std::vector<String> input = {"sys.pcm"};
bool enoutput = true;
float rule1 = 2.4;
float rule2 = 1.2;
float rule3 = 30.0;
};

class ApiAsr {
Expand Down
9 changes: 6 additions & 3 deletions src/api/api_kws.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ String ApiKws::setup(ApiKwsSetupConfig_t config, String request_id)
doc["object"] = "kws.setup";
doc["data"]["model"] = config.model;
doc["data"]["response_format"] = config.response_format;
doc["data"]["input"] = config.input;
doc["data"]["enoutput"] = config.enoutput;
doc["data"]["kws"] = config.kws;
JsonArray inputArray = doc["data"]["input"].to<JsonArray>();
for (const String& str : config.input) {
inputArray.add(str);
}
doc["data"]["enoutput"] = config.enoutput;
doc["data"]["kws"] = config.kws;
serializeJson(doc, cmd);
}

Expand Down
10 changes: 5 additions & 5 deletions src/api/api_kws.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
namespace m5_module_llm {

struct ApiKwsSetupConfig_t {
String kws = "HELLO";
String model = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01";
String response_format = "kws.bool";
String input = "sys.pcm";
bool enoutput = true;
String kws = "HELLO";
String model = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01";
String response_format = "kws.bool";
std::vector<String> input = {"sys.pcm"};
bool enoutput = true;
};

class ApiKws {
Expand Down
12 changes: 7 additions & 5 deletions src/api/api_llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ String ApiLlm::setup(ApiLlmSetupConfig_t config, String request_id)
doc["object"] = "llm.setup";
doc["data"]["model"] = config.model;
doc["data"]["response_format"] = config.response_format;
doc["data"]["input"] = config.input;
doc["data"]["enoutput"] = config.enoutput;
doc["data"]["enkws"] = config.enkws;
doc["data"]["max_token_len"] = config.max_token_len;
doc["data"]["prompt"] = config.prompt;
JsonArray inputArray = doc["data"]["input"].to<JsonArray>();
for (const String& str : config.input) {
inputArray.add(str);
}
doc["data"]["enoutput"] = config.enoutput;
doc["data"]["max_token_len"] = config.max_token_len;
doc["data"]["prompt"] = config.prompt;
serializeJson(doc, cmd);
}

Expand Down
12 changes: 6 additions & 6 deletions src/api/api_llm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ namespace m5_module_llm {

struct ApiLlmSetupConfig_t {
String prompt;
String model = "qwen2.5-0.5b";
String response_format = "llm.utf-8.stream";
String input = "llm.utf-8.stream";
bool enoutput = true;
bool enkws = true;
int max_token_len = 127;
String model = "qwen2.5-0.5B-prefill-20e";
String response_format = "llm.utf-8.stream";
std::vector<String> input = {"llm.utf-8.stream"};
bool enoutput = true;
int max_token_len = 127;
// int max_token_len = 512;
};

class ApiLlm {
Expand Down
75 changes: 75 additions & 0 deletions src/api/api_melotts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD
*
* SPDX-License-Identifier: MIT
*/
#include "api_melotts.h"

using namespace m5_module_llm;

void ApiMelotts::init(ModuleMsg* moduleMsg)
{
_module_msg = moduleMsg;
}

String ApiMelotts::setup(ApiMelottsSetupConfig_t config, String request_id)
{
String cmd;
{
JsonDocument doc;
doc["request_id"] = request_id;
doc["work_id"] = "melotts";
doc["action"] = "setup";
doc["object"] = "melotts.setup";
doc["data"]["model"] = config.model;
doc["data"]["response_format"] = config.response_format;
JsonArray inputArray = doc["data"]["input"].to<JsonArray>();
for (const String& str : config.input) {
inputArray.add(str);
}
doc["data"]["enoutput"] = config.enoutput;
doc["data"]["enaudio"] = config.enaudio;
serializeJson(doc, cmd);
}

String work_id;
_module_msg->sendCmdAndWaitToTakeMsg(
cmd.c_str(), request_id,
[&work_id](ResponseMsg_t& msg) {
// Copy work id
work_id = msg.work_id;
},
10000);
return work_id;
}

int ApiMelotts::inference(String work_id, String input, uint32_t timeout, String request_id)
{
String cmd;
{
JsonDocument doc;
doc["request_id"] = request_id;
doc["work_id"] = work_id;
doc["action"] = "inference";
doc["object"] = "melotts.utf-8.stream";
doc["data"]["delta"] = input;
doc["data"]["index"] = 0;
doc["data"]["finish"] = true;
serializeJson(doc, cmd);
}

if (timeout == 0) {
_module_msg->sendCmd(cmd.c_str());
return MODULE_LLM_OK;
}

int ret = MODULE_LLM_WAIT_RESPONSE_TIMEOUT;
_module_msg->sendCmdAndWaitToTakeMsg(
cmd.c_str(), request_id,
[&ret](ResponseMsg_t& msg) {
// Copy error code
ret = msg.error.code;
},
timeout);
return ret;
}
48 changes: 48 additions & 0 deletions src/api/api_melotts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD
*
* SPDX-License-Identifier: MIT
*/
#pragma once
#include "../utils/msg.h"
#include <Arduino.h>

namespace m5_module_llm {

struct ApiMelottsSetupConfig_t {
String model = "melotts_zh-cn";
String response_format = "sys.pcm";
std::vector<String> input = {"sys.pcm"};
bool enoutput = false;
bool enaudio = true;
};

class ApiMelotts {
public:
void init(ModuleMsg* moduleMsg);

/**
* @brief Setup module TTS, return TTS work_id
*
* @param config
* @param request_id
* @return String
*/
String setup(ApiMelottsSetupConfig_t config = ApiMelottsSetupConfig_t(), String request_id = "tts_setup");

/**
* @brief Inference input data by TTS module
*
* @param work_id
* @param input
* @param timeout wait response timeout, default 0 (do not wait response)
* @param request_id
* @return int
*/
int inference(String work_id, String input, uint32_t timeout = 0, String request_id = "tts_inference");

private:
ModuleMsg* _module_msg = nullptr;
};

} // namespace m5_module_llm
11 changes: 7 additions & 4 deletions src/api/api_tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ String ApiTts::setup(ApiTtsSetupConfig_t config, String request_id)
doc["object"] = "tts.setup";
doc["data"]["model"] = config.model;
doc["data"]["response_format"] = config.response_format;
doc["data"]["input"] = config.input;
doc["data"]["enoutput"] = config.enoutput;
doc["data"]["enkws"] = config.enkws;
JsonArray inputArray = doc["data"]["input"].to<JsonArray>();
for (const String& str : config.input) {
inputArray.add(str);
}
doc["data"]["enoutput"] = config.enoutput;
doc["data"]["enaudio"] = config.enaudio;
serializeJson(doc, cmd);
}

Expand All @@ -50,7 +53,7 @@ int ApiTts::inference(String work_id, String input, uint32_t timeout, String req
doc["action"] = "inference";
doc["object"] = "tts.utf-8.stream";
doc["data"]["delta"] = input;
doc["data"]["index"] = 1;
doc["data"]["index"] = 0;
doc["data"]["finish"] = true;
serializeJson(doc, cmd);
}
Expand Down
10 changes: 5 additions & 5 deletions src/api/api_tts.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
namespace m5_module_llm {

struct ApiTtsSetupConfig_t {
String model = "single_speaker_english_fast";
String response_format = "tts.base64.wav";
String input = "tts.utf-8.stream";
bool enoutput = true;
bool enkws = true;
String model = "single_speaker_english_fast";
String response_format = "sys.pcm";
std::vector<String> input = {"sys.pcm"};
bool enoutput = false;
bool enaudio = true;
};

class ApiTts {
Expand Down
10 changes: 7 additions & 3 deletions src/presets/voice_assistant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,19 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt)
}

_debug("setup module asr..");
_work_id.asr = _m5_module_llm->asr.setup();
{
ApiAsrSetupConfig_t config;
config.input = {"sys.pcm", _work_id.kws};
_work_id.asr = _m5_module_llm->asr.setup(config);
}
if (_work_id.asr.isEmpty()) {
return MODULE_LLM_ERROR_NONE;
}

_debug("setup module llm..");
{
ApiLlmSetupConfig_t config;
config.input = _work_id.asr;
config.input = {_work_id.asr, _work_id.kws};
config.prompt = prompt;
_work_id.llm = _m5_module_llm->llm.setup(config);
}
Expand All @@ -56,7 +60,7 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt)
_debug("setup module tts..");
{
ApiTtsSetupConfig_t config;
config.input = _work_id.llm;
config.input = {_work_id.llm, _work_id.kws};
_work_id.tts = _m5_module_llm->tts.setup(config);
}
if (_work_id.tts.isEmpty()) {
Expand Down