Skip to content

Commit

Permalink
Merge pull request #7 from Abandon-ht/dev
Browse files Browse the repository at this point in the history
fix melotts inference & add yolo
  • Loading branch information
Forairaaaaa authored Dec 2, 2024
2 parents 671b87a + f1d6be7 commit 688df93
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/M5ModuleLLM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ bool M5ModuleLLM::begin(Stream* serialPort)
llm.init(&msg);
audio.init(&msg);
tts.init(&msg);
melotts.init(&msg);
kws.init(&msg);
asr.init(&msg);
yolo.init(&msg);
return true;
}

Expand Down
9 changes: 9 additions & 0 deletions src/M5ModuleLLM.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "api/api_melotts.h"
#include "api/api_kws.h"
#include "api/api_asr.h"
#include "api/api_yolo.h"

class M5ModuleLLM {
public:
Expand Down Expand Up @@ -82,6 +83,12 @@ class M5ModuleLLM {
*/
m5_module_llm::ApiAsr asr;

/**
* @brief YOLO module api set
*
*/
m5_module_llm::ApiYolo yolo;

/**
* @brief MSG module to handle module response message
*
Expand All @@ -100,8 +107,10 @@ class M5ModuleLLM {
typedef std::function<void(void)> OnKeywordDetectedCallback_t;
typedef std::function<void(String data, bool isFinish, int index)> OnAsrDataInputCallback_t;
typedef std::function<void(String data, bool isFinish, int index)> OnLlmDataInputCallback_t;
typedef std::function<void(String data, bool isFinish, int index)> OnYoloDataInputCallback_t;
typedef std::function<void(String rawData)> OnAsrDataInputRawCallback_t;
typedef std::function<void(String rawData)> OnLlmDataInputRawCallback_t;
typedef std::function<void(String rawData)> OnYoloDataInputRawCallback_t;

/**
* @brief Voice assistant preset base on class M5ModuleLLM
Expand Down
4 changes: 2 additions & 2 deletions src/api/api_melotts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ String ApiMelotts::setup(ApiMelottsSetupConfig_t config, String request_id)
// Copy work id
work_id = msg.work_id;
},
10000);
15000);
return work_id;
}

Expand All @@ -51,7 +51,7 @@ int ApiMelotts::inference(String work_id, String input, uint32_t timeout, String
doc["request_id"] = request_id;
doc["work_id"] = work_id;
doc["action"] = "inference";
doc["object"] = "melotts.utf-8.stream";
doc["object"] = "tts.utf-8.stream";
doc["data"]["delta"] = input;
doc["data"]["index"] = 0;
doc["data"]["finish"] = true;
Expand Down
106 changes: 106 additions & 0 deletions src/api/api_yolo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD
*
* SPDX-License-Identifier: MIT
*/
#include "api_yolo.h"

using namespace m5_module_llm;

void ApiYolo::init(ModuleMsg* moduleMsg)
{
_module_msg = moduleMsg;
}

String ApiYolo::setup(ApiYoloSetupConfig_t config, String request_id)
{
String cmd;
{
JsonDocument doc;
doc["request_id"] = request_id;
doc["work_id"] = "yolo";
doc["action"] = "setup";
doc["object"] = "yolo.setup";
doc["data"]["model"] = config.model;
doc["data"]["response_format"] = config.response_format;
JsonArray inputArray = doc["data"]["input"].to<JsonArray>();
for (const String& str : config.input) {
inputArray.add(str);
}
doc["data"]["enoutput"] = config.enoutput;
serializeJson(doc, cmd);
}

String work_id;
_module_msg->sendCmdAndWaitToTakeMsg(
cmd.c_str(), request_id,
[&work_id](ResponseMsg_t& msg) {
// Copy work id
work_id = msg.work_id;
},
5000);
return work_id;
}

int ApiYolo::inference(String& work_id, uint8_t* input, size_t& raw_len, String request_id)
{
String cmd;
{
JsonDocument doc;
doc["RAW"] = raw_len;
doc["request_id"] = request_id;
doc["work_id"] = work_id;
doc["action"] = "inference";
doc["object"] = "cv.jpeg.base64";
serializeJson(doc, cmd);
}

_module_msg->sendCmd(cmd.c_str());
_module_msg->sendRaw(input, raw_len);
return MODULE_LLM_OK;
}

int ApiYolo::inferenceAndWaitResult(String& work_id, uint8_t* input, size_t& raw_len,
std::function<void(String&)> onResult, uint32_t timeout, String request_id)
{
inference(work_id, input, raw_len, request_id);

uint32_t time_out_count = millis();
bool is_time_out = false;
bool is_msg_finish = false;
while (1) {
_module_msg->update();
_module_msg->takeMsg(request_id, [&time_out_count, &is_msg_finish, &onResult](ResponseMsg_t& msg) {
String response_msg;
{
JsonDocument doc;
deserializeJson(doc, msg.raw_msg);
response_msg = doc["data"]["delta"].as<String>();
if (!doc["data"]["finish"].isNull()) {
is_msg_finish = doc["data"]["finish"];
if (is_msg_finish) {
response_msg += '\n';
}
}
}
if (onResult) {
onResult(response_msg);
}
time_out_count = millis();
});

if (is_msg_finish) {
break;
}

if (millis() - time_out_count > timeout) {
is_time_out = true;
break;
}
}

if (is_time_out) {
return MODULE_LLM_WAIT_RESPONSE_TIMEOUT;
}
return MODULE_LLM_OK;
}
59 changes: 59 additions & 0 deletions src/api/api_yolo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD
*
* SPDX-License-Identifier: MIT
*/
#pragma once
#include "../utils/msg.h"
#include <Arduino.h>

namespace m5_module_llm {
struct ApiYoloSetupConfig_t {
String model = "yolo11n";
String response_format = "yolo.yolobox.stream";
std::vector<String> input = {"yolo.jpeg.base64"};
bool enoutput = true;
};

class ApiYolo {
public:
void init(ModuleMsg* moduleMsg);

/**
* @brief Setup module YOLO, return YOLO work_id
*
* @param config
* @param request_id
* @return String
*/
String setup(ApiYoloSetupConfig_t config = ApiYoloSetupConfig_t(), String request_id = "yolo_setup");

/**
* @brief Inference input data by module LLM
*
* @param raw_len
* @param work_id
* @param input
* @param request_id
* @return int
*/
int inference(String& work_id, uint8_t* input, size_t& raw_len, String request_id = "yolo_inference");

/**
* @brief Inference input data by module LLM, and wait inference result
*
* @param raw_len
* @param work_id
* @param input
* @param onResult On inference result callback
* @param timeout
* @param request_id
* @return int
*/
int inferenceAndWaitResult(String& work_id, uint8_t* input, size_t& raw_len, std::function<void(String&)> onResult,
uint32_t timeout = 5000, String request_id = "yolo_inference");

private:
ModuleMsg* _module_msg = nullptr;
};
} // namespace m5_module_llm
5 changes: 5 additions & 0 deletions src/utils/comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ void ModuleComm::sendCmd(const char* cmd)
_serial->print(cmd);
}

void ModuleComm::sendRaw(const uint8_t* data, size_t& raw_len)
{
_serial->write(data, raw_len);
}

ModuleComm::Respond_t ModuleComm::getResponse(uint32_t timeout)
{
Respond_t ret;
Expand Down
1 change: 1 addition & 0 deletions src/utils/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ModuleComm {

bool init(Stream* serialPort);
void sendCmd(const char* cmd);
void sendRaw(const uint8_t* data, size_t& raw_len);
Respond_t getResponse(uint32_t timeout = 0xFFFFFFFF);

private:
Expand Down
4 changes: 4 additions & 0 deletions src/utils/msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class ModuleMsg {
_module_comm->sendCmd(cmd);
}

inline void sendRaw(const uint8_t* data, size_t& raw_len)
{
_module_comm->sendRaw(data, raw_len);
}
/**
* @brief Module response message list
*
Expand Down

0 comments on commit 688df93

Please sign in to comment.