Skip to content

Commit

Permalink
feat: image input and session optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
zhayujie committed Nov 27, 2023
1 parent 061d8a3 commit 4e675b8
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 16 deletions.
124 changes: 117 additions & 7 deletions bot/linkai/link_ai_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from common.log import logger
from config import conf, pconf
import threading
from common import memory, utils
import base64


class LinkAIBot(Bot):
# authentication failed
Expand All @@ -21,7 +24,7 @@ class LinkAIBot(Bot):

def __init__(self):
super().__init__()
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
self.args = {}

def reply(self, query, context: Context = None) -> Reply:
Expand Down Expand Up @@ -61,17 +64,25 @@ def _chat(self, query, context, retry_count=0) -> Reply:
linkai_api_key = conf().get("linkai_api_key")

session_id = context["session_id"]
session_message = self.sessions.session_msg_query(query, session_id)
logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")

# image process
img_cache = memory.USER_IMAGE_CACHE.get(session_id)
if img_cache:
messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
if messages:
session_message = messages

session = self.sessions.session_query(query, session_id)
model = conf().get("model")
# remove system message
if session.messages[0].get("role") == "system":
if session_message[0].get("role") == "system":
if app_code or model == "wenxin":
session.messages.pop(0)
session_message.pop(0)

body = {
"app_code": app_code,
"messages": session.messages,
"messages": session_message,
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
"temperature": conf().get("temperature"),
"top_p": conf().get("top_p", 1),
Expand All @@ -94,7 +105,7 @@ def _chat(self, query, context, retry_count=0) -> Reply:
reply_content = response["choices"][0]["message"]["content"]
total_tokens = response["usage"]["total_tokens"]
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
self.sessions.session_reply(reply_content, session_id, total_tokens)
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)

agent_suffix = self._fetch_agent_suffix(response)
if agent_suffix:
Expand Down Expand Up @@ -130,6 +141,54 @@ def _chat(self, query, context, retry_count=0) -> Reply:
logger.warn(f"[LINKAI] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)

def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
try:
enable_image_input = False
app_info = self._fetch_app_info(app_code)
if not app_info:
logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
return None
plugins = app_info.get("data").get("plugins")
for plugin in plugins:
if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
enable_image_input = True
if not enable_image_input:
return
msg = img_cache.get("msg")
path = img_cache.get("path")
msg.prepare()
logger.info(f"[LinkAI] query with images, path={path}")
messages = self._build_vision_msg(query, path)
memory.USER_IMAGE_CACHE[session_id] = None
return messages
except Exception as e:
logger.exception(e)


def _build_vision_msg(self, query: str, path: str):
try:
suffix = utils.get_path_suffix(path)
with open(path, "rb") as file:
base64_str = base64.b64encode(file.read()).decode('utf-8')
messages = [{
"role": "user",
"content": [
{
"type": "text",
"text": query
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/{suffix};base64,{base64_str}"
}
}
]
}]
return messages
except Exception as e:
logger.exception(e)

def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
if retry_count >= 2:
# exit from retry 2 times
Expand Down Expand Up @@ -195,6 +254,16 @@ def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dic
logger.warn(f"[LINKAI] do retry, times={retry_count}")
return self.reply_text(session, app_code, retry_count + 1)

def _fetch_app_info(self, app_code: str):
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
# do http request
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
params = {"app_code": app_code}
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
if res.status_code == 200:
return res.json()
else:
logger.warning(f"[LinkAI] find app info exception, res={res}")

def create_img(self, query, retry_count=0, api_key=None):
try:
Expand Down Expand Up @@ -239,6 +308,7 @@ def _fetch_knowledge_search_suffix(self, response) -> str:
except Exception as e:
logger.exception(e)


def _fetch_agent_suffix(self, response):
try:
plugin_list = []
Expand Down Expand Up @@ -275,4 +345,44 @@ def _send_image(self, channel, context, image_urls):
reply = Reply(ReplyType.IMAGE_URL, url)
channel.send(reply, context)
except Exception as e:
logger.error(e)
logger.error(e)


class LinkAISessionManager(SessionManager):
def session_msg_query(self, query, session_id):
session = self.build_session(session_id)
messages = session.messages + [{"role": "user", "content": query}]
return messages

def session_reply(self, reply, session_id, total_tokens=None, query=None):
session = self.build_session(session_id)
if query:
session.add_query(query)
session.add_reply(reply)
try:
max_tokens = conf().get("conversation_max_tokens", 2500)
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.info(f"[LinkAI] chat history discard, before tokens={total_tokens}, now tokens={tokens_cnt}")
except Exception as e:
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
return session


class LinkAISession(ChatGPTSession):
def calc_tokens(self):
try:
cur_tokens = super().calc_tokens()
except Exception as e:
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
cur_tokens = len(str(self.messages))
return cur_tokens

def discard_exceeding(self, max_tokens, cur_tokens=None):
cur_tokens = self.calc_tokens()
if cur_tokens > max_tokens:
for i in range(0, len(self.messages)):
if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
self.messages.pop(i)
self.messages.pop(i - 1)
return self.calc_tokens()
return cur_tokens
4 changes: 2 additions & 2 deletions bot/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def session_query(self, query, session_id):
total_tokens = session.discard_exceeding(max_tokens, None)
logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e:
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
return session

def session_reply(self, reply, session_id, total_tokens=None):
Expand All @@ -80,7 +80,7 @@ def session_reply(self, reply, session_id, total_tokens=None):
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
except Exception as e:
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
return session

def clear_session(self, session_id):
Expand Down
11 changes: 6 additions & 5 deletions channel/chat_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from bridge.reply import *
from channel.channel import Channel
from common.dequeue import Dequeue
from common.log import logger
from config import conf
from common import memory
from plugins import *

try:
Expand Down Expand Up @@ -205,14 +204,16 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
else:
return
elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
cmsg = context["msg"]
cmsg.prepare()
memory.USER_IMAGE_CACHE[context["session_id"]] = {
"path": context.content,
"msg": context.get("msg")
}
elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
pass
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
pass
else:
logger.error("[WX] unknown context type: {}".format(context.type))
logger.warning("[WX] unknown context type: {}".format(context.type))
return
return reply

Expand Down
3 changes: 3 additions & 0 deletions common/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from common.expired_dict import ExpiredDict

USER_IMAGE_CACHE = ExpiredDict(60 * 3)
7 changes: 6 additions & 1 deletion common/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import os

from urllib.parse import urlparse
from PIL import Image


Expand Down Expand Up @@ -49,3 +49,8 @@ def split_string_by_utf8_length(string, max_length, max_split=0):
result.append(encoded[start:end].decode("utf-8"))
start = end
return result


def get_path_suffix(path):
path = urlparse(path).path
return os.path.splitext(path)[-1].lstrip('.')
1 change: 0 additions & 1 deletion plugins/linkai/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,4 @@ def check_url(self, url: str):
for support_url in support_list:
if url.strip().startswith(support_url):
return True
logger.debug(f"[LinkSum] unsupported url, no need to process, url={url}")
return False

0 comments on commit 4e675b8

Please sign in to comment.