diff --git a/source/infrastructure/lib/knowledge-base/knowledge-base-stack.ts b/source/infrastructure/lib/knowledge-base/knowledge-base-stack.ts index 0ce731091..6625f1983 100644 --- a/source/infrastructure/lib/knowledge-base/knowledge-base-stack.ts +++ b/source/infrastructure/lib/knowledge-base/knowledge-base-stack.ts @@ -84,7 +84,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac this.dynamodbStatement = createKnowledgeBaseTablesAndPoliciesResult.dynamodbStatement; this.sfnOutput = this.createKnowledgeBaseJob(props); - + } private createKnowledgeBaseTablesAndPolicies(props: any) { @@ -153,7 +153,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac // If this.region is cn-north-1 or cn-northwest-1, use the glue-job-script-cn.py const glueJobScript = "glue-job-script.py"; - + const extraPythonFiles = new s3deploy.BucketDeployment( this, @@ -172,9 +172,9 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac const extraPythonFilesList = [ this.glueLibS3Bucket.s3UrlForObject("llm_bot_dep-0.1.0-py3-none-any.whl"), ].join(","); - - + + const glueRole = new iam.Role(this, "ETLGlueJobRole", { assumedBy: new iam.ServicePrincipal("glue.amazonaws.com"), // The role is used by the glue job to access AOS and by default it has 1 hour session duration which is not enough for the glue job to finish the embedding injection diff --git a/source/lambda/online/common_logic/common_utils/constant.py b/source/lambda/online/common_logic/common_utils/constant.py index 1ed416130..330bf2e8e 100644 --- a/source/lambda/online/common_logic/common_utils/constant.py +++ b/source/lambda/online/common_logic/common_utils/constant.py @@ -94,7 +94,6 @@ class LLMTaskType(ConstantBase): AUTO_EVALUATION = "auto_evaluation" - class MessageType(ConstantBase): HUMAN_MESSAGE_TYPE = 'human' AI_MESSAGE_TYPE = 'ai' @@ -149,7 +148,6 @@ class LLMModelType(ConstantBase): COHERE_COMMAND_R_PLUS = "cohere.command-r-plus-v1:0" - class EmbeddingModelType(ConstantBase): BEDROCK_TITAN_V1 = "amazon.titan-embed-text-v1" @@ -179,7 +177,8 @@ class IndexTag(Enum): @unique class KBType(Enum): - AOS = "aos" + AOS = "aos" + GUIDE_INTENTION_NOT_FOUND = "Intention not found, please add intentions first when using agent mode, refer to https://amzn-chn.feishu.cn/docx/HlxvduJYgoOz8CxITxXc43XWn8e" INDEX_DESC = "Answer question based on search result" @@ -188,4 +187,3 @@ class KBType(Enum): class Threshold(ConstantBase): QQ_IN_RAG_CONTEXT = 0.5 INTENTION_ALL_KNOWLEDGE_RETRIEVAL = 0.4 - diff --git a/source/lambda/online/common_logic/common_utils/lambda_invoke_utils.py b/source/lambda/online/common_logic/common_utils/lambda_invoke_utils.py index 923f3ddde..8989d6e35 100644 --- a/source/lambda/online/common_logic/common_utils/lambda_invoke_utils.py +++ b/source/lambda/online/common_logic/common_utils/lambda_invoke_utils.py @@ -3,7 +3,7 @@ import importlib import json import time -import os +import os from typing import Any, Dict, Optional, Callable, Union import threading @@ -34,23 +34,23 @@ class StateContext: - def __init__(self,state): - self.state=state - + def __init__(self, state): + self.state = state + @classmethod def get_current_state(cls): # print("thread id",threading.get_ident(),'parent id',threading.) # state = getattr(thread_local,'state',None) state = CURRENT_STATE - assert state is not None,"There is not a valid state in current context" + assert state is not None, "There is not a valid state in current context" return state @classmethod def set_current_state(cls, state): - global CURRENT_STATE + global CURRENT_STATE assert CURRENT_STATE is None, "Parallel node executions are not alowed" CURRENT_STATE = state - + @classmethod def clear_state(cls): global CURRENT_STATE @@ -108,7 +108,8 @@ def validate_environment(cls, values: Dict): session = boto3.Session() values["client"] = session.client( "lambda", - region_name=values.get("region_name",os.environ['AWS_REGION']) + region_name=values.get( + "region_name", os.environ['AWS_REGION']) ) except Exception as e: raise ValueError( @@ -320,8 +321,8 @@ def wrapper(state: Dict[str, Any]) -> Dict[str, Any]: current_stream_use, ws_connection_id, enable_trace) state['trace_infos'].append( f"Enter: {func.__name__}, time: {time.time()}") - - with StateContext(state): + + with StateContext(state): output = func(state) current_monitor_infos = output.get(monitor_key, None) @@ -329,7 +330,8 @@ def wrapper(state: Dict[str, Any]) -> Dict[str, Any]: send_trace(f"\n\n {current_monitor_infos}", current_stream_use, ws_connection_id, enable_trace) exit_time = time.time() - state['trace_infos'].append(f"Exit: {func.__name__}, time: {time.time()}") + state['trace_infos'].append( + f"Exit: {func.__name__}, time: {time.time()}") send_trace(f"\n\n Elapsed time: {round((exit_time-enter_time)*100)/100} s", current_stream_use, ws_connection_id, enable_trace) return output diff --git a/source/lambda/online/common_logic/common_utils/langchain_utils.py b/source/lambda/online/common_logic/common_utils/langchain_utils.py index 4de56b9af..87e8c44e8 100644 --- a/source/lambda/online/common_logic/common_utils/langchain_utils.py +++ b/source/lambda/online/common_logic/common_utils/langchain_utils.py @@ -239,4 +239,4 @@ def format_trace_infos(trace_infos: list, use_pretty_table=True): class NestUpdateState(TypedDict): - keys: Annotated[dict, update_nest_dict] \ No newline at end of file + keys: Annotated[dict, update_nest_dict] diff --git a/source/lambda/online/common_logic/common_utils/logger_utils.py b/source/lambda/online/common_logic/common_utils/logger_utils.py index 118459421..366f909f1 100644 --- a/source/lambda/online/common_logic/common_utils/logger_utils.py +++ b/source/lambda/online/common_logic/common_utils/logger_utils.py @@ -75,4 +75,3 @@ def _inner(*args, **kwargs): print_llm_messages(kwargs) return fn(*args, **kwargs) return _inner - diff --git a/source/lambda/online/common_logic/common_utils/monitor_utils.py b/source/lambda/online/common_logic/common_utils/monitor_utils.py index 25088c104..045097398 100644 --- a/source/lambda/online/common_logic/common_utils/monitor_utils.py +++ b/source/lambda/online/common_logic/common_utils/monitor_utils.py @@ -29,7 +29,7 @@ def format_qq_data(data) -> str: """ if is_null_or_empty(data): return "" - + markdown_table = "**QQ Match Result**\n" markdown_table += "| Source | Score | Question | Answer |\n" markdown_table += "|-----|-----|-----|-----|\n" @@ -66,7 +66,7 @@ def format_rag_data(data, qq_result) -> str: score = item.get("score", -1) page_content = item.get("retrieval_content", "").replace("\n", "
") markdown_table += f"| {source} | {raw_source} | {score} | {page_content} |\n" - + if not is_null_or_empty(qq_result): markdown_table += "\n**QQ Match Result**\n" markdown_table += "| Source File Name | Source URI | Score | Question | Answer |\n" diff --git a/source/lambda/online/common_logic/common_utils/prompt_utils.py b/source/lambda/online/common_logic/common_utils/prompt_utils.py index 0ad6160f5..25f81c0a3 100644 --- a/source/lambda/online/common_logic/common_utils/prompt_utils.py +++ b/source/lambda/online/common_logic/common_utils/prompt_utils.py @@ -181,7 +181,7 @@ def prompt_template_render(self, prompt_template: dict): ) -################ +################ # query rewrite prompt template from paper https://arxiv.org/pdf/2401.10225 ################### CQR_SYSTEM_PROMPT = """You are a helpful, pattern-following assistant.""" @@ -273,7 +273,7 @@ def prompt_template_render(self, prompt_template: dict): LLMModelType.LLAMA3_2_90B_INSTRUCT, LLMModelType.MISTRAL_LARGE_2407, LLMModelType.COHERE_COMMAND_R_PLUS, - + ], task_type=LLMTaskType.CONVERSATION_SUMMARY_TYPE, prompt_template=CQR_SYSTEM_PROMPT, @@ -330,7 +330,6 @@ def prompt_template_render(self, prompt_template: dict): ) - ############## xml agent prompt ############# # AGENT_USER_PROMPT = "你是一个AI助理。今天是{date},{weekday}. " # register_prompt_templates( @@ -397,7 +396,7 @@ def prompt_template_render(self, prompt_template: dict): """ # AGENT_SYSTEM_PROMPT = """\ -# You are a helpful and honest AI assistant. Today is {date},{weekday}. +# You are a helpful and honest AI assistant. Today is {date},{weekday}. # Here are some guidelines for you: # # - Output your step by step thinking in one pair of and tags, here are steps for you to think about deciding to use which tool: @@ -405,7 +404,7 @@ def prompt_template_render(self, prompt_template: dict): # 2. Determine whether the current context is sufficient to answer the user's question. # 3. If the current context is sufficient to answer the user's question, call the `give_final_response` tool. # 4. If the current context is not sufficient to answer the user's question, you can consider calling one of the provided tools. -# 5. If any of required parameters of the tool you want to call do not appears in context, call the `give_rhetorical_question` tool to ask the user for more information. +# 5. If any of required parameters of the tool you want to call do not appears in context, call the `give_rhetorical_question` tool to ask the user for more information. # - Always output with the same language as the content from user. If the content is English, use English to output. If the content is Chinese, use Chinese to output. # - Always invoke one tool. # - Before invoking any tool, be sure to first output your thought process in one pair of and tag. @@ -413,7 +412,7 @@ def prompt_template_render(self, prompt_template: dict): # AGENT_SYSTEM_PROMPT = """\ -# You are a helpful and honest AI assistant. Today is {date},{weekday}. +# You are a helpful and honest AI assistant. Today is {date},{weekday}. # Here are some guidelines for you: # # - Output your step by step thinking in one pair of and tags, here are steps for you to think about deciding to use which tool: @@ -421,7 +420,7 @@ def prompt_template_render(self, prompt_template: dict): # 2. Determine whether the current context is sufficient to answer the user's question. # 3. If the current context is sufficient to answer the user's question, call the `give_final_response` tool. # 4. If the current context is not sufficient to answer the user's question, you can consider calling one of the provided tools. -# 5. If any of required parameters of the tool you want to call do not appears in context, call the `give_rhetorical_question` tool to ask the user for more information. +# 5. If any of required parameters of the tool you want to call do not appears in context, call the `give_rhetorical_question` tool to ask the user for more information. # - Always output with the same language as the content from user. If the content is English, use English to output. If the content is Chinese, use Chinese to output. # - Always invoke one tool. # @@ -505,6 +504,5 @@ def prompt_template_render(self, prompt_template: dict): ) - if __name__ == "__main__": print(get_all_templates()) diff --git a/source/lambda/online/common_logic/common_utils/pydantic_models.py b/source/lambda/online/common_logic/common_utils/pydantic_models.py index 2cfc90f96..e61493823 100644 --- a/source/lambda/online/common_logic/common_utils/pydantic_models.py +++ b/source/lambda/online/common_logic/common_utils/pydantic_models.py @@ -35,9 +35,12 @@ class LLMConfig(AllowBaseModel): model_kwargs: dict = {"temperature": 0.01, "max_tokens": 4096} +class QueryRewriteConfig(LLMConfig): + rewrite_first_message: bool = False + class QueryProcessConfig(ForbidBaseModel): - conversation_query_rewrite_config: LLMConfig = Field( - default_factory=LLMConfig) + conversation_query_rewrite_config: QueryRewriteConfig = Field( + default_factory=QueryRewriteConfig) class RetrieverConfigBase(AllowBaseModel): @@ -88,7 +91,7 @@ class RagToolConfig(AllowBaseModel): class AgentConfig(ForbidBaseModel): llm_config: LLMConfig = Field(default_factory=LLMConfig) - tools: list[Union[str,dict]] = Field(default_factory=list) + tools: list[Union[str, dict]] = Field(default_factory=list) only_use_rag_tool: bool = False @@ -178,8 +181,7 @@ def get_index_infos_from_ddb(cls, group_name, chatbot_id): cls.format_index_info(info) for info in list(index_info["value"].values()) ] - infos[index_type] = {info["index_name"] - : info for info in info_list} + infos[index_type] = {info["index_name"] : info for info in info_list} for index_type in IndexType.all_values(): if index_type not in infos: @@ -194,17 +196,18 @@ def update_retrievers( ): index_infos = self.get_index_infos_from_ddb( self.group_name, self.chatbot_id) - print(f"index_infos: {index_infos}") - print(f"default_index_names: {default_index_names}") - print(f"default_retriever_config: {default_retriever_config}") + logger.info(f"index_infos: {index_infos}") + logger.info(f"default_index_names: {default_index_names}") + logger.info(f"default_retriever_config: {default_retriever_config}") for task_name, index_names in default_index_names.items(): - assert task_name in ("qq_match", "intention", "private_knowledge") if task_name == "qq_match": index_type = IndexType.QQ elif task_name == "intention": index_type = IndexType.INTENTION elif task_name == "private_knowledge": index_type = IndexType.QD + else: + raise ValueError(f"Invalid task_name: {task_name}") # default to use all index if not index_names: diff --git a/source/lambda/online/common_logic/common_utils/python_utils.py b/source/lambda/online/common_logic/common_utils/python_utils.py index 7ef3930aa..17fa7d47b 100644 --- a/source/lambda/online/common_logic/common_utils/python_utils.py +++ b/source/lambda/online/common_logic/common_utils/python_utils.py @@ -1,6 +1,7 @@ import collections.abc -def update_nest_dict(d:dict, u:dict): + +def update_nest_dict(d: dict, u: dict): for k, v in u.items(): if isinstance(v, collections.abc.Mapping): d[k] = update_nest_dict(d.get(k, {}), v) diff --git a/source/lambda/online/common_logic/common_utils/response_utils.py b/source/lambda/online/common_logic/common_utils/response_utils.py index fe54fe083..257414bee 100644 --- a/source/lambda/online/common_logic/common_utils/response_utils.py +++ b/source/lambda/online/common_logic/common_utils/response_utils.py @@ -8,22 +8,23 @@ from common_logic.common_utils.logger_utils import get_logger logger = get_logger("response_utils") + class WebsocketClientError(Exception): pass def write_chat_history_to_ddb( - query:str, - answer:str, - ddb_obj:DynamoDBChatMessageHistory, + query: str, + answer: str, + ddb_obj: DynamoDBChatMessageHistory, message_id, custom_message_id, entry_type, additional_kwargs=None, - ): +): ddb_obj.add_user_message( - f"user_{message_id}", custom_message_id, entry_type, query, additional_kwargs - ) + f"user_{message_id}", custom_message_id, entry_type, query, additional_kwargs + ) ddb_obj.add_ai_message( f"ai_{message_id}", custom_message_id, @@ -51,23 +52,23 @@ def api_response(event_body: dict, response: dict): ) return { - "session_id": event_body['session_id'], - "entry_type": event_body['entry_type'], - "created": time.time(), - "total_time": time.time()-event_body["request_timestamp"], - "message": { - "role": "assistant", - "content": answer - }, - **response['extra_response'] + "session_id": event_body['session_id'], + "entry_type": event_body['entry_type'], + "created": time.time(), + "total_time": time.time()-event_body["request_timestamp"], + "message": { + "role": "assistant", + "content": answer + }, + **response['extra_response'] } -def stream_response(event_body:dict, response:dict): +def stream_response(event_body: dict, response: dict): request_timestamp = event_body["request_timestamp"] entry_type = event_body["entry_type"] message_id = event_body["message_id"] - log_first_token_time = True + log_first_token_time = True ws_connection_id = event_body["ws_connection_id"] custom_message_id = event_body["custom_message_id"] answer = response["answer"] @@ -79,10 +80,10 @@ def stream_response(event_body:dict, response:dict): try: send_to_ws_client(message={ - "message_type": StreamMessageType.START, - "message_id": f"ai_{message_id}", - "custom_message_id": custom_message_id, - }, + "message_type": StreamMessageType.START, + "message_id": f"ai_{message_id}", + "custom_message_id": custom_message_id, + }, ws_connection_id=ws_connection_id ) answer_str = "" @@ -90,25 +91,25 @@ def stream_response(event_body:dict, response:dict): for i, chunk in enumerate(answer): if i == 0 and log_first_token_time: first_token_time = time.time() - + logger.info( f"{custom_message_id} running time of first token whole {entry_type} entry: {first_token_time-request_timestamp}s" ) send_to_ws_client(message={ - "message_type": StreamMessageType.CHUNK, - "message_id": f"ai_{message_id}", - "custom_message_id": custom_message_id, - "message": { - "role": "assistant", - "content": chunk, - # "knowledge_sources": sources, - }, - "chunk_id": i, + "message_type": StreamMessageType.CHUNK, + "message_id": f"ai_{message_id}", + "custom_message_id": custom_message_id, + "message": { + "role": "assistant", + "content": chunk, + # "knowledge_sources": sources, }, + "chunk_id": i, + }, ws_connection_id=ws_connection_id ) answer_str += chunk - + if log_first_token_time: logger.info( f"{custom_message_id} running time of last token whole {entry_type} entry: {time.time()-request_timestamp}s" @@ -123,7 +124,7 @@ def stream_response(event_body:dict, response:dict): message_id=message_id, custom_message_id=custom_message_id, entry_type=entry_type, - additional_kwargs=response.get("ddb_additional_kwargs",{}) + additional_kwargs=response.get("ddb_additional_kwargs", {}) ) # Send source and contexts @@ -180,8 +181,9 @@ def __call__(self, answer, contexts): return stream_response(**kwargs) -def process_response(event_body,response): - stream = event_body["stream"] + +def process_response(event_body, response): + stream = event_body.get("stream", True) if stream: - return stream_response(event_body,response) - return api_response(event_body,response) + return stream_response(event_body, response) + return api_response(event_body, response) diff --git a/source/lambda/online/common_logic/common_utils/s3_utils.py b/source/lambda/online/common_logic/common_utils/s3_utils.py index 658a556e4..343eb55a6 100644 --- a/source/lambda/online/common_logic/common_utils/s3_utils.py +++ b/source/lambda/online/common_logic/common_utils/s3_utils.py @@ -1,6 +1,7 @@ import os import boto3 + def download_dir_from_s3(bucket_name, s3_dir_path, local_dir_path): s3 = boto3.client('s3') paginator = s3.get_paginator('list_objects_v2') @@ -8,29 +9,37 @@ def download_dir_from_s3(bucket_name, s3_dir_path, local_dir_path): if result.get('Contents') is not None: for file in result.get('Contents'): if not os.path.exists(os.path.dirname(local_dir_path + os.sep + file.get('Key'))): - os.makedirs(os.path.dirname(local_dir_path + os.sep + file.get('Key'))) - s3.download_file(bucket_name, file.get('Key'), local_dir_path + os.sep + file.get('Key')) + os.makedirs(os.path.dirname( + local_dir_path + os.sep + file.get('Key'))) + s3.download_file(bucket_name, file.get('Key'), + local_dir_path + os.sep + file.get('Key')) + def download_file_from_s3(bucket_name, s3_file_path, local_file_path): s3 = boto3.client('s3') s3.download_file(bucket_name, s3_file_path, local_file_path) + def delete_s3_file(bucket_name, s3_file_path): s3 = boto3.client('s3') s3.delete_object(Bucket=bucket_name, Key=s3_file_path) + def upload_file_to_s3(bucket_name, s3_file_path, local_file_path): s3 = boto3.client('s3') s3.upload_file(local_file_path, bucket_name, s3_file_path) + def upload_dir_to_s3(bucket_name, s3_dir_path, local_dir_path): for root, dirs, files in os.walk(local_dir_path): for file in files: local_file_path = os.path.join(root, file) - s3_file_path = os.path.join(s3_dir_path, local_file_path[len(local_dir_path)+1:]) + s3_file_path = os.path.join( + s3_dir_path, local_file_path[len(local_dir_path)+1:]) print(f"Uploading {local_file_path} to {s3_file_path}") upload_file_to_s3(bucket_name, s3_file_path, local_file_path) + def check_local_folder(file_path): folder = '/'.join(file_path.split('/')[:-1]) if not os.path.exists(folder): diff --git a/source/lambda/online/common_logic/common_utils/time_utils.py b/source/lambda/online/common_logic/common_utils/time_utils.py index f89422999..f2d22e085 100644 --- a/source/lambda/online/common_logic/common_utils/time_utils.py +++ b/source/lambda/online/common_logic/common_utils/time_utils.py @@ -18,7 +18,8 @@ def timeit_wrapper(*args, **kwargs): end_time = time.perf_counter() total_time = end_time - start_time # first item in the args, ie `args[0]` is `self` - logger.info(f'Function {func.__name__} {str(args)[:32]} {str(kwargs)[:32]} Took {total_time:.4f} seconds\n') + logger.info( + f'Function {func.__name__} {str(args)[:32]} {str(kwargs)[:32]} Took {total_time:.4f} seconds\n') return result return timeit_wrapper @@ -27,7 +28,7 @@ def get_china_now(): SHA_TZ = timezone( timedelta(hours=8), name='Asia/Shanghai' - ) + ) # 协调世界时 utc_now = datetime.utcnow().replace(tzinfo=timezone.utc) - return utc_now.date() \ No newline at end of file + return utc_now.date() diff --git a/source/lambda/online/common_logic/common_utils/websocket_utils.py b/source/lambda/online/common_logic/common_utils/websocket_utils.py index 65a68e3ae..1fe904773 100644 --- a/source/lambda/online/common_logic/common_utils/websocket_utils.py +++ b/source/lambda/online/common_logic/common_utils/websocket_utils.py @@ -27,10 +27,12 @@ def is_websocket_request(event): else: return False + def load_ws_client(websocket_url): global ws_client if ws_client is None: - ws_client = boto3.client("apigatewaymanagementapi", endpoint_url=websocket_url) + ws_client = boto3.client( + "apigatewaymanagementapi", endpoint_url=websocket_url) return ws_client @@ -39,4 +41,3 @@ def send_to_ws_client(message: dict, ws_connection_id): ConnectionId=ws_connection_id, Data=json.dumps(message).encode("utf-8"), ) - diff --git a/source/lambda/online/common_logic/langchain_integration/chains/__llm_chain_base.py b/source/lambda/online/common_logic/langchain_integration/chains/__llm_chain_base.py index 98ae93d34..f30dc1f12 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/__llm_chain_base.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/__llm_chain_base.py @@ -23,4 +23,3 @@ def get_chain(cls, model_id, intent_type, model_kwargs=None, **kwargs): return cls.model_map[cls._get_chain_id(model_id, intent_type)].create_chain( model_kwargs=model_kwargs, **kwargs ) - diff --git a/source/lambda/online/common_logic/langchain_integration/chains/chat_chain.py b/source/lambda/online/common_logic/langchain_integration/chains/chat_chain.py index 325c55fdb..230279054 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/chat_chain.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/chat_chain.py @@ -1,8 +1,8 @@ # chat llm chains from langchain.schema.runnable import RunnableLambda, RunnablePassthrough -from langchain_core.messages import AIMessage,SystemMessage -from langchain.prompts import ChatPromptTemplate,HumanMessagePromptTemplate +from langchain_core.messages import AIMessage, SystemMessage +from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain_core.messages import convert_to_messages from langchain_core.output_parsers import StrOutputParser @@ -27,14 +27,14 @@ class Claude2ChatChain(LLMChain): model_id = LLMModelType.CLAUDE_2 intent_type = LLMTaskType.CHAT - @classmethod - def get_common_system_prompt(cls,system_prompt_template:str): + def get_common_system_prompt(cls, system_prompt_template: str): now = get_china_now() date_str = now.strftime("%Y年%m月%d日") weekdays = ['星期一', '星期二', '星期三', '星期四', '星期五', '星期六', '星期日'] weekday = weekdays[now.weekday()] - system_prompt = system_prompt_template.format(date=date_str,weekday=weekday) + system_prompt = system_prompt_template.format( + date=date_str, weekday=weekday) return system_prompt @classmethod @@ -43,24 +43,26 @@ def create_chain(cls, model_kwargs=None, **kwargs): system_prompt_template = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="system_prompt" + prompt_name="system_prompt" ).prompt_template - system_prompt = kwargs.get('system_prompt',system_prompt_template) or "" + system_prompt = kwargs.get( + 'system_prompt', system_prompt_template) or "" system_prompt = cls.get_common_system_prompt(system_prompt) - prefill = kwargs.get('prefill',None) + prefill = kwargs.get('prefill', None) messages = [ ("placeholder", "{chat_history}"), HumanMessagePromptTemplate.from_template("{query}") ] if system_prompt: - messages.insert(0,SystemMessage(content=system_prompt)) - + messages.insert(0, SystemMessage(content=system_prompt)) + if prefill is not None: messages.append(AIMessage(content=prefill)) messages_template = ChatPromptTemplate.from_messages(messages) - llm = Model.get_model(cls.model_id, model_kwargs=model_kwargs, **kwargs) + llm = Model.get_model( + cls.model_id, model_kwargs=model_kwargs, **kwargs) chain = messages_template | RunnableLambda(lambda x: x.messages) chain = chain | llm | StrOutputParser() @@ -108,6 +110,7 @@ class Mixtral8x7bChatChain(Claude2ChatChain): class Llama31Instruct70BChatChain(Claude2ChatChain): model_id = LLMModelType.LLAMA3_1_70B_INSTRUCT + class Llama32Instruct90BChatChain(Claude2ChatChain): model_id = LLMModelType.LLAMA3_2_90B_INSTRUCT @@ -120,7 +123,6 @@ class CohereCommandRPlusChatChain(Claude2ChatChain): model_id = LLMModelType.COHERE_COMMAND_R_PLUS - class Baichuan2Chat13B4BitsChatChain(LLMChain): model_id = LLMModelType.BAICHUAN2_13B_CHAT intent_type = LLMTaskType.CHAT @@ -139,7 +141,8 @@ def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} model_kwargs.update({"stream": stream}) model_kwargs = {**cls.default_model_kwargs, **model_kwargs} - llm = Model.get_model(cls.model_id, model_kwargs=model_kwargs, **kwargs) + llm = Model.get_model( + cls.model_id, model_kwargs=model_kwargs, **kwargs) llm_chain = RunnableLambda(lambda x: llm.invoke(x, stream=stream)) return llm_chain @@ -184,14 +187,14 @@ def create_history(cls, x): return history @classmethod - def create_prompt(cls, x,system_prompt=None): + def create_prompt(cls, x, system_prompt=None): history = cls.create_history(x) if system_prompt is None: system_prompt = get_prompt_template( - model_id=cls.model_id, - task_type=cls.intent_type, - prompt_name="system_prompt" - ).prompt_template + model_id=cls.model_id, + task_type=cls.intent_type, + prompt_name="system_prompt" + ).prompt_template prompt = cls.build_prompt( query=x["query"], @@ -205,11 +208,13 @@ def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} model_kwargs = {**cls.default_model_kwargs, **model_kwargs} stream = kwargs.get("stream", False) - system_prompt = kwargs.get("system_prompt",None) - llm = Model.get_model(cls.model_id, model_kwargs=model_kwargs, **kwargs) + system_prompt = kwargs.get("system_prompt", None) + llm = Model.get_model( + cls.model_id, model_kwargs=model_kwargs, **kwargs) prompt_template = RunnablePassthrough.assign( - prompt=RunnableLambda(lambda x: cls.create_prompt(x,system_prompt=system_prompt)) + prompt=RunnableLambda(lambda x: cls.create_prompt( + x, system_prompt=system_prompt)) ) llm_chain = prompt_template | RunnableLambda( lambda x: llm.invoke(x, stream=stream) @@ -229,28 +234,31 @@ class GLM4Chat9BChatChain(LLMChain): "timeout": 60, "temperature": 0.1, } + @classmethod - def create_chat_history(cls,x, system_prompt=None): + def create_chat_history(cls, x, system_prompt=None): if system_prompt is None: system_prompt = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="system_prompt" + prompt_name="system_prompt" ).prompt_template chat_history = x['chat_history'] if system_prompt is not None: - chat_history = [{"role":"system","content": system_prompt}] + chat_history - chat_history = chat_history + [{"role":MessageType.HUMAN_MESSAGE_TYPE,"content":x['query']}] + chat_history = [ + {"role": "system", "content": system_prompt}] + chat_history + chat_history = chat_history + \ + [{"role": MessageType.HUMAN_MESSAGE_TYPE, "content": x['query']}] return chat_history - + @classmethod def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} model_kwargs = {**cls.default_model_kwargs, **model_kwargs} - system_prompt = kwargs.get("system_prompt",None) + system_prompt = kwargs.get("system_prompt", None) llm = Model.get_model( model_id=cls.model_id, model_kwargs=model_kwargs, @@ -258,9 +266,10 @@ def create_chain(cls, model_kwargs=None, **kwargs): ) chain = RunnablePassthrough.assign( - chat_history = RunnableLambda(lambda x: cls.create_chat_history(x,system_prompt=system_prompt)) + chat_history=RunnableLambda( + lambda x: cls.create_chat_history(x, system_prompt=system_prompt)) ) | RunnableLambda(lambda x: llm.invoke(x)) - + return chain @@ -273,34 +282,34 @@ class Qwen2Instruct7BChatChain(LLMChain): } @classmethod - def create_chat_history(cls,x, system_prompt=None): + def create_chat_history(cls, x, system_prompt=None): if system_prompt is None: system_prompt = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="system_prompt" + prompt_name="system_prompt" ).prompt_template chat_history = x['chat_history'] if system_prompt is not None: - chat_history = [{"role":"system", "content": system_prompt}] + chat_history - - chat_history = chat_history + [{"role": MessageType.HUMAN_MESSAGE_TYPE, "content":x['query']}] - return chat_history + chat_history = [ + {"role": "system", "content": system_prompt}] + chat_history + chat_history = chat_history + \ + [{"role": MessageType.HUMAN_MESSAGE_TYPE, "content": x['query']}] + return chat_history @classmethod - def parse_function_calls_from_ai_message(cls,message:dict): + def parse_function_calls_from_ai_message(cls, message: dict): return message['text'] - @classmethod def create_chain(cls, model_kwargs=None, **kwargs): stream = kwargs.get("stream", False) model_kwargs = model_kwargs or {} model_kwargs = {**cls.default_model_kwargs, **model_kwargs} - system_prompt = kwargs.get("system_prompt",None) + system_prompt = kwargs.get("system_prompt", None) llm = Model.get_model( model_id=cls.model_id, @@ -309,18 +318,20 @@ def create_chain(cls, model_kwargs=None, **kwargs): ) chain = RunnablePassthrough.assign( - chat_history = RunnableLambda(lambda x: cls.create_chat_history(x,system_prompt=system_prompt)) + chat_history=RunnableLambda( + lambda x: cls.create_chat_history(x, system_prompt=system_prompt)) ) | RunnableLambda(lambda x: llm.invoke(x)) | RunnableLambda(lambda x: cls.parse_function_calls_from_ai_message(x)) - + return chain + class Qwen2Instruct72BChatChain(Qwen2Instruct7BChatChain): model_id = LLMModelType.QWEN2INSTRUCT72B class Qwen2Instruct72BChatChain(Qwen2Instruct7BChatChain): model_id = LLMModelType.QWEN15INSTRUCT32B - + class ChatGPT35ChatChain(LLMChain): model_id = LLMModelType.CHATGPT_35_TURBO_0125 @@ -329,20 +340,21 @@ class ChatGPT35ChatChain(LLMChain): @classmethod def create_chain(cls, model_kwargs=None, **kwargs): stream = kwargs.get("stream", False) - system_prompt = kwargs.get('system_prompt',None) - prefill = kwargs.get('prefill',None) + system_prompt = kwargs.get('system_prompt', None) + prefill = kwargs.get('prefill', None) messages = [ ("placeholder", "{chat_history}"), HumanMessagePromptTemplate.from_template("{query}") ] if system_prompt is not None: - messages.insert(SystemMessage(content=system_prompt),0) - + messages.insert(SystemMessage(content=system_prompt), 0) + if prefill is not None: messages.append(AIMessage(content=prefill)) messages_template = ChatPromptTemplate.from_messages(messages) - llm = Model.get_model(cls.model_id, model_kwargs=model_kwargs, **kwargs) + llm = Model.get_model( + cls.model_id, model_kwargs=model_kwargs, **kwargs) chain = messages_template | RunnableLambda(lambda x: x.messages) chain = chain | llm | StrOutputParser() @@ -353,8 +365,10 @@ def create_chain(cls, model_kwargs=None, **kwargs): return final_chain + class ChatGPT4ChatChain(ChatGPT35ChatChain): model_id = LLMModelType.CHATGPT_4_TURBO + class ChatGPT4oChatChain(ChatGPT35ChatChain): model_id = LLMModelType.CHATGPT_4O diff --git a/source/lambda/online/common_logic/langchain_integration/chains/conversation_summary_chain.py b/source/lambda/online/common_logic/langchain_integration/chains/conversation_summary_chain.py index 8b7dc1009..d94a1ab7e 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/conversation_summary_chain.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/conversation_summary_chain.py @@ -1,6 +1,6 @@ # conversation summary chain -from typing import List -import json +from typing import List +import json from langchain.schema.runnable import ( RunnableLambda ) @@ -15,20 +15,20 @@ LLMModelType ) -from langchain_core.messages import( +from langchain_core.messages import ( AIMessage, BaseMessage, HumanMessage, SystemMessage, convert_to_messages -) +) from langchain.prompts import ( HumanMessagePromptTemplate, ChatPromptTemplate ) from common_logic.common_utils.prompt_utils import get_prompt_template -from common_logic.common_utils.logger_utils import get_logger,print_llm_messages +from common_logic.common_utils.logger_utils import get_logger, print_llm_messages logger = get_logger("conversation_summary") @@ -47,7 +47,7 @@ class Internlm2Chat20BConversationSummaryChain(Internlm2Chat7BChatChain): } @classmethod - def create_prompt(cls, x,system_prompt=None): + def create_prompt(cls, x, system_prompt=None): chat_history = x["chat_history"] conversational_contexts = [] for his in chat_history: @@ -58,11 +58,11 @@ def create_prompt(cls, x,system_prompt=None): else: conversational_contexts.append(f"AI: {his['content']}") if system_prompt is None: - system_prompt = get_prompt_template( - model_id=cls.model_id, - task_type=cls.intent_type, - prompt_name="system_prompt" - ).prompt_template + system_prompt = get_prompt_template( + model_id=cls.model_id, + task_type=cls.intent_type, + prompt_name="system_prompt" + ).prompt_template conversational_context = "\n".join(conversational_contexts) prompt = cls.build_prompt( @@ -73,6 +73,7 @@ def create_prompt(cls, x,system_prompt=None): prompt = prompt + "Standalone Question: " return prompt + class Internlm2Chat7BConversationSummaryChain(Internlm2Chat20BConversationSummaryChain): model_id = LLMModelType.INTERNLM2_CHAT_7B @@ -81,16 +82,17 @@ class Claude2ConversationSummaryChain(LLMChain): model_id = LLMModelType.CLAUDE_2 intent_type = LLMTaskType.CONVERSATION_SUMMARY_TYPE - default_model_kwargs = {"max_tokens": 2000, "temperature": 0.1, "top_p": 0.9} + default_model_kwargs = {"max_tokens": 2000, + "temperature": 0.1, "top_p": 0.9} prefill = "From PersonU's point of view, here is the single standalone sentence:" @staticmethod - def create_conversational_context(chat_history:List[BaseMessage]): + def create_conversational_context(chat_history: List[BaseMessage]): conversational_contexts = [] for his in chat_history: - assert isinstance(his,(AIMessage,HumanMessage)), his + assert isinstance(his, (AIMessage, HumanMessage)), his content = his.content - if isinstance(his,HumanMessage): + if isinstance(his, HumanMessage): conversational_contexts.append(f"USER: {content}") else: conversational_contexts.append(f"AI: {content}") @@ -98,79 +100,80 @@ def create_conversational_context(chat_history:List[BaseMessage]): return conversational_context @classmethod - def format_conversation(cls,conversation:list[BaseMessage]): + def format_conversation(cls, conversation: list[BaseMessage]): conversation_strs = [] - for message in conversation: - assert isinstance(message,(AIMessage,HumanMessage)), message + for message in conversation: + assert isinstance(message, (AIMessage, HumanMessage)), message content = message.content if isinstance(message, HumanMessage): conversation_strs.append(f"PersonU: {content}") elif isinstance(message, AIMessage): conversation_strs.append(f"PersonA: {content}") return "\n".join(conversation_strs) - + @classmethod - def create_messages_inputs(cls,x:dict,user_prompt,few_shots:list[dict]): + def create_messages_inputs(cls, x: dict, user_prompt, few_shots: list[dict]): # create few_shots few_shot_messages = [] for few_shot in few_shots: - conversation=cls.format_conversation( - convert_to_messages(few_shot['conversation']) + conversation = cls.format_conversation( + convert_to_messages(few_shot['conversation']) ) few_shot_messages.append(HumanMessage(content=user_prompt.format( - conversation=conversation, - current_query=few_shot['conversation'][-1]['content'] - ))) - few_shot_messages.append(AIMessage(content=f"{cls.prefill} {few_shot['rewrite_query']}")) + conversation=conversation, + current_query=few_shot['conversation'][-1]['content'] + ))) + few_shot_messages.append( + AIMessage(content=f"{cls.prefill} {few_shot['rewrite_query']}")) # create current cocnversation cur_messages = convert_to_messages( - x['chat_history'] + [{"role":MessageType.HUMAN_MESSAGE_TYPE,"content":x['query']}] + x['chat_history'] + + [{"role": MessageType.HUMAN_MESSAGE_TYPE, "content": x['query']}] ) - + conversation = cls.format_conversation(cur_messages) return { - "conversation":conversation, - "few_shots":few_shot_messages, + "conversation": conversation, + "few_shots": few_shot_messages, "current_query": x['query'] } @classmethod - def create_messages_chain(cls,**kwargs): + def create_messages_chain(cls, **kwargs): enable_prefill = kwargs['enable_prefill'] system_prompt = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="system_prompt" + prompt_name="system_prompt" ).prompt_template user_prompt = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="user_prompt" + prompt_name="user_prompt" ).prompt_template few_shots = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="few_shots" + prompt_name="few_shots" ).prompt_template system_prompt = kwargs.get("system_prompt", system_prompt) user_prompt = kwargs.get('user_prompt', user_prompt) - messages = [ SystemMessage(content=system_prompt), - ('placeholder','{few_shots}'), + ('placeholder', '{few_shots}'), HumanMessagePromptTemplate.from_template(user_prompt) ] if enable_prefill: messages.append(AIMessage(content=cls.prefill)) cqr_template = ChatPromptTemplate.from_messages(messages) - return RunnableLambda(lambda x: cls.create_messages_inputs(x,user_prompt=user_prompt,few_shots=json.loads(few_shots))) | cqr_template - + return RunnableLambda(lambda x: cls.create_messages_inputs(x, user_prompt=user_prompt, few_shots=json.loads(few_shots))) | cqr_template + @classmethod def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} @@ -179,9 +182,10 @@ def create_chain(cls, model_kwargs=None, **kwargs): model_id=cls.model_id, model_kwargs=model_kwargs, ) - messages_chain = cls.create_messages_chain(**kwargs,enable_prefill=llm.enable_prefill) + messages_chain = cls.create_messages_chain( + **kwargs, enable_prefill=llm.enable_prefill) chain = messages_chain | RunnableLambda(lambda x: print_llm_messages(f"conversation summary messages: {x.messages}") or x.messages) \ - | llm | RunnableLambda(lambda x: x.content.replace(cls.prefill,"").strip()) + | llm | RunnableLambda(lambda x: x.content.replace(cls.prefill, "").strip()) return chain @@ -208,6 +212,7 @@ class Claude35HaikuConversationSummaryChain(Claude2ConversationSummaryChain): class Claude35SonnetConversationSummaryChain(Claude2ConversationSummaryChain): model_id = LLMModelType.CLAUDE_3_5_SONNET + class Claude35SonnetV2ConversationSummaryChain(Claude2ConversationSummaryChain): model_id = LLMModelType.CLAUDE_3_5_SONNET_V2 @@ -233,7 +238,6 @@ class CohereCommandRPlusConversationSummaryChain(Claude2ConversationSummaryChain model_id = LLMModelType.COHERE_COMMAND_R_PLUS - class Qwen2Instruct72BConversationSummaryChain(Claude2ConversationSummaryChain): model_id = LLMModelType.QWEN2INSTRUCT72B @@ -248,5 +252,3 @@ class Qwen2Instruct7BConversationSummaryChain(Claude2ConversationSummaryChain): class GLM4Chat9BConversationSummaryChain(Claude2ConversationSummaryChain): model_id = LLMModelType.GLM_4_9B_CHAT - - diff --git a/source/lambda/online/common_logic/langchain_integration/chains/intention_chain.py b/source/lambda/online/common_logic/langchain_integration/chains/intention_chain.py index bc2602beb..b5483ef9a 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/intention_chain.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/intention_chain.py @@ -12,7 +12,7 @@ RunnablePassthrough, ) -from common_logic.common_utils.constant import LLMTaskType,LLMModelType +from common_logic.common_utils.constant import LLMTaskType, LLMModelType from ..chat_models import Model from .chat_chain import Internlm2Chat7BChatChain from . import LLMChain @@ -50,7 +50,7 @@ def load_intention_file(intent_save_path=intent_save_path, seed=42): class Internlm2Chat7BIntentRecognitionChain(Internlm2Chat7BChatChain): model_id = LLMModelType.INTERNLM2_CHAT_7B - intent_type =LLMTaskType.INTENT_RECOGNITION_TYPE + intent_type = LLMTaskType.INTENT_RECOGNITION_TYPE default_model_kwargs = { "temperature": 0.1, @@ -67,7 +67,8 @@ def create_prompt(cls, x): example_strs = [] for example in few_shot_examples: example_strs.append( - exmaple_template.format(query=example["query"], label=example["intent"]) + exmaple_template.format( + query=example["query"], label=example["intent"]) ) example_str = "\n\n".join(example_strs) @@ -148,7 +149,8 @@ def create_few_shot_example_string( for example in cls.few_shot_examples: example_strs.append( example_template.format( - label=intent_indexs[example["intent"]], query=example["query"] + label=intent_indexs[example["intent"] + ], query=example["query"] ) ) return "\n\n".join(example_strs) diff --git a/source/lambda/online/common_logic/langchain_integration/chains/marketing_chains/mkt_conversation_summary.py b/source/lambda/online/common_logic/langchain_integration/chains/marketing_chains/mkt_conversation_summary.py index 87cc8b584..82690c2db 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/marketing_chains/mkt_conversation_summary.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/marketing_chains/mkt_conversation_summary.py @@ -20,6 +20,7 @@ CHIT_CHAT_SYSTEM_TEMPLATE = """You are a helpful AI Assistant""" + class Internlm2Chat7BMKTConversationSummaryChain(Internlm2Chat7BChatChain): model_id = LLMModelType.INTERNLM2_CHAT_7B intent_type = MKT_CONVERSATION_SUMMARY_TYPE @@ -39,7 +40,8 @@ def _create_prompt(cls, x): assert chat_history[i].type == HUMAN_MESSAGE_TYPE, chat_history assert chat_history[i + 1].type == AI_MESSAGE_TYPE, chat_history questions.append(chat_history[i].content) - history.append((chat_history[i].content, chat_history[i + 1].content)) + history.append( + (chat_history[i].content, chat_history[i + 1].content)) questions_str = "" for i, question in enumerate(questions): @@ -70,7 +72,8 @@ def create_chain(cls, model_kwargs=None, **kwargs): stream = kwargs.get("stream", False) llm_chain = super().create_chain(model_kwargs=model_kwargs, **kwargs) chain = ( - RunnablePassthrough.assign(prompt_dict=lambda x: cls._create_prompt(x)) + RunnablePassthrough.assign( + prompt_dict=lambda x: cls._create_prompt(x)) | RunnablePassthrough.assign( prompt=lambda x: x["prompt_dict"]["prompt"], prefix=lambda x: x["prompt_dict"]["prefix"], @@ -78,9 +81,11 @@ def create_chain(cls, model_kwargs=None, **kwargs): | RunnablePassthrough.assign(llm_output=llm_chain) ) if stream: - chain = chain | RunnableLambda(lambda x: cls.stream_postprocess_fn(x)) + chain = chain | RunnableLambda( + lambda x: cls.stream_postprocess_fn(x)) else: - chain = chain | RunnableLambda(lambda x: x["prefix"] + x["llm_output"]) + chain = chain | RunnableLambda( + lambda x: x["prefix"] + x["llm_output"]) return chain @@ -94,7 +99,8 @@ class Claude2MKTConversationSummaryChain(Claude2ChatChain): model_id = LLMModelType.CLAUDE_2 intent_type = MKT_CONVERSATION_SUMMARY_TYPE - default_model_kwargs = {"max_tokens": 2000, "temperature": 0.1, "top_p": 0.9} + default_model_kwargs = {"max_tokens": 2000, + "temperature": 0.1, "top_p": 0.9} @classmethod def create_chain(cls, model_kwargs=None, **kwargs): diff --git a/source/lambda/online/common_logic/langchain_integration/chains/marketing_chains/mkt_rag_chain.py b/source/lambda/online/common_logic/langchain_integration/chains/marketing_chains/mkt_rag_chain.py index 3ce5f2631..1a49dfb31 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/marketing_chains/mkt_rag_chain.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/marketing_chains/mkt_rag_chain.py @@ -3,17 +3,19 @@ LLMModelType ) from ..chat_chain import Internlm2Chat7BChatChain -from common_logic.common_utils.prompt_utils import register_prompt_templates,get_prompt_template +from common_logic.common_utils.prompt_utils import register_prompt_templates, get_prompt_template INTERLM2_RAG_PROMPT_TEMPLATE = "你是一个Amazon AWS的客服助理小Q,帮助的用户回答使用AWS过程中的各种问题。\n面对用户的问题,你需要给出中文回答,注意不要在回答中重复输出内容。\n下面给出相关问题的背景知识, 需要注意的是如果你认为当前的问题不能在背景知识中找到答案, 你需要拒答。\n背景知识:\n{context}\n\n" register_prompt_templates( - model_ids=[LLMModelType.INTERNLM2_CHAT_7B,LLMModelType.INTERNLM2_CHAT_20B], + model_ids=[LLMModelType.INTERNLM2_CHAT_7B, + LLMModelType.INTERNLM2_CHAT_20B], task_type=LLMTaskType.MTK_RAG, prompt_template=INTERLM2_RAG_PROMPT_TEMPLATE, prompt_name="system_prompt" ) + class Internlm2Chat7BKnowledgeQaChain(Internlm2Chat7BChatChain): model_id = LLMModelType.INTERNLM2_CHAT_7B intent_type = LLMTaskType.MTK_RAG @@ -26,9 +28,9 @@ def create_prompt(cls, x): history = cls.create_history(x) context = "\n".join(contexts) prompt_template = get_prompt_template( - model_id = cls.model_id, - task_type = cls.task_type, - prompt_name = "system_prompt" + model_id=cls.model_id, + task_type=cls.task_type, + prompt_name="system_prompt" ).prompt_template meta_instruction = prompt_template.format(context) # meta_instruction = f"You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use simplified Chinese to response the qustion. I’m going to tip $300K for a better answer! " @@ -52,4 +54,4 @@ def create_prompt(cls, x): class Internlm2Chat20BKnowledgeQaChain(Internlm2Chat7BKnowledgeQaChain): - model_id = LLMModelType.INTERNLM2_CHAT_20B \ No newline at end of file + model_id = LLMModelType.INTERNLM2_CHAT_20B diff --git a/source/lambda/online/common_logic/langchain_integration/chains/query_rewrite_chain.py b/source/lambda/online/common_logic/langchain_integration/chains/query_rewrite_chain.py index 480902b83..e00007459 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/query_rewrite_chain.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/query_rewrite_chain.py @@ -59,7 +59,8 @@ def create_chain(cls, model_kwargs=None, **kwargs): query_key = kwargs.pop("query_key", "query") model_kwargs = model_kwargs or {} model_kwargs = {**cls.default_model_kwargs, **model_kwargs} - llm = LLM_Model.get_model(cls.model_id, model_kwargs=model_kwargs, **kwargs) + llm = LLM_Model.get_model( + cls.model_id, model_kwargs=model_kwargs, **kwargs) chain = ( RunnablePassthrough.assign(question=lambda x: x[query_key]) | query_expansion_template_claude @@ -142,7 +143,8 @@ def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} model_kwargs = {**cls.default_model_kwargs, **model_kwargs} chain = super().create_chain(model_kwargs=model_kwargs, **kwargs) - chain = chain | RunnableLambda(lambda x: cls.query_rewrite_postprocess(x)) + chain = chain | RunnableLambda( + lambda x: cls.query_rewrite_postprocess(x)) return chain diff --git a/source/lambda/online/common_logic/langchain_integration/chains/rag_chain.py b/source/lambda/online/common_logic/langchain_integration/chains/rag_chain.py index bfacac8f5..0a2ed0dc2 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/rag_chain.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/rag_chain.py @@ -1,4 +1,7 @@ # rag llm chains +from .chat_chain import Baichuan2Chat13B4BitsChatChain +from .chat_chain import Qwen2Instruct7BChatChain +from .chat_chain import GLM4Chat9BChatChain from langchain.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, @@ -41,10 +44,11 @@ def create_chain(cls, model_kwargs=None, **kwargs): system_prompt_template = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="system_prompt" + prompt_name="system_prompt" ).prompt_template - system_prompt_template = kwargs.get("system_prompt",system_prompt_template) + system_prompt_template = kwargs.get( + "system_prompt", system_prompt_template) chat_messages = [ SystemMessagePromptTemplate.from_template(system_prompt_template), @@ -52,11 +56,14 @@ def create_chain(cls, model_kwargs=None, **kwargs): HumanMessagePromptTemplate.from_template("{query}") ] context_chain = RunnablePassthrough.assign( - context=RunnableLambda(lambda x: get_claude_rag_context(x["contexts"])) + context=RunnableLambda( + lambda x: get_claude_rag_context(x["contexts"])) ) - llm = Model.get_model(cls.model_id, model_kwargs=model_kwargs, **kwargs) - chain = context_chain | ChatPromptTemplate.from_messages(chat_messages) | RunnableLambda(lambda x: print_llm_messages(f"rag messages: {x.messages}") or x) - + llm = Model.get_model( + cls.model_id, model_kwargs=model_kwargs, **kwargs) + chain = context_chain | ChatPromptTemplate.from_messages(chat_messages) | RunnableLambda( + lambda x: print_llm_messages(f"rag messages: {x.messages}") or x) + chain = chain | llm | StrOutputParser() if stream: @@ -82,6 +89,7 @@ class Claude3SonnetRAGLLMChain(Claude2RagLLMChain): class Claude3HaikuRAGLLMChain(Claude2RagLLMChain): model_id = LLMModelType.CLAUDE_3_HAIKU + class Claude35SonnetRAGLLMChain(Claude2RagLLMChain): model_id = LLMModelType.CLAUDE_3_5_SONNET @@ -89,6 +97,7 @@ class Claude35SonnetRAGLLMChain(Claude2RagLLMChain): class Claude35SonnetV2RAGLLMChain(Claude2RagLLMChain): model_id = LLMModelType.CLAUDE_3_5_SONNET_V2 + class Claude35HaikuRAGLLMChain(Claude2RagLLMChain): model_id = LLMModelType.CLAUDE_3_5_HAIKU @@ -96,6 +105,7 @@ class Claude35HaikuRAGLLMChain(Claude2RagLLMChain): class Llama31Instruct70B(Claude2RagLLMChain): model_id = LLMModelType.LLAMA3_1_70B_INSTRUCT + class Llama32Instruct90B(Claude2RagLLMChain): model_id = LLMModelType.LLAMA3_2_90B_INSTRUCT @@ -112,44 +122,40 @@ class Mixtral8x7bChatChain(Claude2RagLLMChain): model_id = LLMModelType.MIXTRAL_8X7B_INSTRUCT -from .chat_chain import GLM4Chat9BChatChain - class GLM4Chat9BRagChain(GLM4Chat9BChatChain): model_id = LLMModelType.GLM_4_9B_CHAT intent_type = LLMTaskType.RAG @classmethod - def create_chat_history(cls,x, system_prompt=None): + def create_chat_history(cls, x, system_prompt=None): if system_prompt is None: system_prompt = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="system_prompt" + prompt_name="system_prompt" ).prompt_template - context = ("\n" + "="*50+ "\n").join(x['contexts']) + context = ("\n" + "="*50 + "\n").join(x['contexts']) system_prompt = system_prompt.format(context=context) - return super().create_chat_history(x,system_prompt=system_prompt) - + return super().create_chat_history(x, system_prompt=system_prompt) -from .chat_chain import Qwen2Instruct7BChatChain class Qwen2Instruct7BRagChain(Qwen2Instruct7BChatChain): model_id = LLMModelType.QWEN2INSTRUCT7B intent_type = LLMTaskType.RAG @classmethod - def create_chat_history(cls,x, system_prompt=None): + def create_chat_history(cls, x, system_prompt=None): if system_prompt is None: system_prompt = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="system_prompt" + prompt_name="system_prompt" ).prompt_template - + context = ("\n\n").join(x['contexts']) system_prompt = system_prompt.format(context=context) - return super().create_chat_history(x,system_prompt=system_prompt) + return super().create_chat_history(x, system_prompt=system_prompt) class Qwen2Instruct72BRagChain(Qwen2Instruct7BRagChain): @@ -160,8 +166,6 @@ class Qwen2Instruct72BRagChain(Qwen2Instruct7BRagChain): model_id = LLMModelType.QWEN15INSTRUCT32B -from .chat_chain import Baichuan2Chat13B4BitsChatChain - class Baichuan2Chat13B4BitsKnowledgeQaChain(Baichuan2Chat13B4BitsChatChain): model_id = LLMModelType.BAICHUAN2_13B_CHAT intent_type = LLMTaskType.RAG @@ -182,7 +186,3 @@ def add_system_prompt(x): ) llm_chain = chat_history_chain | llm_chain return llm_chain - - - - diff --git a/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/auto_evaluation_chain.py b/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/auto_evaluation_chain.py index 28d4b22c0..1b0178e4e 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/auto_evaluation_chain.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/auto_evaluation_chain.py @@ -2,9 +2,9 @@ import re from langchain.schema.runnable import RunnableLambda, RunnablePassthrough -from langchain_core.messages import AIMessage,SystemMessage,HumanMessage +from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from common_logic.common_utils.logger_utils import get_logger -from langchain.prompts import ChatPromptTemplate,HumanMessagePromptTemplate +from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain_core.messages import convert_to_messages from common_logic.common_utils.constant import ( MessageType, @@ -49,11 +49,11 @@ class Claude2AutoEvaluationChain(Claude2ChatChain): - intent_type = LLMTaskType.AUTO_EVALUATION + intent_type = LLMTaskType.AUTO_EVALUATION model_id = LLMModelType.CLAUDE_2 @classmethod - def create_messages(cls,x:dict,examples=""): + def create_messages(cls, x: dict, examples=""): prompt = AUTO_EVALUATION_TEMPLATE.format( ref_answer=x['ref_answer'], model_answer=x['model_answer'], @@ -62,38 +62,36 @@ def create_messages(cls,x:dict,examples=""): messages = [ HumanMessage(content=prompt), AIMessage(content="") - ] + ] return messages @classmethod - def postprocess(cls,content): + def postprocess(cls, content): logger.info(f"auto eval content: {content}") try: - score = float(re.findall("(.*?)",content)[0].strip()) + score = float(re.findall( + "(.*?)", content)[0].strip()) return score except Exception as e: logger.error(f"error: {e}, content: {content}") raise e - @classmethod def create_chain(cls, model_kwargs=None, **kwargs): - llm = Model.get_model(cls.model_id, model_kwargs=model_kwargs, **kwargs) - chain = RunnableLambda(lambda x: cls.create_messages(x)) | llm | RunnableLambda(lambda x: cls.postprocess(x.content)) - return chain - + llm = Model.get_model( + cls.model_id, model_kwargs=model_kwargs, **kwargs) + chain = RunnableLambda(lambda x: cls.create_messages( + x)) | llm | RunnableLambda(lambda x: cls.postprocess(x.content)) + return chain + class Claude21AutoEvaluationChain(Claude2AutoEvaluationChain): model_id = LLMModelType.CLAUDE_21 - class Claude3HaikuAutoEvaluationChain(Claude2AutoEvaluationChain): model_id = LLMModelType.CLAUDE_3_HAIKU class Claude3SonnetAutoEvaluationChain(Claude2AutoEvaluationChain): model_id = LLMModelType.CLAUDE_3_SONNET - - - diff --git a/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_conversation_summary_chain.py b/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_conversation_summary_chain.py index eae0716d6..5e240afc6 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_conversation_summary_chain.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_conversation_summary_chain.py @@ -1,5 +1,5 @@ # conversation summary chain -from typing import List +from typing import List from langchain.schema.runnable import ( RunnableLambda, @@ -15,12 +15,12 @@ LLMModelType ) -from langchain_core.messages import( +from langchain_core.messages import ( AIMessage, HumanMessage, BaseMessage, convert_to_messages -) +) from langchain.prompts import ( HumanMessagePromptTemplate, ChatPromptTemplate @@ -64,22 +64,25 @@ class Claude2RetailConversationSummaryChain(LLMChain): model_id = LLMModelType.CLAUDE_2 intent_type = LLMTaskType.RETAIL_CONVERSATION_SUMMARY_TYPE - default_model_kwargs = {"max_tokens": 2000, "temperature": 0.1, "top_p": 0.9} + default_model_kwargs = {"max_tokens": 2000, + "temperature": 0.1, "top_p": 0.9} CQR_TEMPLATE = CQR_TEMPLATE + @staticmethod - def create_conversational_context(chat_history:List[BaseMessage]): + def create_conversational_context(chat_history: List[BaseMessage]): conversational_contexts = [] for his in chat_history: - role = his.type + role = his.type content = his.content - assert role in [HUMAN_MESSAGE_TYPE, AI_MESSAGE_TYPE],(role,[HUMAN_MESSAGE_TYPE, AI_MESSAGE_TYPE]) + assert role in [HUMAN_MESSAGE_TYPE, AI_MESSAGE_TYPE], (role, [ + HUMAN_MESSAGE_TYPE, AI_MESSAGE_TYPE]) if role == HUMAN_MESSAGE_TYPE: conversational_contexts.append(f"客户: {content}") else: conversational_contexts.append(f"客服: {content}") conversational_context = "\n".join(conversational_contexts) return conversational_context - + @classmethod def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} @@ -95,14 +98,14 @@ def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs=model_kwargs, ) cqr_chain = RunnablePassthrough.assign( - conversational_context=RunnableLambda( + conversational_context=RunnableLambda( lambda x: cls.create_conversational_context( convert_to_messages(x["chat_history"]) ) )) \ - | RunnableLambda(lambda x: cqr_template.format(chat_history=x['conversational_context'],query=x['query'])) \ + | RunnableLambda(lambda x: cqr_template.format(chat_history=x['conversational_context'], query=x['query'])) \ | llm | RunnableLambda(lambda x: x.content) - + return cqr_chain @@ -144,15 +147,15 @@ class Mixtral8x7bRetailConversationSummaryChain(Claude2RetailConversationSummary CQR_TEMPLATE = MIXTRAL_CQR_TEMPLATE -class GLM4Chat9BRetailConversationSummaryChain(GLM4Chat9BChatChain,Claude2RetailConversationSummaryChain): +class GLM4Chat9BRetailConversationSummaryChain(GLM4Chat9BChatChain, Claude2RetailConversationSummaryChain): model_id = LLMModelType.GLM_4_9B_CHAT intent_type = LLMTaskType.RETAIL_CONVERSATION_SUMMARY_TYPE CQR_TEMPLATE = MIXTRAL_CQR_TEMPLATE @classmethod - def create_chat_history(cls,x): + def create_chat_history(cls, x): conversational_context = cls.create_conversational_context( - convert_to_messages(x["chat_history"]) + convert_to_messages(x["chat_history"]) ) prompt = cls.CQR_TEMPLATE.format( chat_history=conversational_context, @@ -161,12 +164,12 @@ def create_chat_history(cls,x): chat_history = [ {"role": MessageType.HUMAN_MESSAGE_TYPE, "content": prompt - }, + }, { - "role":MessageType.AI_MESSAGE_TYPE, + "role": MessageType.AI_MESSAGE_TYPE, "content": "好的,站在客户的角度,我将当前用户的回复内容改写为: " } - ] + ] return chat_history @@ -174,7 +177,7 @@ def create_chat_history(cls,x): def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} model_kwargs = {**cls.default_model_kwargs, **model_kwargs} - + llm = Model.get_model( model_id=cls.model_id, model_kwargs=model_kwargs, @@ -182,11 +185,11 @@ def create_chain(cls, model_kwargs=None, **kwargs): ) cqr_chain = RunnablePassthrough.assign( - chat_history = RunnableLambda(lambda x: cls.create_chat_history(x)) + chat_history=RunnableLambda(lambda x: cls.create_chat_history(x)) ) | RunnableLambda(lambda x: llm.invoke(x)) - + return cqr_chain - + class Qwen2Instruct7BRetailConversationSummaryChain(GLM4Chat9BRetailConversationSummaryChain): model_id = LLMModelType.QWEN2INSTRUCT7B @@ -194,10 +197,11 @@ class Qwen2Instruct7BRetailConversationSummaryChain(GLM4Chat9BRetailConversation "max_tokens": 1024, "temperature": 0.1, } + @classmethod def create_chain(cls, model_kwargs=None, **kwargs): - chain = super().create_chain(model_kwargs=model_kwargs,**kwargs) - return chain | RunnableLambda(lambda x:x['text']) + chain = super().create_chain(model_kwargs=model_kwargs, **kwargs) + return chain | RunnableLambda(lambda x: x['text']) class Qwen2Instruct72BRetailConversationSummaryChain(Qwen2Instruct7BRetailConversationSummaryChain): @@ -205,4 +209,4 @@ class Qwen2Instruct72BRetailConversationSummaryChain(Qwen2Instruct7BRetailConver class Qwen2Instruct72BRetailConversationSummaryChain(Qwen2Instruct7BRetailConversationSummaryChain): - model_id = LLMModelType.QWEN15INSTRUCT32B \ No newline at end of file + model_id = LLMModelType.QWEN15INSTRUCT32B diff --git a/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_tool_calling_chain_claude_xml.py b/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_tool_calling_chain_claude_xml.py index 71a953c5a..f480e15ba 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_tool_calling_chain_claude_xml.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_tool_calling_chain_claude_xml.py @@ -1,20 +1,20 @@ # tool calling chain import json -from typing import List,Dict,Any +from typing import List, Dict, Any import re -from datetime import datetime +from datetime import datetime from langchain.schema.runnable import ( RunnableLambda, ) -from langchain_core.messages import( +from langchain_core.messages import ( AIMessage, SystemMessage -) +) from langchain.prompts import ChatPromptTemplate -from langchain_core.messages import AIMessage,SystemMessage,HumanMessage +from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from common_logic.common_utils.constant import ( LLMTaskType, @@ -38,27 +38,27 @@ """ -SYSTEM_MESSAGE_PROMPT=("你是安踏的客服助理小安, 主要职责是处理用户售前和售后的问题。下面是当前用户正在浏览的商品信息:\n\n{goods_info}\n" - "In this environment you have access to a set of tools you can use to answer the customer's question." - "\n" - "You may call them like this:\n" - "\n" - "\n" - "$TOOL_NAME\n" - "\n" - "<$PARAMETER_NAME>$PARAMETER_VALUE\n" - "...\n" - "\n" - "\n" - "\n" - "\n" - "Here are the tools available:\n" - "\n" - "{tools}" - "\n" - "\nAnswer the user's request using relevant tools (if they are available). Before calling a tool, do some analysis within tags. First, think about which of the provided tools is the relevant tool to answer the user's request. Second, go through each of the required parameters of the relevant tool and determine if the user has directly provided or given enough information to infer a value. When deciding if the parameter can be inferred, carefully consider all the context to see if it supports a specific value. If all of the required parameters are present or can be reasonably inferred, close the thinking tag and proceed with the tool call. BUT, if one of the values for a required parameter is missing, DO NOT invoke the function (not even with fillers for the missing params) and instead, ask the user to provide the missing parameters. DO NOT ask for more information on optional parameters if it is not provided." - f"\nHere are some guidelines for you:\n{tool_call_guidelines}" - ) +SYSTEM_MESSAGE_PROMPT = ("你是安踏的客服助理小安, 主要职责是处理用户售前和售后的问题。下面是当前用户正在浏览的商品信息:\n\n{goods_info}\n" + "In this environment you have access to a set of tools you can use to answer the customer's question." + "\n" + "You may call them like this:\n" + "\n" + "\n" + "$TOOL_NAME\n" + "\n" + "<$PARAMETER_NAME>$PARAMETER_VALUE\n" + "...\n" + "\n" + "\n" + "\n" + "\n" + "Here are the tools available:\n" + "\n" + "{tools}" + "\n" + "\nAnswer the user's request using relevant tools (if they are available). Before calling a tool, do some analysis within tags. First, think about which of the provided tools is the relevant tool to answer the user's request. Second, go through each of the required parameters of the relevant tool and determine if the user has directly provided or given enough information to infer a value. When deciding if the parameter can be inferred, carefully consider all the context to see if it supports a specific value. If all of the required parameters are present or can be reasonably inferred, close the thinking tag and proceed with the tool call. BUT, if one of the values for a required parameter is missing, DO NOT invoke the function (not even with fillers for the missing params) and instead, ask the user to provide the missing parameters. DO NOT ask for more information on optional parameters if it is not provided." + f"\nHere are some guidelines for you:\n{tool_call_guidelines}" + ) SYSTEM_MESSAGE_PROMPT_WITH_FEWSHOT_EXAMPLES = SYSTEM_MESSAGE_PROMPT + ( "Some examples of tool calls are given below, where the content within represents the most recent reply in the dialog." @@ -112,7 +112,7 @@ def _get_type(parameter: Dict[str, Any]) -> str: return json.dumps(parameter) -def convert_openai_tool_to_anthropic(tools:list[dict])->str: +def convert_openai_tool_to_anthropic(tools: list[dict]) -> str: formatted_tools = tools tools_data = [ { @@ -162,15 +162,15 @@ class Claude2RetailToolCallingChain(LLMChain): "max_tokens": 2000, "temperature": 0.1, "top_p": 0.9, - "stop_sequences": ["\n\nHuman:", "\n\nAssistant",""], - } + "stop_sequences": ["\n\nHuman:", "\n\nAssistant", ""], + } @staticmethod - def format_fewshot_examples(fewshot_examples:list[dict]): + def format_fewshot_examples(fewshot_examples: list[dict]): fewshot_example_strs = [] for fewshot_example in fewshot_examples: param_strs = [] - for p,v in fewshot_example['kwargs'].items(): + for p, v in fewshot_example['kwargs'].items(): param_strs.append(f"<{p}>{v}\n{fewshot_example_str}\n" - + @classmethod - def parse_function_calls_from_ai_message(cls,message:AIMessage): + def parse_function_calls_from_ai_message(cls, message: AIMessage): content = "" + message.content + "" - function_calls:List[str] = re.findall("(.*?)", content,re.S) + function_calls: List[str] = re.findall( + "(.*?)", content, re.S) if not function_calls: - content = "" + message.content + content = "" + message.content return { - "function_calls": function_calls, - "content": content - } - + "function_calls": function_calls, + "content": content + } - @staticmethod - def generate_chat_history(state:dict): + @staticmethod + def generate_chat_history(state: dict): chat_history = state['chat_history'] \ - + [{"role": "user","content":state['query']}] \ + + [{"role": "user", "content": state['query']}] \ + state['agent_tool_history'] - return {"chat_history":chat_history} + return {"chat_history": chat_history} - @classmethod def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} - tools:list[dict] = kwargs['tools'] + tools: list[dict] = kwargs['tools'] tool_names = [tool['name'] for tool in tools] # add two extral tools if "give_rhetorical_question" not in tool_names: - tools.append(get_tool_by_name("give_rhetorical_question",scene=SceneType.RETAIL).tool_def) + tools.append(get_tool_by_name( + "give_rhetorical_question", scene=SceneType.RETAIL).tool_def) if "give_final_response" not in tool_names: - tools.append(get_tool_by_name("give_final_response",scene=SceneType.RETAIL).tool_def) + tools.append(get_tool_by_name("give_final_response", + scene=SceneType.RETAIL).tool_def) - fewshot_examples = kwargs.get('fewshot_examples',[]) + fewshot_examples = kwargs.get('fewshot_examples', []) if fewshot_examples: fewshot_examples.append({ "name": "give_rhetorical_question", "query": "今天天气怎么样?", "kwargs": {"question": "请问你想了解哪个城市的天气?"} }) - + model_kwargs = {**cls.default_model_kwargs, **model_kwargs} tools_formatted = convert_openai_tool_to_anthropic(tools) @@ -248,21 +249,21 @@ def create_chain(cls, model_kwargs=None, **kwargs): tools=tools_formatted, fewshot_examples=cls.format_fewshot_examples( fewshot_examples - ), - goods_info = goods_info + ), + goods_info=goods_info ) else: system_prompt = SYSTEM_MESSAGE_PROMPT.format( tools=tools_formatted, goods_info=goods_info ) - + tool_calling_template = ChatPromptTemplate.from_messages( [ - SystemMessage(content=system_prompt), - ("placeholder", "{chat_history}"), - AIMessage(content="") - ]) + SystemMessage(content=system_prompt), + ("placeholder", "{chat_history}"), + AIMessage(content="") + ]) llm = Model.get_model( model_id=cls.model_id, @@ -270,10 +271,10 @@ def create_chain(cls, model_kwargs=None, **kwargs): ) chain = RunnableLambda(cls.generate_chat_history) | tool_calling_template \ | RunnableLambda(lambda x: x.messages) \ - | llm | RunnableLambda(lambda message:cls.parse_function_calls_from_ai_message( + | llm | RunnableLambda(lambda message: cls.parse_function_calls_from_ai_message( message )) - + return chain @@ -303,52 +304,47 @@ class Claude3HaikuRetailToolCallingChain(Claude2RetailToolCallingChain): class Mixtral8x7bRetailToolCallingChain(Claude2RetailToolCallingChain): model_id = LLMModelType.MIXTRAL_8X7B_INSTRUCT - default_model_kwargs = {"max_tokens": 1000, "temperature": 0.01,"stop":[""]} + default_model_kwargs = {"max_tokens": 1000, + "temperature": 0.01, "stop": [""]} @classmethod - def parse_function_calls_from_ai_message(cls,message:AIMessage): - content = message.content.replace("\_","_") - function_calls:List[str] = re.findall("(.*?)", content + "",re.S) + def parse_function_calls_from_ai_message(cls, message: AIMessage): + content = message.content.replace("\_", "_") + function_calls: List[str] = re.findall( + "(.*?)", content + "", re.S) if function_calls: function_calls = [function_calls[0]] if not function_calls: content = message.content return { - "function_calls": function_calls, - "content": content - } - - @staticmethod - def chat_history_to_string(chat_history:list[dict]): + "function_calls": function_calls, + "content": content + } + + @staticmethod + def chat_history_to_string(chat_history: list[dict]): chat_history_lc = ChatPromptTemplate.from_messages([ - ("placeholder", "{chat_history}") - ]).invoke({"chat_history":chat_history}).messages + ("placeholder", "{chat_history}") + ]).invoke({"chat_history": chat_history}).messages chat_history_strs = [] for message in chat_history_lc: - assert isinstance(message,(HumanMessage,AIMessage)),message - if isinstance(message,HumanMessage): + assert isinstance(message, (HumanMessage, AIMessage)), message + if isinstance(message, HumanMessage): chat_history_strs.append(f"客户: {message.content}") else: chat_history_strs.append(f"客服: {message.content}") - return "\n".join(chat_history_strs) + return "\n".join(chat_history_strs) - @classmethod - def generate_chat_history(cls,state:dict): + def generate_chat_history(cls, state: dict): chat_history_str = cls.chat_history_to_string(state['chat_history']) chat_history = [{ "role": "user", "content": MIXTRAL8X7B_QUERY_TEMPLATE.format( chat_history=chat_history_str, - query = state['query'] + query=state['query'] ) - }] + state['agent_tool_history'] + }] + state['agent_tool_history'] return {"chat_history": chat_history} - - - - - - diff --git a/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_tool_calling_chain_json.py b/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_tool_calling_chain_json.py index f1bc5d8b0..0a9847f1a 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_tool_calling_chain_json.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/retail_chains/retail_tool_calling_chain_json.py @@ -1,21 +1,22 @@ # tool calling chain +from ..chat_chain import Qwen2Instruct7BChatChain import json -from typing import List,Dict,Any +from typing import List, Dict, Any import re -from datetime import datetime +from datetime import datetime import copy from langchain.schema.runnable import ( RunnableLambda, ) -from langchain_core.messages import( +from langchain_core.messages import ( AIMessage, SystemMessage -) +) from langchain.prompts import ChatPromptTemplate -from langchain_core.messages import AIMessage,SystemMessage,HumanMessage +from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from common_logic.common_utils.constant import ( LLMTaskType, @@ -46,7 +47,6 @@ """ - class GLM4Chat9BRetailToolCallingChain(GLM4Chat9BChatChain): model_id = LLMModelType.GLM_4_9B_CHAT intent_type = LLMTaskType.RETAIL_TOOL_CALLING @@ -56,22 +56,22 @@ class GLM4Chat9BRetailToolCallingChain(GLM4Chat9BChatChain): "temperature": 0.1, } DATE_PROMPT = "当前日期: %Y-%m-%d" - + @staticmethod - def convert_openai_function_to_glm(tools:list[dict]): + def convert_openai_function_to_glm(tools: list[dict]): glm_tools = [] for tool_def in tools: tool_name = tool_def['name'] description = tool_def['description'] params = [] - required = tool_def['parameters'].get("required",[]) - for param_name,param in tool_def['parameters'].get('properties',{}).items(): + required = tool_def['parameters'].get("required", []) + for param_name, param in tool_def['parameters'].get('properties', {}).items(): params.append({ "name": param_name, "description": param["description"], "type": param["type"], - "required": param_name in required, - }) + "required": param_name in required, + }) glm_tools.append({ "name": tool_name, "description": description, @@ -80,9 +80,9 @@ def convert_openai_function_to_glm(tools:list[dict]): return glm_tools @staticmethod - def format_fewshot_examples(fewshot_examples:list[dict]): + def format_fewshot_examples(fewshot_examples: list[dict]): fewshot_example_strs = [] - for i,example in enumerate(fewshot_examples): + for i, example in enumerate(fewshot_examples): query = example['query'] name = example['name'] kwargs = example['kwargs'] @@ -90,9 +90,8 @@ def format_fewshot_examples(fewshot_examples:list[dict]): fewshot_example_strs.append(fewshot_example_str) return "\n\n".join(fewshot_example_strs) - @classmethod - def create_system_prompt(cls,goods_info:str,tools:list,fewshot_examples:list) -> str: + def create_system_prompt(cls, goods_info: str, tools: list, fewshot_examples: list) -> str: value = GLM4_SYSTEM_PROMPT.format( goods_info=goods_info, date_prompt=datetime.now().strftime(cls.DATE_PROMPT) @@ -115,46 +114,44 @@ def create_system_prompt(cls,goods_info:str,tools:list,fewshot_examples:list) -> return value @classmethod - def create_chat_history(cls,x,system_prompt=None): + def create_chat_history(cls, x, system_prompt=None): _chat_history = x['chat_history'] + \ - [{"role":MessageType.HUMAN_MESSAGE_TYPE,"content": x['query']}] + \ + [{"role": MessageType.HUMAN_MESSAGE_TYPE, "content": x['query']}] + \ x['agent_tool_history'] - + chat_history = [] for message in _chat_history: - new_message = message + new_message = message if message['role'] == MessageType.AI_MESSAGE_TYPE: new_message = { **message } - tool_calls = message.get('additional_kwargs',{}).get("tool_calls",[]) + tool_calls = message.get( + 'additional_kwargs', {}).get("tool_calls", []) if tool_calls: new_message['metadata'] = tool_calls[0]['name'] chat_history.append(new_message) - chat_history = [{"role": "system", "content": system_prompt}] + chat_history + chat_history = [ + {"role": "system", "content": system_prompt}] + chat_history return chat_history @classmethod def create_chain(cls, model_kwargs=None, **kwargs): - tools:list = kwargs.get('tools',[]) - fewshot_examples = kwargs.get('fewshot_examples',[]) + tools: list = kwargs.get('tools', []) + fewshot_examples = kwargs.get('fewshot_examples', []) glm_tools = cls.convert_openai_function_to_glm(tools) system_prompt = cls.create_system_prompt( - goods_info=kwargs['goods_info'], + goods_info=kwargs['goods_info'], tools=glm_tools, fewshot_examples=fewshot_examples - ) + ) kwargs['system_prompt'] = system_prompt - return super().create_chain(model_kwargs=model_kwargs,**kwargs) - - -from ..chat_chain import Qwen2Instruct7BChatChain - + return super().create_chain(model_kwargs=model_kwargs, **kwargs) class Qwen2Instruct72BRetailToolCallingChain(Qwen2Instruct7BChatChain): model_id = LLMModelType.QWEN2INSTRUCT72B - intent_type = LLMTaskType.RETAIL_TOOL_CALLING + intent_type = LLMTaskType.RETAIL_TOOL_CALLING default_model_kwargs = { "max_tokens": 1024, "temperature": 0.1, @@ -173,28 +170,26 @@ class Qwen2Instruct72BRetailToolCallingChain(Qwen2Instruct7BChatChain): prefill_after_second_thinking = "" prefill = prefill_after_thinking - - FN_CALL_TEMPLATE_INFO_ZH="""# 工具 + FN_CALL_TEMPLATE_INFO_ZH = """# 工具 ## 你拥有如下工具: {tool_descs}""" - - FN_CALL_TEMPLATE_FMT_ZH="""## 你可以在回复中插入零次或者一次以下命令以调用工具: + FN_CALL_TEMPLATE_FMT_ZH = """## 你可以在回复中插入零次或者一次以下命令以调用工具: %s: 工具名称,必须是[{tool_names}]之一。 %s: 工具输入 %s: 工具结果 %s: 根据工具结果进行回复""" % ( - FN_NAME, - FN_ARGS, - FN_RESULT, - FN_EXIT, -) - TOOL_DESC_TEMPLATE="""### {name_for_human}\n\n{name_for_model}: {description_for_model} 输入参数:{parameters} {args_format}""" - - FN_CALL_TEMPLATE=FN_CALL_TEMPLATE_INFO_ZH + '\n\n' + FN_CALL_TEMPLATE_FMT_ZH + FN_NAME, + FN_ARGS, + FN_RESULT, + FN_EXIT, + ) + TOOL_DESC_TEMPLATE = """### {name_for_human}\n\n{name_for_model}: {description_for_model} 输入参数:{parameters} {args_format}""" + + FN_CALL_TEMPLATE = FN_CALL_TEMPLATE_INFO_ZH + '\n\n' + FN_CALL_TEMPLATE_FMT_ZH # SYSTEM_PROMPT=f"""你是安踏天猫的客服助理小安, 主要职责是处理用户售前和售后的问题。{{date_prompt}} @@ -237,7 +232,7 @@ class Qwen2Instruct72BRetailToolCallingChain(Qwen2Instruct7BChatChain): # - 当前主要服务天猫平台的客户,如果客户询问其他平台的问题,直接回复 “不好意思,亲亲,这里是天猫店铺,只能为您解答天猫的问题。建议您联系其他平台的客服或售后人员给您提供相关的帮助和支持。谢谢!“ # - 如果客户的回复里面包含订单号,则直接回复 “您好,亲亲,这就帮您去查相关订单信息。请问还有什么问题吗?“{{non_ask_rules}}""" - SYSTEM_PROMPT=f"""你是安踏天猫的客服助理小安, 主要职责是处理用户售前和售后的问题。{{date_prompt}} + SYSTEM_PROMPT = f"""你是安踏天猫的客服助理小安, 主要职责是处理用户售前和售后的问题。{{date_prompt}} {{tools}} {{fewshot_examples}} @@ -257,11 +252,12 @@ class Qwen2Instruct72BRetailToolCallingChain(Qwen2Instruct7BChatChain): ## Tips - 如果客户没有明确指出在哪里购买的商品,则默认都是在天猫平台购买的。 - 回答必须简洁,不允许出现超过2句话的回复。{{non_ask_rules}}""" + @classmethod - def get_function_description(cls,tool:dict): + def get_function_description(cls, tool: dict): tool_name = tool['name'] description = tool['description'] - params_str = json.dumps(tool.get('parameters',{}),ensure_ascii=False) + params_str = json.dumps(tool.get('parameters', {}), ensure_ascii=False) args_format = '此工具的输入应为JSON对象。' return cls.TOOL_DESC_TEMPLATE.format( name_for_human=tool_name, @@ -271,22 +267,21 @@ def get_function_description(cls,tool:dict): args_format=args_format ).rstrip() - @classmethod - def format_fewshot_examples(cls,fewshot_examples:list[dict]): + def format_fewshot_examples(cls, fewshot_examples: list[dict]): fewshot_example_strs = [] - for i,example in enumerate(fewshot_examples): + for i, example in enumerate(fewshot_examples): query = example['query'] name = example['name'] kwargs = example['kwargs'] fewshot_example_str = f"""## 工具调用例子{i+1}\nInput:\n{query}\nOutput:\n{cls.FN_NAME}: {name}\n{cls.FN_ARGS}: {json.dumps(kwargs,ensure_ascii=False)}\n{cls.FN_RESULT}""" fewshot_example_strs.append(fewshot_example_str) return "\n\n".join(fewshot_example_strs) - - + @classmethod - def create_system_prompt(cls,goods_info:str,tools:list[dict],fewshot_examples:list,create_time=None) -> str: - tool_descs = '\n\n'.join(cls.get_function_description(tool) for tool in tools) + def create_system_prompt(cls, goods_info: str, tools: list[dict], fewshot_examples: list, create_time=None) -> str: + tool_descs = '\n\n'.join( + cls.get_function_description(tool) for tool in tools) tool_names = ','.join(tool['name'] for tool in tools) tool_system = cls.FN_CALL_TEMPLATE.format( tool_descs=tool_descs, @@ -297,7 +292,7 @@ def create_system_prompt(cls,goods_info:str,tools:list[dict],fewshot_examples:li fewshot_examples_str = "\n\n# 下面给出不同客户回复下调用不同工具的例子。" fewshot_examples_str += f"\n\n{cls.format_fewshot_examples(fewshot_examples)}" fewshot_examples_str += "\n\n请参考上述例子进行工具调用。" - + non_ask_tool_list = [] # for tool in tools: # should_ask_parameter = get_tool_by_name(tool['name']).should_ask_parameter @@ -310,44 +305,46 @@ def create_system_prompt(cls,goods_info:str,tools:list[dict],fewshot_examples:li non_ask_rules = "\n - " + ','.join(non_ask_tool_list) if create_time: - datetime_object = datetime.strptime(create_time, '%Y-%m-%d %H:%M:%S.%f') + datetime_object = datetime.strptime( + create_time, '%Y-%m-%d %H:%M:%S.%f') else: datetime_object = datetime.now() - logger.info(f"create_time: {create_time} is not valid, use current time instead.") + logger.info( + f"create_time: {create_time} is not valid, use current time instead.") return cls.SYSTEM_PROMPT.format( - goods_info=goods_info, - tools=tool_system, - fewshot_examples=fewshot_examples_str, - non_ask_rules=non_ask_rules, - date_prompt=datetime_object.strftime(cls.DATE_PROMPT) - ) + goods_info=goods_info, + tools=tool_system, + fewshot_examples=fewshot_examples_str, + non_ask_rules=non_ask_rules, + date_prompt=datetime_object.strftime(cls.DATE_PROMPT) + ) @classmethod - def create_chat_history(cls,x,system_prompt=None): + def create_chat_history(cls, x, system_prompt=None): # deal with function _chat_history = x['chat_history'] + \ - [{"role": MessageType.HUMAN_MESSAGE_TYPE,"content": x['query']}] + \ + [{"role": MessageType.HUMAN_MESSAGE_TYPE, "content": x['query']}] + \ x['agent_tool_history'] - + # print(f'chat_history_before create: {_chat_history}') # merge chat_history chat_history = [] if system_prompt is not None: chat_history.append({ "role": MessageType.SYSTEM_MESSAGE_TYPE, - "content":system_prompt + "content": system_prompt }) - + # move tool call results to assistant - for i,message in enumerate(copy.deepcopy(_chat_history)): + for i, message in enumerate(copy.deepcopy(_chat_history)): role = message['role'] - if i==0: + if i == 0: assert role == MessageType.HUMAN_MESSAGE_TYPE, f"The first message should comes from human role" - + if role == MessageType.TOOL_MESSAGE_TYPE: - assert chat_history[-1]['role'] == MessageType.AI_MESSAGE_TYPE,_chat_history + assert chat_history[-1]['role'] == MessageType.AI_MESSAGE_TYPE, _chat_history chat_history[-1]['content'] += message['content'] - continue + continue elif role == MessageType.AI_MESSAGE_TYPE: # continue ai message if chat_history[-1]['role'] == MessageType.AI_MESSAGE_TYPE: @@ -355,71 +352,74 @@ def create_chat_history(cls,x,system_prompt=None): continue chat_history.append(message) - - # move the last tool call message to user + + # move the last tool call message to user if chat_history[-1]['role'] == MessageType.AI_MESSAGE_TYPE: - assert chat_history[-2]['role'] == MessageType.HUMAN_MESSAGE_TYPE,chat_history - tool_calls = chat_history[-1].get("additional_kwargs",{}).get("tool_calls",[]) + assert chat_history[-2]['role'] == MessageType.HUMAN_MESSAGE_TYPE, chat_history + tool_calls = chat_history[-1].get("additional_kwargs", + {}).get("tool_calls", []) if tool_calls: - chat_history[-2]['content'] += ("\n\n" + chat_history[-1]['content']) + chat_history[-2]['content'] += ("\n\n" + + chat_history[-1]['content']) chat_history = chat_history[:-1] return chat_history - @classmethod - def parse_function_calls_from_ai_message(cls,message:dict): + def parse_function_calls_from_ai_message(cls, message: dict): stop_reason = message['stop_reason'] - content = f"{cls.prefill}" + message['text'] + content = f"{cls.prefill}" + message['text'] content = content.strip() stop_reason = stop_reason or "" - - function_calls = re.findall(f"{cls.FN_NAME}.*?{cls.FN_RESULT}", content + stop_reason,re.S) + function_calls = re.findall( + f"{cls.FN_NAME}.*?{cls.FN_RESULT}", content + stop_reason, re.S) return { - "function_calls":function_calls, - "content":content + "function_calls": function_calls, + "content": content } - + @classmethod def create_chain(cls, model_kwargs=None, **kwargs): - tools:list = kwargs.get('tools',[]) + tools: list = kwargs.get('tools', []) # add extral tools if "give_rhetorical_question" not in tools: - tools.append(get_tool_by_name("give_rhetorical_question",scene=SceneType.RETAIL).tool_def) - fewshot_examples = kwargs.get('fewshot_examples',[]) + tools.append(get_tool_by_name( + "give_rhetorical_question", scene=SceneType.RETAIL).tool_def) + fewshot_examples = kwargs.get('fewshot_examples', []) system_prompt = cls.create_system_prompt( - goods_info=kwargs['goods_info'], - create_time=kwargs.get('create_time',None), + goods_info=kwargs['goods_info'], + create_time=kwargs.get('create_time', None), tools=tools, fewshot_examples=fewshot_examples - ) + ) agent_current_call_number = kwargs['agent_current_call_number'] - + # give different prefill if agent_current_call_number == 0: cls.prefill = cls.prefill_after_thinking else: cls.prefill = cls.prefill_after_second_thinking - + # cls.prefill = '' model_kwargs = model_kwargs or {} kwargs['system_prompt'] = system_prompt model_kwargs = {**model_kwargs} # model_kwargs["stop"] = model_kwargs.get("stop",[]) + ['✿RESULT✿', '✿RESULT✿:', '✿RESULT✿:\n','✿RETURN✿',f'<{cls.thinking_tag}>',f'<{cls.thinking_tag}/>'] - model_kwargs["stop"] = model_kwargs.get("stop",[]) + ['✿RESULT✿', '✿RESULT✿:', '✿RESULT✿:\n','✿RETURN✿',f'<{cls.thinking_tag}/>'] + model_kwargs["stop"] = model_kwargs.get( + "stop", []) + ['✿RESULT✿', '✿RESULT✿:', '✿RESULT✿:\n', '✿RETURN✿', f'<{cls.thinking_tag}/>'] # model_kwargs["prefill"] = "我先看看调用哪个工具,下面是我的思考过程:\n\nstep 1." if "prefill" not in model_kwargs: - model_kwargs["prefill"] = f'{cls.prefill}' - return super().create_chain(model_kwargs=model_kwargs,**kwargs) - + model_kwargs["prefill"] = f'{cls.prefill}' + return super().create_chain(model_kwargs=model_kwargs, **kwargs) + class Qwen2Instruct7BRetailToolCallingChain(Qwen2Instruct72BRetailToolCallingChain): model_id = LLMModelType.QWEN2INSTRUCT7B goods_info_tag = "商品信息" - SYSTEM_PROMPT=f"""你是安踏天猫的客服助理小安, 主要职责是处理用户售前和售后的问题。{{date_prompt}} + SYSTEM_PROMPT = f"""你是安踏天猫的客服助理小安, 主要职责是处理用户售前和售后的问题。{{date_prompt}} {{tools}} {{fewshot_examples}} @@ -443,13 +443,10 @@ class Qwen2Instruct7BRetailToolCallingChain(Qwen2Instruct72BRetailToolCallingCha @classmethod def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs["prefill"] = "" - res = super().create_chain(model_kwargs=model_kwargs,**kwargs) + res = super().create_chain(model_kwargs=model_kwargs, **kwargs) cls.prefill = "" return res + class Qwen15Instruct32BRetailToolCallingChain(Qwen2Instruct7BRetailToolCallingChain): model_id = LLMModelType.QWEN15INSTRUCT32B - - - - diff --git a/source/lambda/online/common_logic/langchain_integration/chains/stepback_chain.py b/source/lambda/online/common_logic/langchain_integration/chains/stepback_chain.py index 4fb49a410..64653fcff 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/stepback_chain.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/stepback_chain.py @@ -14,6 +14,7 @@ STEPBACK_PROMPTING_TYPE = LLMTaskType.STEPBACK_PROMPTING_TYPE + class Internlm2Chat7BStepBackChain(Internlm2Chat7BChatChain): model_id = LLMModelType.INTERNLM2_CHAT_7B intent_type = STEPBACK_PROMPTING_TYPE @@ -108,7 +109,8 @@ def create_chain(cls, model_kwargs=None, **kwargs): ] ) - llm = Model.get_model(cls.model_id, model_kwargs=model_kwargs, **kwargs) + llm = Model.get_model( + cls.model_id, model_kwargs=model_kwargs, **kwargs) chain = prompt | llm if stream: chain = ( diff --git a/source/lambda/online/common_logic/langchain_integration/chains/tool_calling_chain_api.py b/source/lambda/online/common_logic/langchain_integration/chains/tool_calling_chain_api.py index 62d982a33..7d0fbc3a1 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/tool_calling_chain_api.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/tool_calling_chain_api.py @@ -1,15 +1,15 @@ # tool calling chain import json -from typing import List,Dict,Any +from typing import List, Dict, Any from collections import defaultdict from common_logic.common_utils.prompt_utils import get_prompt_template -from langchain_core.messages import( +from langchain_core.messages import ( AIMessage, SystemMessage -) +) from langchain.prompts import ChatPromptTemplate -from langchain_core.messages import AIMessage,SystemMessage +from langchain_core.messages import AIMessage, SystemMessage from langchain.tools.base import BaseTool from langchain_core.language_models import BaseChatModel @@ -34,86 +34,88 @@ class Claude2ToolCallingChain(LLMChain): } @classmethod - def create_chat_history(cls,x): + def create_chat_history(cls, x): chat_history = x['chat_history'] + \ - [{"role": MessageType.HUMAN_MESSAGE_TYPE,"content": x['query']}] + \ + [{"role": MessageType.HUMAN_MESSAGE_TYPE, "content": x['query']}] + \ x['agent_tool_history'] return chat_history @classmethod - def get_common_system_prompt(cls,system_prompt_template:str,all_knowledge_retrieved_list=None): + def get_common_system_prompt(cls, system_prompt_template: str, all_knowledge_retrieved_list=None): all_knowledge_retrieved_list = all_knowledge_retrieved_list or [] all_knowledge_retrieved = "\n\n".join(all_knowledge_retrieved_list) now = get_china_now() date_str = now.strftime("%Y年%m月%d日") weekdays = ['星期一', '星期二', '星期三', '星期四', '星期五', '星期六', '星期日'] weekday = weekdays[now.weekday()] - system_prompt = system_prompt_template.format(date=date_str,weekday=weekday,context=all_knowledge_retrieved) + system_prompt = system_prompt_template.format( + date=date_str, weekday=weekday, context=all_knowledge_retrieved) return system_prompt - @classmethod - def bind_tools(cls,llm:BaseChatModel,tools:List[BaseTool], fewshot_examples=None, fewshot_template=None,tool_choice='any'): + def bind_tools(cls, llm: BaseChatModel, tools: List[BaseTool], fewshot_examples=None, fewshot_template=None, tool_choice='any'): tools = [tool.model_copy() for tool in tools] if not fewshot_examples: - if getattr(llm,"enable_auto_tool_choice",True): - return llm.bind_tools(tools,tool_choice=tool_choice) + if getattr(llm, "enable_auto_tool_choice", True): + return llm.bind_tools(tools, tool_choice=tool_choice) return llm.bind_tools(tools) # add fewshot examples to tool description - tools_map = {tool.name:tool for tool in tools} + tools_map = {tool.name: tool for tool in tools} # group fewshot examples fewshot_examples_grouped = defaultdict(list) for example in fewshot_examples: fewshot_examples_grouped[example['name']].append(example) - for tool_name,examples in fewshot_examples_grouped.items(): + for tool_name, examples in fewshot_examples_grouped.items(): tool = tools_map[tool_name] tool.description += "\n\nHere are some examples where this tool are called:\n" examples_strs = [] for example in examples: - params_str = json.dumps(example['kwargs'],ensure_ascii=False) + params_str = json.dumps(example['kwargs'], ensure_ascii=False) examples_strs.append( fewshot_template.format( query=example['query'], args=params_str ) ) - + tool.description += "\n\n".join(examples_strs) - - if getattr(llm,"enable_auto_tool_choice",True): - return llm.bind_tools(tools,tool_choice=tool_choice) + + if getattr(llm, "enable_auto_tool_choice", True): + return llm.bind_tools(tools, tool_choice=tool_choice) return llm.bind_tools(tools) - - + @classmethod def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} - tools:list = kwargs['tools'] - assert all(isinstance(tool,BaseTool) for tool in tools),tools - fewshot_examples = kwargs.get('fewshot_examples',[]) + tools: list = kwargs['tools'] + assert all(isinstance(tool, BaseTool) for tool in tools), tools + fewshot_examples = kwargs.get('fewshot_examples', []) agent_system_prompt = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="agent_system_prompt" + prompt_name="agent_system_prompt" ).prompt_template - agent_system_prompt = kwargs.get("agent_system_prompt",None) or agent_system_prompt - - all_knowledge_retrieved_list = kwargs.get('all_knowledge_retrieved_list',[]) + agent_system_prompt = kwargs.get( + "agent_system_prompt", None) or agent_system_prompt + + all_knowledge_retrieved_list = kwargs.get( + 'all_knowledge_retrieved_list', []) agent_system_prompt = cls.get_common_system_prompt( - agent_system_prompt,all_knowledge_retrieved_list + agent_system_prompt, all_knowledge_retrieved_list ) - - # tool fewshot prompt + + # tool fewshot prompt tool_fewshot_prompt = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, prompt_name="tool_fewshot_prompt" ).prompt_template - tool_fewshot_prompt = kwargs.get('tool_fewshot_prompt',None) or tool_fewshot_prompt + tool_fewshot_prompt = kwargs.get( + 'tool_fewshot_prompt', None) or tool_fewshot_prompt model_kwargs = {**cls.default_model_kwargs, **model_kwargs} @@ -122,18 +124,19 @@ def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs=model_kwargs, ) - llm = cls.bind_tools(llm,tools,fewshot_examples,fewshot_template=tool_fewshot_prompt) - + llm = cls.bind_tools(llm, tools, fewshot_examples, + fewshot_template=tool_fewshot_prompt) + tool_calling_template = ChatPromptTemplate.from_messages( [ - SystemMessage(content=agent_system_prompt), - ("placeholder", "{chat_history}"), - ("human", "{query}"), - ("placeholder", "{agent_tool_history}"), + SystemMessage(content=agent_system_prompt), + ("placeholder", "{chat_history}"), + ("human", "{query}"), + ("placeholder", "{agent_tool_history}"), ] ) - chain = tool_calling_template | llm + chain = tool_calling_template | llm return chain diff --git a/source/lambda/online/common_logic/langchain_integration/chains/tool_calling_chain_claude_xml.py b/source/lambda/online/common_logic/langchain_integration/chains/tool_calling_chain_claude_xml.py index 114139f84..49a187f3b 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/tool_calling_chain_claude_xml.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/tool_calling_chain_claude_xml.py @@ -1,6 +1,6 @@ # tool calling chain import json -from typing import List,Dict,Any +from typing import List, Dict, Any import re from langchain.schema.runnable import ( @@ -9,13 +9,13 @@ ) from common_logic.common_utils.prompt_utils import get_prompt_template from common_logic.common_utils.logger_utils import print_llm_messages -from langchain_core.messages import( +from langchain_core.messages import ( AIMessage, SystemMessage -) +) from langchain.prompts import ChatPromptTemplate -from langchain_core.messages import AIMessage,SystemMessage +from langchain_core.messages import AIMessage, SystemMessage from common_logic.common_utils.constant import ( LLMTaskType, @@ -47,27 +47,27 @@ """ -SYSTEM_MESSAGE_PROMPT =(f"In this environment you have access to a set of tools you can use to answer the user's question.\n" - "\n" - "You may call them like this:\n" - "\n" - "\n" - "$TOOL_NAME\n" - "\n" - "<$PARAMETER_NAME>$PARAMETER_VALUE\n" - "...\n" - "\n" - "\n" - "\n" - "\n" - "Here are the tools available:\n" - "\n" - "{tools}" - "\n" - "\nAnswer the user's request using relevant tools (if they are available). Before calling a tool, do some analysis within tags. First, think about which of the provided tools is the relevant tool to answer the user's request. Second, go through each of the required parameters of the relevant tool and determine if the user has directly provided or given enough information to infer a value. When deciding if the parameter can be inferred, carefully consider all the context to see if it supports a specific value. If all of the required parameters are present or can be reasonably inferred, close the thinking tag and proceed with the tool call. BUT, if one of the values for a required parameter is missing, DO NOT invoke the function (not even with fillers for the missing params) and instead, ask the user to provide the missing parameters. DO NOT ask for more information on optional parameters if it is not provided." - "\nHere are some guidelines for you:\n{tool_call_guidelines}." - f"\n{incorrect_tool_call_example}" - ) +SYSTEM_MESSAGE_PROMPT = (f"In this environment you have access to a set of tools you can use to answer the user's question.\n" + "\n" + "You may call them like this:\n" + "\n" + "\n" + "$TOOL_NAME\n" + "\n" + "<$PARAMETER_NAME>$PARAMETER_VALUE\n" + "...\n" + "\n" + "\n" + "\n" + "\n" + "Here are the tools available:\n" + "\n" + "{tools}" + "\n" + "\nAnswer the user's request using relevant tools (if they are available). Before calling a tool, do some analysis within tags. First, think about which of the provided tools is the relevant tool to answer the user's request. Second, go through each of the required parameters of the relevant tool and determine if the user has directly provided or given enough information to infer a value. When deciding if the parameter can be inferred, carefully consider all the context to see if it supports a specific value. If all of the required parameters are present or can be reasonably inferred, close the thinking tag and proceed with the tool call. BUT, if one of the values for a required parameter is missing, DO NOT invoke the function (not even with fillers for the missing params) and instead, ask the user to provide the missing parameters. DO NOT ask for more information on optional parameters if it is not provided." + "\nHere are some guidelines for you:\n{tool_call_guidelines}." + f"\n{incorrect_tool_call_example}" + ) SYSTEM_MESSAGE_PROMPT_WITH_FEWSHOT_EXAMPLES = SYSTEM_MESSAGE_PROMPT + ( "Some examples of tool calls are given below, where the content within represents the most recent reply in the dialog." @@ -123,7 +123,7 @@ def _get_type(parameter: Dict[str, Any]) -> str: return json.dumps(parameter) -def convert_openai_tool_to_anthropic(tools:list[dict])->str: +def convert_openai_tool_to_anthropic(tools: list[dict]) -> str: formatted_tools = tools tools_data = [ { @@ -173,15 +173,15 @@ class Claude2ToolCallingChain(LLMChain): "max_tokens": 2000, "temperature": 0.1, "top_p": 0.9, - "stop_sequences": ["\n\nHuman:", "\n\nAssistant",""], + "stop_sequences": ["\n\nHuman:", "\n\nAssistant", ""], } @staticmethod - def format_fewshot_examples(fewshot_examples:list[dict]): + def format_fewshot_examples(fewshot_examples: list[dict]): fewshot_example_strs = [] for fewshot_example in fewshot_examples: param_strs = [] - for p,v in fewshot_example['kwargs'].items(): + for p, v in fewshot_example['kwargs'].items(): param_strs.append(f"<{p}>{v}\n{fewshot_example_str}\n" - + @classmethod - def parse_function_calls_from_ai_message(cls,message:AIMessage): + def parse_function_calls_from_ai_message(cls, message: AIMessage): content = "" + message.content + "" - function_calls:List[str] = re.findall("(.*?)", content,re.S) + function_calls: List[str] = re.findall( + "(.*?)", content, re.S) if not function_calls: - content = "" + message.content + content = "" + message.content return { - "function_calls": function_calls, - "content": content - } + "function_calls": function_calls, + "content": content + } @classmethod - def create_chat_history(cls,x): + def create_chat_history(cls, x): chat_history = x['chat_history'] + \ - [{"role": MessageType.HUMAN_MESSAGE_TYPE,"content": x['query']}] + \ + [{"role": MessageType.HUMAN_MESSAGE_TYPE, "content": x['query']}] + \ x['agent_tool_history'] return chat_history @classmethod - def get_common_system_prompt(cls,system_prompt_template:str): + def get_common_system_prompt(cls, system_prompt_template: str): now = get_china_now() date_str = now.strftime("%Y年%m月%d日") weekdays = ['星期一', '星期二', '星期三', '星期四', '星期五', '星期六', '星期日'] weekday = weekdays[now.weekday()] - system_prompt = system_prompt_template.format(date=date_str,weekday=weekday) + system_prompt = system_prompt_template.format( + date=date_str, weekday=weekday) return system_prompt - - + @classmethod def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} - tools:list = kwargs['tools'] - fewshot_examples = kwargs.get('fewshot_examples',[]) + tools: list = kwargs['tools'] + fewshot_examples = kwargs.get('fewshot_examples', []) if fewshot_examples: fewshot_examples.append({ "name": "give_rhetorical_question", @@ -249,10 +250,11 @@ def create_chain(cls, model_kwargs=None, **kwargs): user_system_prompt = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="user_prompt" + prompt_name="user_prompt" ).prompt_template - user_system_prompt = kwargs.get("user_prompt",None) or user_system_prompt + user_system_prompt = kwargs.get( + "user_prompt", None) or user_system_prompt user_system_prompt = cls.get_common_system_prompt( user_system_prompt @@ -260,10 +262,11 @@ def create_chain(cls, model_kwargs=None, **kwargs): guidelines_prompt = get_prompt_template( model_id=cls.model_id, task_type=cls.intent_type, - prompt_name="guidelines_prompt" + prompt_name="guidelines_prompt" ).prompt_template - guidelines_prompt = kwargs.get("guidelines_prompt",None) or guidelines_prompt + guidelines_prompt = kwargs.get( + "guidelines_prompt", None) or guidelines_prompt model_kwargs = {**cls.default_model_kwargs, **model_kwargs} tools_formatted = convert_openai_tool_to_anthropic(tools) @@ -279,22 +282,22 @@ def create_chain(cls, model_kwargs=None, **kwargs): tools=tools_formatted, tool_call_guidelines=guidelines_prompt ) - - system_prompt = user_system_prompt + system_prompt + + system_prompt = user_system_prompt + system_prompt tool_calling_template = ChatPromptTemplate.from_messages( [ - SystemMessage(content=system_prompt), - ("placeholder", "{chat_history}"), - AIMessage(content="") - ]) + SystemMessage(content=system_prompt), + ("placeholder", "{chat_history}"), + AIMessage(content="") + ]) llm = Model.get_model( model_id=cls.model_id, model_kwargs=model_kwargs, ) chain = RunnablePassthrough.assign(chat_history=lambda x: cls.create_chat_history(x)) | tool_calling_template \ - | RunnableLambda(lambda x: print_llm_messages(f"Agent messages: {x.messages}") or x.messages ) \ - | llm | RunnableLambda(lambda message:cls.parse_function_calls_from_ai_message( + | RunnableLambda(lambda x: print_llm_messages(f"Agent messages: {x.messages}") or x.messages) \ + | llm | RunnableLambda(lambda message: cls.parse_function_calls_from_ai_message( message )) return chain diff --git a/source/lambda/online/common_logic/langchain_integration/chains/translate_chain.py b/source/lambda/online/common_logic/langchain_integration/chains/translate_chain.py index 638128b75..2be6a37d8 100644 --- a/source/lambda/online/common_logic/langchain_integration/chains/translate_chain.py +++ b/source/lambda/online/common_logic/langchain_integration/chains/translate_chain.py @@ -32,7 +32,8 @@ def create_chain(cls, model_kwargs=None, **kwargs): model_kwargs = model_kwargs or {} model_kwargs = {**cls.default_model_kwargs, **model_kwargs} llm_chain = super().create_chain(model_kwargs=model_kwargs, **kwargs) - llm_chain = llm_chain | RunnableLambda(lambda x: x.strip('"')) # postprocess + llm_chain = llm_chain | RunnableLambda( + lambda x: x.strip('"')) # postprocess return llm_chain diff --git a/source/lambda/online/common_logic/langchain_integration/chat_models/bedrock_models.py b/source/lambda/online/common_logic/langchain_integration/chat_models/bedrock_models.py index 8282392f0..9a8c6a8de 100644 --- a/source/lambda/online/common_logic/langchain_integration/chat_models/bedrock_models.py +++ b/source/lambda/online/common_logic/langchain_integration/chat_models/bedrock_models.py @@ -4,7 +4,7 @@ MessageType, LLMModelType ) -from common_logic.common_utils.logger_utils import get_logger,llm_messages_print_decorator +from common_logic.common_utils.logger_utils import get_logger, llm_messages_print_decorator from . import Model logger = get_logger("bedrock_model") @@ -18,7 +18,8 @@ class ChatBedrockConverse(_ChatBedrockConverse): # Bedrock model type class Claude2(Model): model_id = LLMModelType.CLAUDE_2 - default_model_kwargs = {"max_tokens": 2000, "temperature": 0.7, "top_p": 0.9} + default_model_kwargs = {"max_tokens": 2000, + "temperature": 0.7, "top_p": 0.9} enable_auto_tool_choice = False @classmethod @@ -44,7 +45,8 @@ def create_model(cls, model_kwargs=None, **kwargs): enable_prefill=cls.enable_prefill, **model_kwargs, ) - llm.client.converse_stream = llm_messages_print_decorator(llm.client.converse_stream) + llm.client.converse_stream = llm_messages_print_decorator( + llm.client.converse_stream) llm.client.converse = llm_messages_print_decorator(llm.client.converse) return llm @@ -84,19 +86,17 @@ class MistralLarge2407(Claude2): class Llama3d1Instruct70B(Claude2): model_id = LLMModelType.LLAMA3_1_70B_INSTRUCT - enable_auto_tool_choice = False + enable_auto_tool_choice = False enable_prefill = False + class Llama3d2Instruct90B(Claude2): model_id = LLMModelType.LLAMA3_2_90B_INSTRUCT - enable_auto_tool_choice = False + enable_auto_tool_choice = False enable_prefill = False class CohereCommandRPlus(Claude2): model_id = LLMModelType.COHERE_COMMAND_R_PLUS - enable_auto_tool_choice = False + enable_auto_tool_choice = False enable_prefill = False - - - diff --git a/source/lambda/online/common_logic/langchain_integration/chat_models/openai_models.py b/source/lambda/online/common_logic/langchain_integration/chat_models/openai_models.py index fdddeb454..7afe025be 100644 --- a/source/lambda/online/common_logic/langchain_integration/chat_models/openai_models.py +++ b/source/lambda/online/common_logic/langchain_integration/chat_models/openai_models.py @@ -5,9 +5,11 @@ logger = get_logger("openai_model") + class ChatGPT35(Model): model_id = LLMModelType.CHATGPT_35_TURBO_0125 - default_model_kwargs = {"max_tokens": 2000, "temperature": 0.7, "top_p": 0.9} + default_model_kwargs = {"max_tokens": 2000, + "temperature": 0.7, "top_p": 0.9} @classmethod def create_model(cls, model_kwargs=None, **kwargs): @@ -25,4 +27,4 @@ class ChatGPT4Turbo(ChatGPT35): class ChatGPT4o(ChatGPT35): - model_id = LLMModelType.CHATGPT_4O \ No newline at end of file + model_id = LLMModelType.CHATGPT_4O diff --git a/source/lambda/online/common_logic/langchain_integration/retrievers/retriever.py b/source/lambda/online/common_logic/langchain_integration/retrievers/retriever.py index d1c9884c8..f117b7413 100644 --- a/source/lambda/online/common_logic/langchain_integration/retrievers/retriever.py +++ b/source/lambda/online/common_logic/langchain_integration/retrievers/retriever.py @@ -1,33 +1,32 @@ -import json -import os -os.environ["PYTHONUNBUFFERED"] = "1" -import logging -import sys - -import boto3 -from common_logic.common_utils.chatbot_utils import ChatbotManager -from common_logic.langchain_integration.retrievers.utils.aos_retrievers import ( - QueryDocumentBM25Retriever, - QueryDocumentKNNRetriever, - QueryQuestionRetriever, +from langchain.schema.runnable import RunnableLambda, RunnablePassthrough +from langchain.retrievers.merger_retriever import MergerRetriever +from langchain_community.retrievers import AmazonKnowledgeBasesRetriever +from langchain.retrievers import ( + ContextualCompressionRetriever, ) -from common_logic.langchain_integration.retrievers.utils.context_utils import ( - retriever_results_format, +from common_logic.langchain_integration.retrievers.utils.websearch_retrievers import ( + GoogleRetriever, ) from common_logic.langchain_integration.retrievers.utils.reranker import ( BGEReranker, MergeReranker, ) -from common_logic.langchain_integration.retrievers.utils.websearch_retrievers import ( - GoogleRetriever, +from common_logic.langchain_integration.retrievers.utils.context_utils import ( + retriever_results_format, ) -from langchain.retrievers import ( - ContextualCompressionRetriever, +from common_logic.langchain_integration.retrievers.utils.aos_retrievers import ( + QueryDocumentBM25Retriever, + QueryDocumentKNNRetriever, + QueryQuestionRetriever, ) -from langchain_community.retrievers import AmazonKnowledgeBasesRetriever -from langchain.retrievers.merger_retriever import MergerRetriever -from langchain.schema.runnable import RunnableLambda, RunnablePassthrough -from langchain_community.retrievers import AmazonKnowledgeBasesRetriever +from common_logic.common_utils.chatbot_utils import ChatbotManager +import boto3 +import sys +import logging +import json +import os +os.environ["PYTHONUNBUFFERED"] = "1" + logger = logging.getLogger("retriever") logger.setLevel(logging.INFO) @@ -56,7 +55,8 @@ def get_bedrock_kb_retrievers(knowledge_base_id_list, top_k: int): retriever_list = [ AmazonKnowledgeBasesRetriever( knowledge_base_id=knowledge_base_id, - retrieval_config={"vectorSearchConfiguration": {"numberOfResults": top_k}}, + retrieval_config={"vectorSearchConfiguration": { + "numberOfResults": top_k}}, ) for knowledge_base_id in knowledge_base_id_list ] diff --git a/source/lambda/online/common_logic/langchain_integration/retrievers/utils/aos_retrievers.py b/source/lambda/online/common_logic/langchain_integration/retrievers/utils/aos_retrievers.py index 5fb9ff4d5..1517038ca 100644 --- a/source/lambda/online/common_logic/langchain_integration/retrievers/utils/aos_retrievers.py +++ b/source/lambda/online/common_logic/langchain_integration/retrievers/utils/aos_retrievers.py @@ -21,7 +21,8 @@ # region = os.environ["AWS_REGION"] kb_enabled = os.environ["KNOWLEDGE_BASE_ENABLED"].lower() == "true" kb_type = json.loads(os.environ["KNOWLEDGE_BASE_TYPE"]) -intelli_agent_kb_enabled = kb_type.get("intelliAgentKb", {}).get("enabled", False) +intelli_agent_kb_enabled = kb_type.get( + "intelliAgentKb", {}).get("enabled", False) aos_endpoint = os.environ.get("AOS_ENDPOINT", "") aos_domain_name = os.environ.get("AOS_DOMAIN_NAME", "smartsearch") aos_secret = os.environ.get("AOS_SECRET_NAME", "opensearch-master-user") @@ -72,7 +73,8 @@ def get_similarity_embedding( model_type: str = "vector", ) -> List[List[float]]: if model_type.lower() == "bedrock": - embeddings = BedrockEmbeddings(model_id=embedding_model_endpoint, region_name=bedrock_region) + embeddings = BedrockEmbeddings( + model_id=embedding_model_endpoint, region_name=bedrock_region) response = embeddings.embed_query(query) else: query_similarity_embedding_prompt = query @@ -96,7 +98,8 @@ def get_relevance_embedding( model_type: str = "vector", ): if model_type == "bedrock": - embeddings = BedrockEmbeddings(model_id=embedding_model_endpoint, region_name=bedrock_region) + embeddings = BedrockEmbeddings( + model_id=embedding_model_endpoint, region_name=bedrock_region) response = embeddings.embed_query(query) else: if model_type == "vector": @@ -343,7 +346,8 @@ def get_context(aos_hit, index_name, window_size): next_chunk_id = aos_hit["_source"]["metadata"]["heading_hierarchy"]["next"] next_pos = 0 while ( - next_chunk_id and next_chunk_id.startswith("$") and next_pos < window_size + next_chunk_id and next_chunk_id.startswith( + "$") and next_pos < window_size ): opensearch_query_response = aos_client.search( index_name=index_name, @@ -581,9 +585,11 @@ def organize_results( if doc: result["doc"] = doc else: - response_list = asyncio.run(self.__spawn_task(aos_hits, context_size)) + response_list = asyncio.run( + self.__spawn_task(aos_hits, context_size)) for context, result in zip(response_list, results): - result["doc"] = "\n".join(context[0] + [result["content"]] + context[1]) + result["doc"] = "\n".join( + context[0] + [result["content"]] + context[1]) return results @timeit @@ -742,9 +748,11 @@ def organize_results( if doc: result["doc"] = doc else: - response_list = asyncio.run(self.__spawn_task(aos_hits, context_size)) + response_list = asyncio.run( + self.__spawn_task(aos_hits, context_size)) for context, result in zip(response_list, results): - result["doc"] = "\n".join(context[0] + [result["doc"]] + context[1]) + result["doc"] = "\n".join( + context[0] + [result["doc"]] + context[1]) # context = get_context(aos_hit['_source']["metadata"]["heading_hierarchy"]["previous"], # aos_hit['_source']["metadata"]["heading_hierarchy"]["next"], # aos_index, diff --git a/source/lambda/online/common_logic/langchain_integration/retrievers/utils/context_utils.py b/source/lambda/online/common_logic/langchain_integration/retrievers/utils/context_utils.py index cada844c0..f2d330cb0 100644 --- a/source/lambda/online/common_logic/langchain_integration/retrievers/utils/context_utils.py +++ b/source/lambda/online/common_logic/langchain_integration/retrievers/utils/context_utils.py @@ -24,7 +24,8 @@ def contexts_trunc(docs: list[dict], context_num=2): context_strs.append(content) s.add(content) context_docs.append( - {"doc": content, "source": doc["source"], "score": doc["score"]} + {"doc": content, + "source": doc["source"], "score": doc["score"]} ) context_sources.append(doc["source"]) return { diff --git a/source/lambda/online/common_logic/langchain_integration/retrievers/utils/reranker.py b/source/lambda/online/common_logic/langchain_integration/retrievers/utils/reranker.py index 7405d59d5..4ac5cb094 100644 --- a/source/lambda/online/common_logic/langchain_integration/retrievers/utils/reranker.py +++ b/source/lambda/online/common_logic/langchain_integration/retrievers/utils/reranker.py @@ -1,3 +1,8 @@ +from sm_utils import SagemakerEndpointVectorOrCross +from langchain.retrievers.document_compressors.base import BaseDocumentCompressor +from langchain.schema import Document +from langchain.callbacks.manager import Callbacks +from typing import Dict, Optional, Sequence, Any import json import os import time @@ -7,20 +12,16 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) -from typing import Dict, Optional, Sequence, Any - -from langchain.callbacks.manager import Callbacks -from langchain.schema import Document -from langchain.retrievers.document_compressors.base import BaseDocumentCompressor - -from sm_utils import SagemakerEndpointVectorOrCross rerank_model_endpoint = os.environ.get("RERANK_ENDPOINT", "") """Document compressor that uses BGE reranker model.""" + + class BGEM3Reranker(BaseDocumentCompressor): """Number of documents to return.""" + def _colbert_score_np(self, q_reps, p_reps): token_scores = np.einsum('nik,njk->nij', q_reps, p_reps) scores = token_scores.max(-1) @@ -74,8 +75,10 @@ def compress_documents( query_colbert_list.append(query["colbert"][:rerank_text_length]) doc_colbert_list.append(doc[:rerank_text_length]) score_list = [] - logger.info(f'rerank pair num {len(query_colbert_list)}, m3 method: colbert score') - score_list = asyncio.run(self.__spawn_task(query_colbert_list, doc_colbert_list)) + logger.info( + f'rerank pair num {len(query_colbert_list)}, m3 method: colbert score') + score_list = asyncio.run(self.__spawn_task( + query_colbert_list, doc_colbert_list)) final_results = [] debug_info = query["debug_info"] debug_info["knowledge_qa_rerank"] = [] @@ -84,25 +87,31 @@ def compress_documents( # set common score for llm. doc.metadata["score"] = doc.metadata["rerank_score"] final_results.append(doc) - debug_info["knowledge_qa_rerank"].append((doc.page_content, doc.metadata["retrieval_content"], doc.metadata["source"], score)) - final_results.sort(key=lambda x: x.metadata["rerank_score"], reverse=True) - debug_info["knowledge_qa_rerank"].sort(key=lambda x: x[-1], reverse=True) + debug_info["knowledge_qa_rerank"].append( + (doc.page_content, doc.metadata["retrieval_content"], doc.metadata["source"], score)) + final_results.sort( + key=lambda x: x.metadata["rerank_score"], reverse=True) + debug_info["knowledge_qa_rerank"].sort( + key=lambda x: x[-1], reverse=True) recall_end_time = time.time() elpase_time = recall_end_time - start logger.info(f"runing time of rerank: {elpase_time}s seconds") return final_results + """Document compressor that uses BGE reranker model.""" + + class BGEReranker(BaseDocumentCompressor): """Number of documents to return.""" - config: Dict={"run_name": "BGEReranker"} + config: Dict = {"run_name": "BGEReranker"} enable_debug: Any target_model: Any - rerank_model_endpoint: str=rerank_model_endpoint - top_k: int=10 + rerank_model_endpoint: str = rerank_model_endpoint + top_k: int = 10 - def __init__(self,enable_debug=False, rerank_model_endpoint=rerank_model_endpoint, target_model=None, top_k=10): + def __init__(self, enable_debug=False, rerank_model_endpoint=rerank_model_endpoint, target_model=None, top_k=10): super().__init__() self.enable_debug = enable_debug self.rerank_model_endpoint = rerank_model_endpoint @@ -125,7 +134,8 @@ async def __spawn_task(self, rerank_pair): task_list = [] loop = asyncio.get_event_loop() for batch_start in range(0, len(rerank_pair), batch_size): - task = asyncio.create_task(self.__ainvoke_rerank_model(rerank_pair[batch_start:batch_start + batch_size], loop)) + task = asyncio.create_task(self.__ainvoke_rerank_model( + rerank_pair[batch_start:batch_start + batch_size], loop)) task_list.append(task) return await asyncio.gather(*task_list) @@ -157,7 +167,8 @@ def compress_documents( for doc in _docs: rerank_pair.append([query["query"], doc[:rerank_text_length]]) score_list = [] - logger.info(f'rerank pair num {len(rerank_pair)}, endpoint_name: {self.rerank_model_endpoint}') + logger.info( + f'rerank pair num {len(rerank_pair)}, endpoint_name: {self.rerank_model_endpoint}') response_list = asyncio.run(self.__spawn_task(rerank_pair)) for response in response_list: score_list.extend(json.loads(response)) @@ -171,15 +182,21 @@ def compress_documents( doc.metadata["score"] = doc.metadata["rerank_score"] final_results.append(doc) if self.enable_debug: - debug_info["knowledge_qa_rerank"].append((doc.page_content, doc.metadata["retrieval_content"], doc.metadata["source"], score)) - final_results.sort(key=lambda x: x.metadata["rerank_score"], reverse=True) - debug_info["knowledge_qa_rerank"].sort(key=lambda x: x[-1], reverse=True) + debug_info["knowledge_qa_rerank"].append( + (doc.page_content, doc.metadata["retrieval_content"], doc.metadata["source"], score)) + final_results.sort( + key=lambda x: x.metadata["rerank_score"], reverse=True) + debug_info["knowledge_qa_rerank"].sort( + key=lambda x: x[-1], reverse=True) recall_end_time = time.time() elpase_time = recall_end_time - start logger.info(f"runing time of rerank: {elpase_time}s seconds") return final_results[:self.top_k] + """Document compressor that uses retriever score.""" + + class MergeReranker(BaseDocumentCompressor): """Number of documents to return.""" @@ -214,4 +231,4 @@ def compress_documents( recall_end_time = time.time() elpase_time = recall_end_time - start logger.info(f"runing time of rerank: {elpase_time}s seconds") - return final_results \ No newline at end of file + return final_results diff --git a/source/lambda/online/common_logic/langchain_integration/retrievers/utils/test.py b/source/lambda/online/common_logic/langchain_integration/retrievers/utils/test.py index 2c7daa753..3e6ac3ea3 100644 --- a/source/lambda/online/common_logic/langchain_integration/retrievers/utils/test.py +++ b/source/lambda/online/common_logic/langchain_integration/retrievers/utils/test.py @@ -1,26 +1,26 @@ -import json -import os - -os.environ["PYTHONUNBUFFERED"] = "1" -import logging -import sys - -import boto3 -from common_logic.common_utils.lambda_invoke_utils import chatbot_lambda_call_wrapper +from langchain_community.retrievers import AmazonKnowledgeBasesRetriever +from langchain.schema.runnable import RunnableLambda, RunnablePassthrough +from langchain.retrievers.merger_retriever import MergerRetriever +from langchain.retrievers import ( + AmazonKnowledgeBasesRetriever, + ContextualCompressionRetriever, +) +from lambda_retriever.utils.reranker import MergeReranker +from lambda_retriever.utils.context_utils import retriever_results_format from lambda_retriever.utils.aos_retrievers import ( QueryDocumentBM25Retriever, QueryDocumentKNNRetriever, QueryQuestionRetriever, ) -from lambda_retriever.utils.context_utils import retriever_results_format -from lambda_retriever.utils.reranker import MergeReranker -from langchain.retrievers import ( - AmazonKnowledgeBasesRetriever, - ContextualCompressionRetriever, -) -from langchain.retrievers.merger_retriever import MergerRetriever -from langchain.schema.runnable import RunnableLambda, RunnablePassthrough -from langchain_community.retrievers import AmazonKnowledgeBasesRetriever +from common_logic.common_utils.lambda_invoke_utils import chatbot_lambda_call_wrapper +import boto3 +import sys +import logging +import json +import os + +os.environ["PYTHONUNBUFFERED"] = "1" + logger = logging.getLogger("retriever") logger.setLevel(logging.INFO) @@ -38,7 +38,8 @@ def get_bedrock_kb_retrievers(knowledge_base_id_list, top_k: int): retriever_list = [ AmazonKnowledgeBasesRetriever( knowledge_base_id=knowledge_base_id, - retrieval_config={"vectorSearchConfiguration": {"numberOfResults": top_k}}, + retrieval_config={"vectorSearchConfiguration": { + "numberOfResults": top_k}}, ) for knowledge_base_id in knowledge_base_id_list ] @@ -110,7 +111,8 @@ def lambda_handler(event, context=None): retriever_type = event_body["type"] for retriever_config in event_body["retrievers"]: # retriever_type = retriever_config["type"] - retriever_list.extend(get_custom_retrievers(retriever_config, retriever_type)) + retriever_list.extend(get_custom_retrievers( + retriever_config, retriever_type)) # Re-rank not used. # rerankers = event_body.get("rerankers", None) diff --git a/source/lambda/online/common_logic/langchain_integration/retrievers/utils/websearch_retrievers.py b/source/lambda/online/common_logic/langchain_integration/retrievers/utils/websearch_retrievers.py index babdeb9b3..b58b27f2a 100644 --- a/source/lambda/online/common_logic/langchain_integration/retrievers/utils/websearch_retrievers.py +++ b/source/lambda/online/common_logic/langchain_integration/retrievers/utils/websearch_retrievers.py @@ -1,3 +1,8 @@ +from langchain.agents import Tool +from langchain.schema.retriever import BaseRetriever +from langchain.docstore.document import Document +from langchain.callbacks.manager import CallbackManagerForRetrieverRun +from langchain_community.utilities import GoogleSearchAPIWrapper import asyncio import aiohttp import time @@ -9,23 +14,19 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) -from langchain_community.utilities import GoogleSearchAPIWrapper -from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.docstore.document import Document -from langchain.schema.retriever import BaseRetriever -from langchain.agents import Tool -GOOGLE_API_KEY=os.environ.get('GOOGLE_API_KEY',None) -GOOGLE_CSE_ID=os.environ.get('GOOGLE_CSE_ID',None) +GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY', None) +GOOGLE_CSE_ID = os.environ.get('GOOGLE_CSE_ID', None) class GoogleSearchTool(): - tool:Tool - topk:int = 5 - - def __init__(self,top_k=5): + tool: Tool + topk: int = 5 + + def __init__(self, top_k=5): self.topk = top_k search = GoogleSearchAPIWrapper() + def top_results(query): return search.results(query, self.topk) self.tool = Tool( @@ -33,18 +34,20 @@ def top_results(query): description="Search Google for recent results.", func=top_results, ) - - def run(self,query): + + def run(self, query): return self.tool.run(query) + def remove_html_tags(text): soup = BeautifulSoup(text, 'html.parser') text = soup.get_text() - text = re.sub(r'\r{1,}',"\n\n",text) - text = re.sub(r'\t{1,}',"\t",text) - text = re.sub(r'\n{2,}',"\n\n",text) + text = re.sub(r'\r{1,}', "\n\n", text) + text = re.sub(r'\t{1,}', "\t", text) + text = re.sub(r'\n{2,}', "\n\n", text) return text + async def fetch(session, url, timeout): try: async with session.get(url) as response: @@ -56,7 +59,7 @@ async def fetch(session, url, timeout): print(f"ClientError:{url}", str(e)) return '' - + async def fetch_all(urls, timeout): async with aiohttp.ClientSession() as session: tasks = [] @@ -66,7 +69,8 @@ async def fetch_all(urls, timeout): results = await asyncio.gather(*tasks) return results - + + def web_search(**args): if not GOOGLE_API_KEY or not GOOGLE_CSE_ID: logger.info('Missing google API key') @@ -75,13 +79,13 @@ def web_search(**args): result = tool.run(args['query']) return [item for item in result if 'title' in item and 'link' in item and 'snippet' in item] - + def add_webpage_content(snippet_results): t1 = time.time() urls = [item['doc_author'] for item in snippet_results] loop = asyncio.get_event_loop() - fetch_results = loop.run_until_complete(fetch_all(urls,5)) - t2= time.time() + fetch_results = loop.run_until_complete(fetch_all(urls, 5)) + t2 = time.time() logger.info(f'deep web search time:{t2-t1:1f}s') final_results = [] for i, result in enumerate(fetch_results): @@ -89,10 +93,11 @@ def add_webpage_content(snippet_results): continue page_content = remove_html_tags(result) final_results.append({**snippet_results[i], - 'doc':snippet_results[i]['doc']+'\n'+page_content[:10000] + 'doc': snippet_results[i]['doc']+'\n'+page_content[:10000] }) return final_results + class GoogleRetriever(BaseRetriever): search: Any result_num: Any @@ -121,4 +126,4 @@ def _get_relevant_documents( return doc_list def get_whole_doc(self, results) -> Dict: - return add_webpage_content(self._get_relevant_documents(results)) \ No newline at end of file + return add_webpage_content(self._get_relevant_documents(results)) diff --git a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/chat.py b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/chat.py index c007c3534..93f053b1e 100644 --- a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/chat.py +++ b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/chat.py @@ -1,5 +1,4 @@ # give chat response -def chat(response:str): +def chat(response: str): return response - \ No newline at end of file diff --git a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/comparison_rag.py b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/comparison_rag.py index 3bf573967..f9d94dc6d 100644 --- a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/comparison_rag.py +++ b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/comparison_rag.py @@ -4,6 +4,7 @@ LLMTaskType ) + def knowledge_base_retrieve(retriever_params, context=None): output: str = invoke_lambda( event_body=retriever_params, @@ -14,6 +15,7 @@ def knowledge_base_retrieve(retriever_params, context=None): contexts = [doc["page_content"] for doc in output["result"]["docs"]] return contexts + def lambda_handler(event_body, context=None): state = event_body['state'] retriever_params = state["chatbot_config"]["comparison_rag_config"]["retriever_config"] @@ -26,26 +28,26 @@ def lambda_handler(event_body, context=None): # llm generate system_prompt = (f"请根据context内的信息回答问题:\n" - "\n" - " - 回复内容需要展现出礼貌。回答内容为一句话,言简意赅。\n" - " - 使用中文回答。\n" - "\n" - "\n" - f"{context}\n" - "" - ) - - output:str = invoke_lambda( + "\n" + " - 回复内容需要展现出礼貌。回答内容为一句话,言简意赅。\n" + " - 使用中文回答。\n" + "\n" + "\n" + f"{context}\n" + "" + ) + + output: str = invoke_lambda( lambda_name='Online_LLM_Generate', lambda_module_path="lambda_llm_generate.llm_generate", handler_name='lambda_handler', event_body={ "llm_config": { - **state['chatbot_config']['rag_daily_reception_config']['llm_config'], + **state['chatbot_config']['rag_daily_reception_config']['llm_config'], "system_prompt": system_prompt, "intent_type": LLMTaskType.CHAT}, "llm_input": {"query": state['query'], "chat_history": state['chat_history']} - } - ) + } + ) - return {"code":0, "result":output} \ No newline at end of file + return {"code": 0, "result": output} diff --git a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/get_weather.py b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/get_weather.py index ccecb204c..f1bd72c41 100644 --- a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/get_weather.py +++ b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/get_weather.py @@ -1,7 +1,8 @@ # get weather tool import requests -def get_weather(city_name:str): + +def get_weather(city_name: str): if not isinstance(city_name, str): raise TypeError("City name must be a string") @@ -14,21 +15,22 @@ def get_weather(city_name:str): "observation_time", ], } - + try: resp = requests.get(f"https://wttr.in/{city_name}?format=j1") resp.raise_for_status() resp = resp.json() - ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} + ret = {k: {_v: resp[k][0][_v] for _v in v} + for k, v in key_selection.items()} except: import traceback ret = ("Error encountered while fetching weather data!\n" + traceback.format_exc() - ) + ) return str(ret) -def lambda_handler(event_body,context=None): +def lambda_handler(event_body, context=None): result = get_weather(**event_body['kwargs']) - return {"code":0, "result": result} \ No newline at end of file + return {"code": 0, "result": result} diff --git a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/give_final_response.py b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/give_final_response.py index 82146d9b0..eb3f9bbc2 100644 --- a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/give_final_response.py +++ b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/give_final_response.py @@ -1,4 +1,4 @@ # give final response tool -def give_final_response(response:str): - return response \ No newline at end of file +def give_final_response(response: str): + return response diff --git a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/give_rhetorical_question.py b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/give_rhetorical_question.py index ac78268af..085713e73 100644 --- a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/give_rhetorical_question.py +++ b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/give_rhetorical_question.py @@ -1,4 +1,4 @@ # give rhetorical question -def give_rhetorical_question(question:str): - return question \ No newline at end of file +def give_rhetorical_question(question: str): + return question diff --git a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/rag.py b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/rag.py index 1d35e7103..d10af2a21 100644 --- a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/rag.py +++ b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/rag.py @@ -1,4 +1,4 @@ -from common_logic.common_utils.lambda_invoke_utils import invoke_lambda,StateContext +from common_logic.common_utils.lambda_invoke_utils import invoke_lambda, StateContext from common_logic.common_utils.prompt_utils import get_prompt_templates_from_ddb from common_logic.common_utils.constant import ( LLMTaskType @@ -6,9 +6,10 @@ from common_logic.common_utils.lambda_invoke_utils import send_trace from common_logic.langchain_integration.retrievers.retriever import lambda_handler as retrieve_fn from common_logic.langchain_integration.chains import LLMChain -import threading +import threading -def rag_tool(retriever_config:dict,query=None): + +def rag_tool(retriever_config: dict, query=None): state = StateContext.get_current_state() # state = event_body['state'] context_list = [] @@ -16,20 +17,22 @@ def rag_tool(retriever_config:dict,query=None): context_list.extend(state['qq_match_results']) figure_list = [] retriever_params = retriever_config - retriever_params["query"] = query or state[retriever_config.get("query_key","query")] + retriever_params["query"] = query or state[retriever_config.get( + "query_key", "query")] output = retrieve_fn(retriever_params) for doc in output["result"]["docs"]: context_list.append(doc["page_content"]) - figure_list = figure_list + doc.get("figure",[]) - + figure_list = figure_list + doc.get("figure", []) + # Remove duplicate figures unique_set = {tuple(d.items()) for d in figure_list} unique_figure_list = [dict(t) for t in unique_set] state['extra_response']['figures'] = unique_figure_list - - send_trace(f"\n\n**rag-contexts:**\n\n {context_list}", enable_trace=state["enable_trace"]) - + + send_trace( + f"\n\n**rag-contexts:**\n\n {context_list}", enable_trace=state["enable_trace"]) + group_name = state['chatbot_config']['group_name'] llm_config = state["chatbot_config"]["private_knowledge_config"]['llm_config'] chatbot_id = state["chatbot_config"]["chatbot_id"] @@ -42,21 +45,20 @@ def rag_tool(retriever_config:dict,query=None): ) llm_config = { - **prompt_templates_from_ddb, - **llm_config, - "stream": state["stream"], - "intent_type": task_type, - } - + **prompt_templates_from_ddb, + **llm_config, + "stream": state["stream"], + "intent_type": task_type, + } + llm_input = { - "contexts": context_list, - "query": state["query"], - "chat_history": state["chat_history"] - } + "contexts": context_list, + "query": state["query"], + "chat_history": state["chat_history"] + } chain = LLMChain.get_chain( **llm_config ) output = chain.invoke(llm_input) - return output,output - + return output, output diff --git a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/step_back_rag.py b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/step_back_rag.py index cbc09a57d..b20440a17 100644 --- a/source/lambda/online/common_logic/langchain_integration/tools/common_tools/step_back_rag.py +++ b/source/lambda/online/common_logic/langchain_integration/tools/common_tools/step_back_rag.py @@ -4,6 +4,7 @@ LLMTaskType ) + def knowledge_base_retrieve(retriever_params, context=None): output: str = invoke_lambda( event_body=retriever_params, @@ -14,6 +15,7 @@ def knowledge_base_retrieve(retriever_params, context=None): contexts = [doc["page_content"] for doc in output["result"]["docs"]] return contexts + def lambda_handler(event_body, context=None): state = event_body['state'] retriever_params = state["chatbot_config"]["step_back_rag_config"]["retriever_config"] @@ -24,27 +26,27 @@ def lambda_handler(event_body, context=None): # llm generate system_prompt = (f"请根据context内的信息回答问题:\n" - "\n" - " - 回复内容需要展现出礼貌。回答内容为一句话,言简意赅。\n" - " - 每次回答总是先进行思考,并将思考过程写在标签中。\n" - " - 使用中文回答。\n" - "\n" - "\n" - f"{context}\n" - "" - ) - - output:str = invoke_lambda( + "\n" + " - 回复内容需要展现出礼貌。回答内容为一句话,言简意赅。\n" + " - 每次回答总是先进行思考,并将思考过程写在标签中。\n" + " - 使用中文回答。\n" + "\n" + "\n" + f"{context}\n" + "" + ) + + output: str = invoke_lambda( lambda_name='Online_LLM_Generate', lambda_module_path="lambda_llm_generate.llm_generate", handler_name='lambda_handler', event_body={ "llm_config": { - **state['chatbot_config']['rag_daily_reception_config']['llm_config'], + **state['chatbot_config']['rag_daily_reception_config']['llm_config'], "system_prompt": system_prompt, "intent_type": LLMTaskType.CHAT}, "llm_input": {"query": state['query'], "chat_history": state['chat_history']} - } - ) + } + ) - return {"code":0, "result":output} \ No newline at end of file + return {"code": 0, "result": output} diff --git a/source/lambda/online/lambda_agent/agent.py b/source/lambda/online/lambda_agent/agent.py index 495e2587c..2866841ef 100644 --- a/source/lambda/online/lambda_agent/agent.py +++ b/source/lambda/online/lambda_agent/agent.py @@ -2,23 +2,24 @@ RunnableLambda ) from common_logic.common_utils.prompt_utils import get_prompt_templates_from_ddb -from common_logic.common_utils.logger_utils import get_logger -from common_logic.common_utils.lambda_invoke_utils import invoke_lambda,chatbot_lambda_call_wrapper +from common_logic.common_utils.logger_utils import get_logger +from common_logic.common_utils.lambda_invoke_utils import invoke_lambda, chatbot_lambda_call_wrapper from common_logic.common_utils.constant import LLMTaskType from functions import get_tool_by_name logger = get_logger("agent") -def tool_calling(state:dict): + +def tool_calling(state: dict): agent_config = state["chatbot_config"]['agent_config'] tools = state['intent_fewshot_tools'] + agent_config['tools'] tool_defs = [get_tool_by_name( tool_name, - scene=state["chatbot_config"]['scene']).tool_def + scene=state["chatbot_config"]['scene']).tool_def for tool_name in tools ] - - other_chain_kwargs = state.get('other_chain_kwargs',{}) + + other_chain_kwargs = state.get('other_chain_kwargs', {}) llm_config = { **agent_config['llm_config'], **other_chain_kwargs, @@ -26,16 +27,16 @@ def tool_calling(state:dict): "fewshot_examples": state['intent_fewshot_examples'], } - agent_llm_type = state.get("agent_llm_type",None) or LLMTaskType.TOOL_CALLING_XML - + agent_llm_type = state.get( + "agent_llm_type", None) or LLMTaskType.TOOL_CALLING_XML + group_name = state['chatbot_config']['group_name'] chatbot_id = state['chatbot_config']['chatbot_id'] - # add prompt template from ddb prompt_templates_from_ddb = get_prompt_templates_from_ddb( group_name, - model_id = llm_config['model_id'], + model_id=llm_config['model_id'], task_type=agent_llm_type, chatbot_id=chatbot_id ) @@ -47,21 +48,21 @@ def tool_calling(state:dict): event_body={ "llm_config": { **prompt_templates_from_ddb, - **llm_config, + **llm_config, "intent_type": agent_llm_type }, "llm_input": state - } - ) + } + ) return { "agent_output": output, "current_agent_tools_def": tool_defs, "current_agent_model_id": agent_config['llm_config']['model_id'] - } + } @chatbot_lambda_call_wrapper -def lambda_handler(state:dict, context=None): +def lambda_handler(state: dict, context=None): output = tool_calling(state) - return output \ No newline at end of file + return output diff --git a/source/lambda/online/lambda_intention_detection/intention.py b/source/lambda/online/lambda_intention_detection/intention.py index 3499ea293..cc213e536 100644 --- a/source/lambda/online/lambda_intention_detection/intention.py +++ b/source/lambda/online/lambda_intention_detection/intention.py @@ -2,14 +2,15 @@ import pathlib import os -from common_logic.common_utils.logger_utils import get_logger -from common_logic.common_utils.lambda_invoke_utils import chatbot_lambda_call_wrapper,invoke_lambda +from common_logic.common_utils.logger_utils import get_logger +from common_logic.common_utils.lambda_invoke_utils import chatbot_lambda_call_wrapper, invoke_lambda from common_logic.langchain_integration.retrievers.retriever import lambda_handler as retrieve_fn logger = get_logger("intention") kb_enabled = os.environ["KNOWLEDGE_BASE_ENABLED"].lower() == "true" kb_type = json.loads(os.environ["KNOWLEDGE_BASE_TYPE"]) -intelli_agent_kb_enabled = kb_type.get("intelliAgentKb", {}).get("enabled", False) +intelli_agent_kb_enabled = kb_type.get( + "intelliAgentKb", {}).get("enabled", False) def get_intention_results(query: str, intention_config: dict): @@ -65,7 +66,7 @@ def get_intention_results(query: str, intention_config: dict): # "name": answer.get("intent","chat"), # "intent": answer.get("intent","chat"), # "kwargs": answer.get("kwargs", {}), - # }) + # }) else: intent_fewshot_examples = [] for doc in res["result"]["docs"]: @@ -82,21 +83,21 @@ def get_intention_results(query: str, intention_config: dict): "kwargs": doc.get("kwargs", {}), } intent_fewshot_examples.append(doc_item) - + return intent_fewshot_examples, True @chatbot_lambda_call_wrapper -def lambda_handler(state:dict, context=None): - intention_config = state["chatbot_config"].get("intention_config",{}) - query_key = intention_config.get("retriever_config",{}).get("query_key","query") +def lambda_handler(state: dict, context=None): + intention_config = state["chatbot_config"].get("intention_config", {}) + query_key = intention_config.get( + "retriever_config", {}).get("query_key", "query") query = state[query_key] - output:list = get_intention_results( - query, - { - **intention_config, - } - ) + output: list = get_intention_results( + query, + { + **intention_config, + } + ) return output - diff --git a/source/lambda/online/lambda_intention_detection/intention_utils/intent_utils/intent_aos_utils.py b/source/lambda/online/lambda_intention_detection/intention_utils/intent_utils/intent_aos_utils.py index 1ff7a3663..254909a5b 100644 --- a/source/lambda/online/lambda_intention_detection/intention_utils/intent_utils/intent_aos_utils.py +++ b/source/lambda/online/lambda_intention_detection/intention_utils/intent_utils/intent_aos_utils.py @@ -39,7 +39,8 @@ opensearch_client_lock = threading.Lock() abs_file_dir = os.path.dirname(__file__) -intent_example_path = os.path.join(abs_file_dir, "intent_examples", "examples.json") +intent_example_path = os.path.join( + abs_file_dir, "intent_examples", "examples.json") class LangchainOpenSearchClient: @@ -120,10 +121,12 @@ def create_index_name( return index_name def check_index_exist(self): - if_exist = self.opensearch_client.client.indices.exists(self.index_name) + if_exist = self.opensearch_client.client.indices.exists( + self.index_name) count = 0 if if_exist: - count = self.opensearch_client.client.count(index=self.index_name)["count"] + count = self.opensearch_client.client.count( + index=self.index_name)["count"] if_exist = count > 0 logger.info(f"{self.index_name} exist: {if_exist}, count: {count}") return if_exist @@ -133,7 +136,8 @@ def ingestion_intent_data(self): intent_examples = json.load(open(intent_example_path))["examples"] for intent_name, examples in intent_examples.items(): for example in examples: - doc = Document(page_content=example, metadata={"intent": intent_name}) + doc = Document(page_content=example, metadata={ + "intent": intent_name}) docs.append(doc) logger.info( f"ingestion intent doc, num: {len(docs)}, index_name: {self.index_name}" @@ -158,7 +162,7 @@ def search(self, query, top_k=5): logger.info(f"intent index search results:\n{ret}") return ret - + @staticmethod def intent_postprocess_top_1(retriever_list: list[dict]): retriever_list = sorted(retriever_list, key=lambda x: x["score"]) diff --git a/source/lambda/online/lambda_intention_detection/intention_utils/intent_utils/intent_utils.py b/source/lambda/online/lambda_intention_detection/intention_utils/intent_utils/intent_utils.py index 6796b6c53..4e4002f7b 100644 --- a/source/lambda/online/lambda_intention_detection/intention_utils/intent_utils/intent_utils.py +++ b/source/lambda/online/lambda_intention_detection/intention_utils/intent_utils/intent_utils.py @@ -129,7 +129,8 @@ def get_strict_intent(x): ) chain = intent_type_chain | RunnableBranch( - (lambda x: x["intent_type"] == IntentType.KNOWLEDGE_QA.value, sub_intent_chain), + (lambda x: x["intent_type"] == + IntentType.KNOWLEDGE_QA.value, sub_intent_chain), RunnablePassthrough(), ) diff --git a/source/lambda/online/lambda_llm_generate/llm_generate.py b/source/lambda/online/lambda_llm_generate/llm_generate.py index 0ce1506cb..1a22e5e8a 100644 --- a/source/lambda/online/lambda_llm_generate/llm_generate.py +++ b/source/lambda/online/lambda_llm_generate/llm_generate.py @@ -1,9 +1,10 @@ -from common_logic.common_utils.logger_utils import get_logger +from common_logic.common_utils.logger_utils import get_logger from common_logic.langchain_integration.chains import LLMChain from common_logic.common_utils.lambda_invoke_utils import chatbot_lambda_call_wrapper logger = get_logger("llm_generate") + @chatbot_lambda_call_wrapper def lambda_handler(event_body, context=None): llm_chain_config = event_body['llm_config'] @@ -15,4 +16,3 @@ def lambda_handler(event_body, context=None): output = chain.invoke(llm_chain_inputs) return output - diff --git a/source/lambda/online/lambda_main/main.py b/source/lambda/online/lambda_main/main.py index c04963dd2..29eb42913 100644 --- a/source/lambda/online/lambda_main/main.py +++ b/source/lambda/online/lambda_main/main.py @@ -369,8 +369,10 @@ def lambda_handler(event_body: dict, context: dict): return default_event_handler(event_body, context, entry_executor) except Exception as e: error_response = {"answer": str(e), "extra_response": {}} - enable_trace = event_body.get("chatbot_config", {}).get("enable_trace", True) + enable_trace = event_body.get( + "chatbot_config", {}).get("enable_trace", True) error_trace = f"\n### Error trace\n\n{traceback.format_exc()}\n\n" + load_ws_client(websocket_url) send_trace(error_trace, enable_trace=enable_trace) process_response(event_body, error_response) logger.error(f"{traceback.format_exc()}\nAn error occurred: {str(e)}") diff --git a/source/lambda/online/lambda_main/main_utils/online_entries/agent_base.py b/source/lambda/online/lambda_main/main_utils/online_entries/agent_base.py index e3d203863..1e998390a 100644 --- a/source/lambda/online/lambda_main/main_utils/online_entries/agent_base.py +++ b/source/lambda/online/lambda_main/main_utils/online_entries/agent_base.py @@ -1,6 +1,6 @@ import json -from langgraph.graph import StateGraph,END -from common_logic.common_utils.lambda_invoke_utils import invoke_lambda,node_monitor_wrapper +from langgraph.graph import StateGraph, END +from common_logic.common_utils.lambda_invoke_utils import invoke_lambda, node_monitor_wrapper from functions.tool_calling_parse import parse_tool_calling as _parse_tool_calling from common_logic.common_utils.lambda_invoke_utils import send_trace @@ -15,13 +15,14 @@ logger = get_logger("agent_base") + @node_monitor_wrapper def tools_choose_and_results_generation(state): # check once tool calling - agent_current_output:dict = invoke_lambda( + agent_current_output: dict = invoke_lambda( event_body={ **state - }, + }, lambda_name="Online_Agent", lambda_module_path="lambda_agent.agent", handler_name="lambda_handler" @@ -64,25 +65,27 @@ def results_evaluation(state): ) tool_calls = output['tool_calls'] md_tool_result = format_agent_result_output(tool_calls) - send_trace(f"\n\n**tool_calls parsed:** \n\n {md_tool_result}", state["stream"], state["ws_connection_id"], state["enable_trace"]) + send_trace(f"\n\n**tool_calls parsed:** \n\n {md_tool_result}", + state["stream"], state["ws_connection_id"], state["enable_trace"]) if not state["extra_response"].get("current_agent_intent_type", None): state["extra_response"]["current_agent_intent_type"] = output['tool_calls'][0]["name"] - + return { "function_calling_parse_ok": True, "function_calling_parsed_tool_calls": tool_calls, "agent_tool_history": [output['agent_message']] } - + except (ToolNotExistError, - ToolParameterNotExistError, - MultipleToolNameError, - ToolNotFound - ) as e: - send_trace(f"\n\n**tool_calls parse failed:** \n{str(e)}", state["stream"], state["ws_connection_id"], state["enable_trace"]) + ToolParameterNotExistError, + MultipleToolNameError, + ToolNotFound + ) as e: + send_trace(f"\n\n**tool_calls parse failed:** \n{str(e)}", + state["stream"], state["ws_connection_id"], state["enable_trace"]) return { "function_calling_parse_ok": False, - "agent_tool_history":[ + "agent_tool_history": [ e.agent_message, e.error_message ] @@ -106,11 +109,11 @@ def tool_execution(state): tool_kwargs = tool_call['kwargs'] # call tool output = invoke_lambda( - event_body = { - "tool_name":tool_name, - "state":state, - "kwargs":tool_kwargs - }, + event_body={ + "tool_name": tool_name, + "state": state, + "kwargs": tool_kwargs + }, lambda_name="Online_Tool_Execute", lambda_module_path="functions.lambda_tool", handler_name="lambda_handler" @@ -125,26 +128,29 @@ def tool_execution(state): output = format_tool_call_results( tool_call['model_id'], tool_call_results) - send_trace(f'**tool_execute_res:** \n{output["tool_message"]["content"]}', enable_trace=state["enable_trace"]) + send_trace( + f'**tool_execute_res:** \n{output["tool_message"]["content"]}', enable_trace=state["enable_trace"]) return { "agent_tool_history": [output['tool_message']] - } + } def build_agent_graph(chatbot_state_cls): def _results_evaluation_route(state: dict): - #TODO: pass no need tool calling or valid tool calling? + # TODO: pass no need tool calling or valid tool calling? if state["agent_repeated_call_validation"] and not state["function_calling_parse_ok"]: return "invalid tool calling" return "continue" workflow = StateGraph(chatbot_state_cls) - workflow.add_node("tools_choose_and_results_generation", tools_choose_and_results_generation) + workflow.add_node("tools_choose_and_results_generation", + tools_choose_and_results_generation) workflow.add_node("results_evaluation", results_evaluation) # add all edges workflow.set_entry_point("tools_choose_and_results_generation") - workflow.add_edge("tools_choose_and_results_generation","results_evaluation") + workflow.add_edge("tools_choose_and_results_generation", + "results_evaluation") # add conditional edges # the results of agent planning will be evaluated and decide next step: @@ -159,4 +165,4 @@ def _results_evaluation_route(state: dict): } ) app = workflow.compile() - return app \ No newline at end of file + return app diff --git a/source/lambda/online/lambda_main/main_utils/online_entries/common_entry.py b/source/lambda/online/lambda_main/main_utils/online_entries/common_entry.py index 4f95607dc..4f060d3fa 100644 --- a/source/lambda/online/lambda_main/main_utils/online_entries/common_entry.py +++ b/source/lambda/online/lambda_main/main_utils/online_entries/common_entry.py @@ -1,8 +1,8 @@ import traceback -import json -import uuid +import json +import uuid import re -from typing import Annotated, Any, TypedDict, List,Union +from typing import Annotated, Any, TypedDict, List, Union from common_logic.common_utils.chatbot_utils import ChatbotManager from common_logic.common_utils.constant import ( @@ -17,7 +17,7 @@ node_monitor_wrapper, send_trace, ) -from langchain_core.messages import ToolMessage,AIMessage +from langchain_core.messages import ToolMessage, AIMessage from common_logic.common_utils.logger_utils import get_logger from common_logic.common_utils.prompt_utils import get_prompt_templates_from_ddb from common_logic.common_utils.python_utils import add_messages, update_nest_dict @@ -25,7 +25,7 @@ from common_logic.langchain_integration.tools import ToolManager from langchain_core.tools import BaseTool from langchain_core.messages.tool import ToolCall -from langgraph.prebuilt.tool_node import ToolNode,TOOL_CALL_ERROR_TEMPLATE +from langgraph.prebuilt.tool_node import ToolNode, TOOL_CALL_ERROR_TEMPLATE from common_logic.langchain_integration.chains import LLMChain from common_logic.common_utils.serialization_utils import JSONEncoder from common_logic.common_utils.monitor_utils import format_intention_output, format_preprocess_output, format_qq_data @@ -47,7 +47,6 @@ logger = get_logger("common_entry") - class ChatbotState(TypedDict): ########### input/output states ########### # inputs @@ -103,7 +102,8 @@ class ChatbotState(TypedDict): # current output of agent # agent_current_output: dict # # record messages during agent tool choose and calling, including agent message, tool ouput and error messages - agent_tool_history: Annotated[List[Union[AIMessage,ToolMessage]], add_messages] + agent_tool_history: Annotated[List[Union[AIMessage, + ToolMessage]], add_messages] # # the maximum number that agent node can be called # agent_repeated_call_limit: int # # the current call time of agent @@ -139,8 +139,8 @@ def query_preprocess(state: ChatbotState): # handler_name="lambda_handler", # ) - - query_rewrite_llm_type = state.get("query_rewrite_llm_type",None) or LLMTaskType.CONVERSATION_SUMMARY_TYPE + query_rewrite_llm_type = state.get( + "query_rewrite_llm_type", None) or LLMTaskType.CONVERSATION_SUMMARY_TYPE output = conversation_query_rewrite( query=state['query'], chat_history=state['chat_history'], @@ -161,7 +161,7 @@ def intention_detection(state: ChatbotState): retriever_params["query"] = state[ retriever_params.get("retriever_config", {}).get("query_key", "query") ] - + output = retrieve_fn(retriever_params) context_list = [] qq_match_contexts = [] @@ -192,8 +192,9 @@ def intention_detection(state: ChatbotState): return {"qq_match_results": context_list, "intent_type": "intention detected"} # get intention results from aos - intention_config = state["chatbot_config"].get("intention_config",{}) - query_key = intention_config.get("retriever_config",{}).get("query_key","query") + intention_config = state["chatbot_config"].get("intention_config", {}) + query_key = intention_config.get( + "retriever_config", {}).get("query_key", "query") query = state[query_key] intent_fewshot_examples, intention_ready = get_intention_results( query, @@ -211,15 +212,15 @@ def intention_detection(state: ChatbotState): group_name = state["chatbot_config"]["group_name"] chatbot_id = state["chatbot_config"]["chatbot_id"] custom_qd_index = custom_index_desc(group_name, chatbot_id) - # TODO need to modify with new intent logic if not intention_ready and not custom_qd_index: - # if not intention_ready: + # if not intention_ready: # retrieve all knowledge retriever_params = state["chatbot_config"]["private_knowledge_config"] retriever_params["query"] = state[ - retriever_params.get("retriever_config", {}).get("query_key", "query") + retriever_params.get("retriever_config", {}).get( + "query_key", "query") ] threshold = Threshold.INTENTION_ALL_KNOWLEDGE_RETRIEVAL output = retrieve_fn(retriever_params) @@ -229,7 +230,8 @@ def intention_detection(state: ChatbotState): for doc in output["result"]["docs"]: if doc['score'] >= threshold: all_knowledge_retrieved_list.append(doc["page_content"]) - info_to_log.append(f"score: {doc['score']}, page_content: {doc['page_content'][:200]}") + info_to_log.append( + f"score: {doc['score']}, page_content: {doc['page_content'][:200]}") send_trace( f"all knowledge retrieved:\n{chr(10).join(info_to_log)}", @@ -247,7 +249,7 @@ def intention_detection(state: ChatbotState): state["ws_connection_id"], state["enable_trace"], ) - + # rename tool name intent_fewshot_tools = [tool_rename(i) for i in intent_fewshot_tools] intent_fewshot_examples = [ @@ -263,6 +265,7 @@ def intention_detection(state: ChatbotState): "intent_type": "intention detected" } + @node_monitor_wrapper def agent(state: ChatbotState): # two cases to invoke rag function @@ -272,7 +275,7 @@ def agent(state: ChatbotState): last_tool_messages = state["last_tool_messages"] if last_tool_messages and len(last_tool_messages) == 1: last_tool_message = last_tool_messages[0] - tool:BaseTool = ToolManager.get_tool( + tool: BaseTool = ToolManager.get_tool( scene=SceneType.COMMON, name=last_tool_message.name ) @@ -284,7 +287,7 @@ def agent(state: ChatbotState): content = last_tool_message.content return {"answer": content, "exit_tool_calling": True} - no_intention_condition = not state.get("intent_fewshot_examples",[]) + no_intention_condition = not state.get("intent_fewshot_examples", []) if ( # no_intention_condition, @@ -299,46 +302,47 @@ def agent(state: ChatbotState): "no_intention_condition, switch to rag tool", enable_trace=state["enable_trace"], ) - + all_knowledge_rag_tool = state['all_knowledge_rag_tool'] - agent_message = AIMessage(content="",tool_calls=[ + agent_message = AIMessage(content="", tool_calls=[ ToolCall( id=uuid.uuid4().hex, name=all_knowledge_rag_tool.name, - args={"query":state["query"]} + args={"query": state["query"]} ) ]) tools = [ ToolManager.get_tool( scene=SceneType.COMMON, name=all_knowledge_rag_tool.name - ) - ] - return {"agent_tool_history":[agent_message],"tools":tools} + ) + ] + return {"agent_tool_history": [agent_message], "tools": tools} # normal call agent_config = state["chatbot_config"]['agent_config'] - tools_name = list(set(state['intent_fewshot_tools'] + agent_config['tools'])) + tools_name = list( + set(state['intent_fewshot_tools'] + agent_config['tools'])) # get tools from tool names tools = [ ToolManager.get_tool( scene=SceneType.COMMON, name=name - ) + ) for name in tools_name ] llm_config = { **agent_config['llm_config'], "tools": tools, "fewshot_examples": state['intent_fewshot_examples'], - "all_knowledge_retrieved_list":state['all_knowledge_retrieved_list'] + "all_knowledge_retrieved_list": state['all_knowledge_retrieved_list'] } group_name = state['chatbot_config']['group_name'] chatbot_id = state['chatbot_config']['chatbot_id'] prompt_templates_from_ddb = get_prompt_templates_from_ddb( group_name, - model_id = llm_config['model_id'], + model_id=llm_config['model_id'], task_type=LLMTaskType.TOOL_CALLING_API, chatbot_id=chatbot_id ) @@ -349,11 +353,11 @@ def agent(state: ChatbotState): scene=SceneType.COMMON, **llm_config ) - - agent_message:AIMessage = tool_calling_chain.invoke({ - "query":state['query'], - "chat_history":state['chat_history'], - "agent_tool_history":state['agent_tool_history'] + + agent_message: AIMessage = tool_calling_chain.invoke({ + "query": state['query'], + "chat_history": state['chat_history'], + "agent_tool_history": state['agent_tool_history'] }) send_trace( @@ -364,7 +368,7 @@ def agent(state: ChatbotState): if not agent_message.tool_calls: return {"answer": agent_message.content, "exit_tool_calling": True} - return {"agent_tool_history":[agent_message],"tools":tools} + return {"agent_tool_history": [agent_message], "tools": tools} @node_monitor_wrapper @@ -379,17 +383,17 @@ def llm_direct_results_generation(state: ChatbotState): logger.info(prompt_templates_from_ddb) llm_config = { - **llm_config, - "stream": state["stream"], - "intent_type": task_type, - **prompt_templates_from_ddb, - } + **llm_config, + "stream": state["stream"], + "intent_type": task_type, + **prompt_templates_from_ddb, + } llm_input = { - "query": state["query"], - "chat_history": state["chat_history"], - } - + "query": state["query"], + "chat_history": state["chat_history"], + } + chain = LLMChain.get_chain( **llm_config ) @@ -407,7 +411,7 @@ def tool_execution(state): Returns: _type_: _description_ """ - tools:List[BaseTool] = state['tools'] + tools: List[BaseTool] = state['tools'] def handle_tool_errors(e): content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) @@ -418,25 +422,27 @@ def handle_tool_errors(e): tools, handle_tool_errors=handle_tool_errors ) - last_agent_message:AIMessage = state["agent_tool_history"][-1] + last_agent_message: AIMessage = state["agent_tool_history"][-1] tool_calls = last_agent_message.tool_calls - tool_messages:List[ToolMessage] = tool_node.invoke( - [AIMessage(content="",tool_calls=tool_calls)] + tool_messages: List[ToolMessage] = tool_node.invoke( + [AIMessage(content="", tool_calls=tool_calls)] ) - send_trace(f'**tool_execute_res:** \n{tool_messages}', enable_trace=state["enable_trace"]) + send_trace( + f'**tool_execute_res:** \n{tool_messages}', enable_trace=state["enable_trace"]) return { - "agent_tool_history": tool_messages, - "last_tool_messages": tool_messages - } + "agent_tool_history": tool_messages, + "last_tool_messages": tool_messages + } def final_results_preparation(state: ChatbotState): answer = state['answer'] - if isinstance(answer,str): - answer = re.sub(".*?","",answer,flags=re.S).strip() + if isinstance(answer, str): + answer = re.sub(".*?", + "", answer, flags=re.S).strip() state['answer'] = answer app_response = process_response(state["event_body"], state) return {"app_response": app_response} @@ -564,12 +570,12 @@ def build_graph(chatbot_state_cls): ##################################### app = None -def tool_rename(name:str) -> str: + +def tool_rename(name: str) -> str: """ rename the tool name """ - return name.replace("-","_") - + return name.replace("-", "_") def register_rag_tool_from_config(event_body: dict): @@ -582,15 +588,14 @@ def register_rag_tool_from_config(event_body: dict): for index_type, item_dict in chatbot.index_ids.items(): if index_type != IndexType.INTENTION and index_type != IndexType.QQ: for index_content in item_dict["value"].values(): - if "indexId" in index_content and "description" in index_content: # Find retriever contain index_id retrievers = event_body["chatbot_config"]["private_knowledge_config"]['retrievers'] - retriever = None + retriever = None for retriever in retrievers: if retriever["index_name"] == index_content["indexId"]: break - assert retriever is not None,retrievers + assert retriever is not None, retrievers rerankers = event_body["chatbot_config"]["private_knowledge_config"]['rerankers'] if rerankers: rerankers = [rerankers[0]] @@ -600,8 +605,8 @@ def register_rag_tool_from_config(event_body: dict): # TODO give specific retriever config ToolManager.register_common_rag_tool( retriever_config={ - "retrievers":[retriever], - "rerankers":rerankers, + "retrievers": [retriever], + "rerankers": rerankers, "llm_config": event_body["chatbot_config"]["private_knowledge_config"]['llm_config'] }, name=index_name, @@ -610,7 +615,8 @@ def register_rag_tool_from_config(event_body: dict): return_direct=True ) registered_tool_names.append(index_name) - logger.info(f"registered rag tool: {index_name}, description: {description}") + logger.info( + f"registered rag tool: {index_name}, description: {description}") return registered_tool_names @@ -618,7 +624,7 @@ def register_custom_lambda_tools_from_config(event_body): agent_config_tools = event_body['chatbot_config']['agent_config']['tools'] new_agent_config_tools = [] for tool in agent_config_tools: - if isinstance(tool,str): + if isinstance(tool, str): new_agent_config_tools.append(tool) elif isinstance(tool, dict): tool_name = tool['name'] @@ -627,18 +633,19 @@ def register_custom_lambda_tools_from_config(event_body): ToolManager.register_aws_lambda_as_tool( lambda_name=tool["lambda_name"], tool_def={ - "description":tool["description"], - "properties":tool['properties'], - "required":tool.get('required',[]) + "description": tool["description"], + "properties": tool['properties'], + "required": tool.get('required', []) }, name=tool_name, scene=SceneType.COMMON, - return_direct=tool.get("return_direct",False) + return_direct=tool.get("return_direct", False) ) new_agent_config_tools.append(tool_name) else: - raise ValueError(f"tool type {type(tool)}: {tool} is not supported") - + raise ValueError( + f"tool type {type(tool)}: {tool} is not supported") + event_body['chatbot_config']['agent_config']['tools'] = new_agent_config_tools return new_agent_config_tools @@ -673,7 +680,7 @@ def common_entry(event_body): ws_connection_id = event_body["ws_connection_id"] enable_trace = chatbot_config["enable_trace"] agent_config = event_body["chatbot_config"]["agent_config"] - + # register as rag tool for each aos index # print('private_knowledge_config',event_body["chatbot_config"]["private_knowledge_config"]) registered_tool_names = register_rag_tool_from_config(event_body) @@ -684,16 +691,17 @@ def common_entry(event_body): # register lambda tools register_custom_lambda_tools_from_config(event_body) - # - logger.info(f'event body to graph:\n{json.dumps(event_body,ensure_ascii=False,cls=JSONEncoder)}') + # + logger.info( + f'event body to graph:\n{json.dumps(event_body,ensure_ascii=False,cls=JSONEncoder)}') # define all knowledge rag tool all_knowledge_rag_tool = ToolManager.register_common_rag_tool( - retriever_config=event_body["chatbot_config"]["private_knowledge_config"], - name="all_knowledge_rag_tool", - scene=SceneType.COMMON, - description="all knowledge rag tool", - return_direct=True + retriever_config=event_body["chatbot_config"]["private_knowledge_config"], + name="all_knowledge_rag_tool", + scene=SceneType.COMMON, + description="all knowledge rag tool", + return_direct=True ) # invoke graph and get results @@ -712,9 +720,9 @@ def common_entry(event_body): "debug_infos": {}, "extra_response": {}, "qq_match_results": [], - "last_tool_messages":None, - "all_knowledge_rag_tool":all_knowledge_rag_tool, - "tools":None, + "last_tool_messages": None, + "all_knowledge_rag_tool": all_knowledge_rag_tool, + "tools": None, "ddb_additional_kwargs": {} }, config={"recursion_limit": 20} diff --git a/source/lambda/online/lambda_main/main_utils/online_entries/retail_entry.py b/source/lambda/online/lambda_main/main_utils/online_entries/retail_entry.py index cdd01e5c2..9bdb06705 100644 --- a/source/lambda/online/lambda_main/main_utils/online_entries/retail_entry.py +++ b/source/lambda/online/lambda_main/main_utils/online_entries/retail_entry.py @@ -2,13 +2,13 @@ import re import os import random -from datetime import datetime +from datetime import datetime from textwrap import dedent -from typing import TypedDict,Any,Annotated +from typing import TypedDict, Any, Annotated import validators -from langgraph.graph import StateGraph,END -from common_logic.common_utils.lambda_invoke_utils import invoke_lambda,node_monitor_wrapper -from common_logic.common_utils.python_utils import update_nest_dict,add_messages +from langgraph.graph import StateGraph, END +from common_logic.common_utils.lambda_invoke_utils import invoke_lambda, node_monitor_wrapper +from common_logic.common_utils.python_utils import update_nest_dict, add_messages from common_logic.common_utils.constant import ( LLMTaskType, ToolRuningMode, @@ -18,46 +18,49 @@ from functions.lambda_retail_tools.product_information_search import goods_dict from lambda_main.main_utils.parse_config import RetailConfigParser -from common_logic.common_utils.lambda_invoke_utils import send_trace,is_running_local +from common_logic.common_utils.lambda_invoke_utils import send_trace, is_running_local from common_logic.common_utils.logger_utils import get_logger from common_logic.common_utils.response_utils import process_response from common_logic.common_utils.serialization_utils import JSONEncoder -from common_logic.common_utils.s3_utils import download_file_from_s3,check_local_folder -from lambda_main.main_utils.online_entries.agent_base import build_agent_graph,tool_execution +from common_logic.common_utils.s3_utils import download_file_from_s3, check_local_folder +from lambda_main.main_utils.online_entries.agent_base import build_agent_graph, tool_execution from functions import get_tool_by_name -data_bucket_name = os.environ.get("RES_BUCKET", "aws-chatbot-knowledge-base-test") +data_bucket_name = os.environ.get( + "RES_BUCKET", "aws-chatbot-knowledge-base-test") order_info_path = "/tmp/functions/retail_tools/lambda_order_info/order_info.json" check_local_folder(order_info_path) -download_file_from_s3(data_bucket_name, "retail_json/order_info.json", order_info_path) +download_file_from_s3( + data_bucket_name, "retail_json/order_info.json", order_info_path) order_dict = json.load(open(order_info_path)) logger = get_logger('retail_entry') + class ChatbotState(TypedDict): ########### input/output states ########### # inputs # origin event body event_body: dict # origianl input question - query: str + query: str # chat history between human and agent - chat_history: Annotated[list[dict], add_messages] + chat_history: Annotated[list[dict], add_messages] # complete chatbot config, consumed by all the nodes - chatbot_config: dict - goods_id:Any + chatbot_config: dict + goods_id: Any # websocket connection id for the agent - ws_connection_id: str + ws_connection_id: str # whether to enbale stream output via ws_connection_id - stream: bool + stream: bool # message id related to original input question - message_id: str = None + message_id: str = None # record running states of different nodes trace_infos: Annotated[list[str], add_messages] # whether to enbale trace info update via streaming ouput - enable_trace: bool + enable_trace: bool # outputs # final answer generated by whole app graph - answer: Any + answer: Any # information needed return to user, e.g. intention, context, figure and so on, anything you can get during execution extra_response: Annotated[dict, update_nest_dict] # addition kwargs which need to save into ddb @@ -67,13 +70,13 @@ class ChatbotState(TypedDict): ########### query rewrite states ########### # query rewrite results - query_rewrite: str = None + query_rewrite: str = None ########### intention detection states ########### # intention type of retrieved intention samples in search engine, e.g. OpenSearch - intent_type: str = None + intent_type: str = None # retrieved intention samples in search engine, e.g. OpenSearch - intent_fewshot_examples: list + intent_fewshot_examples: list # tools of retrieved intention samples in search engine, e.g. OpenSearch intent_fewshot_tools: list @@ -82,16 +85,16 @@ class ChatbotState(TypedDict): qq_match_results: list = [] contexts: str = None figure: list = None - + ########### agent states ########### # current output of agent agent_current_output: dict # record messages during agent tool choose and calling, including agent message, tool ouput and error messages - agent_tool_history: Annotated[list[dict], add_messages] + agent_tool_history: Annotated[list[dict], add_messages] # the maximum number that agent node can be called - agent_repeated_call_limit: int + agent_repeated_call_limit: int # the current call time of agent - agent_current_call_number: int # + agent_current_call_number: int # whehter the current call time is less than maximum number of agent call agent_repeated_call_validation: bool # function calling @@ -109,12 +112,13 @@ class ChatbotState(TypedDict): agent_llm_type: str other_chain_kwargs: dict + # class ChatbotState(TypedDict): # chatbot_config: dict # chatbot config -# query: str -# create_time: str -# ws_connection_id: str -# stream: bool +# query: str +# create_time: str +# ws_connection_id: str +# stream: bool # query_rewrite: str = None # query rewrite ret # intent_type: str = None # intent # intent_fewshot_examples: list @@ -126,12 +130,12 @@ class ChatbotState(TypedDict): # current_tool_execute_res: dict # debug_infos: Annotated[dict,update_nest_dict] # answer: Any # final answer -# current_monitor_infos: str +# current_monitor_infos: str # extra_response: Annotated[dict,update_nest_dict] # contexts: str = None # intent_fewshot_tools: list # # current_agent_intent_type: str = None -# function_calling_parsed_tool_calls:list +# function_calling_parsed_tool_calls:list # # current_agent_tools_def: list[dict] # # current_agent_model_id: str # agent_current_output: dict @@ -150,10 +154,11 @@ class ChatbotState(TypedDict): # nodes in lambdas # #################### + @node_monitor_wrapper def query_preprocess(state: ChatbotState): - output:str = invoke_lambda( - event_body={**state,"chat_history":[]}, + output: str = invoke_lambda( + event_body={**state, "chat_history": []}, lambda_name="Online_Query_Preprocess", lambda_module_path="lambda_query_preprocess.query_preprocess", handler_name="lambda_handler" @@ -161,9 +166,10 @@ def query_preprocess(state: ChatbotState): state['extra_response']['query_rewrite'] = output send_trace(f"\n\n **query_rewrite:** \n{output}") return { - "query_rewrite":output, - "current_monitor_infos":f"query_rewrite: {output}" - } + "query_rewrite": output, + "current_monitor_infos": f"query_rewrite: {output}" + } + @node_monitor_wrapper def intention_detection(state: ChatbotState): @@ -171,53 +177,58 @@ def intention_detection(state: ChatbotState): lambda_module_path='lambda_intention_detection.intention', lambda_name="Online_Intention_Detection", handler_name="lambda_handler", - event_body=state + event_body=state ) state['extra_response']['intent_fewshot_examples'] = intent_fewshot_examples # send trace - send_trace(f"\n\nintention retrieved:\n{json.dumps(intent_fewshot_examples,ensure_ascii=False,indent=2)}", state["stream"], state["ws_connection_id"]) - intent_fewshot_tools:list[str] = list(set([e['intent'] for e in intent_fewshot_examples])) + send_trace(f"\n\nintention retrieved:\n{json.dumps(intent_fewshot_examples,ensure_ascii=False,indent=2)}", + state["stream"], state["ws_connection_id"]) + intent_fewshot_tools: list[str] = list( + set([e['intent'] for e in intent_fewshot_examples])) return { "intent_fewshot_examples": intent_fewshot_examples, "intent_fewshot_tools": intent_fewshot_tools, "intent_type": "other" - } + } + @node_monitor_wrapper def agent(state: ChatbotState): - goods_info = state.get('goods_info',None) or "" - agent_tool_history = state.get('agent_tool_history',"") - if agent_tool_history and hasattr(agent_tool_history[-1],'additional_kwargs'): - search_result = agent_tool_history[-1]['additional_kwargs']['original'][0].get('search_result',1) + goods_info = state.get('goods_info', None) or "" + agent_tool_history = state.get('agent_tool_history', "") + if agent_tool_history and hasattr(agent_tool_history[-1], 'additional_kwargs'): + search_result = agent_tool_history[-1]['additional_kwargs']['original'][0].get( + 'search_result', 1) if search_result == 0: - context = agent_tool_history[-1]['additional_kwargs']['original'][0].get('result',"") + context = agent_tool_history[-1]['additional_kwargs']['original'][0].get( + 'result', "") system_prompt = ("你是安踏的客服助理,正在帮消费者解答问题,消费者提出的问题大多是属于商品的质量和物流规则。context列举了一些可能有关的具体场景及回复,你可以进行参考:\n" - "\n" - f"{context}\n" - "" - "你需要按照下面的guidelines对消费者的问题进行回答:\n" - "\n" - " - 回答内容为一句话,言简意赅。\n" - " - 如果问题与context内容不相关,就不要采用。\n" - " - 消费者的问题里面可能包含口语化的表达,比如鞋子开胶的意思是用胶黏合的鞋体裂开。这和胶丝遗留没有关系。\n" - ' - 如果问题涉及到订单号,请回复: "请稍等,正在帮您查询订单。"' - "" - ) + "\n" + f"{context}\n" + "" + "你需要按照下面的guidelines对消费者的问题进行回答:\n" + "\n" + " - 回答内容为一句话,言简意赅。\n" + " - 如果问题与context内容不相关,就不要采用。\n" + " - 消费者的问题里面可能包含口语化的表达,比如鞋子开胶的意思是用胶黏合的鞋体裂开。这和胶丝遗留没有关系。\n" + ' - 如果问题涉及到订单号,请回复: "请稍等,正在帮您查询订单。"' + "" + ) query = state['query'] # print('llm config',state['chatbot_config']['rag_product_aftersales_config']['llm_config']) - output:str = invoke_lambda( + output: str = invoke_lambda( lambda_name='Online_LLM_Generate', lambda_module_path="lambda_llm_generate.llm_generate", handler_name='lambda_handler', event_body={ "llm_config": { - **state['chatbot_config']['rag_product_aftersales_config']['llm_config'], + **state['chatbot_config']['rag_product_aftersales_config']['llm_config'], "system_prompt": system_prompt, - "intent_type": LLMTaskType.CHAT - }, - "llm_input": { "query": query, "chat_history": state['chat_history']} - } + "intent_type": LLMTaskType.CHAT + }, + "llm_input": {"query": query, "chat_history": state['chat_history']} + } ) agent_current_call_number = state['agent_current_call_number'] + 1 agent_current_output = {} @@ -236,34 +247,33 @@ def agent(state: ChatbotState): tool_execute_res = state['agent_tool_history'][-1]['additional_kwargs']['raw_tool_call_results'][0] tool_name = tool_execute_res['name'] output = tool_execute_res['output'] - tool = get_tool_by_name(tool_name,scene=SceneType.RETAIL) + tool = get_tool_by_name(tool_name, scene=SceneType.RETAIL) if tool.running_mode == ToolRuningMode.ONCE: send_trace("once tool") return { "answer": str(output['result']), "function_calling_is_run_once": True } - + other_chain_kwargs = { - "goods_info": goods_info, - "create_time": state['create_time'], - "agent_current_call_number":state['agent_current_call_number'] - } - + "goods_info": goods_info, + "create_time": state['create_time'], + "agent_current_call_number": state['agent_current_call_number'] + } + response = app_agent.invoke({ **state, - "other_chain_kwargs":other_chain_kwargs + "other_chain_kwargs": other_chain_kwargs }) return response - @node_monitor_wrapper def final_rag_retriever_lambda(state: ChatbotState): # call retriever retriever_params = state["chatbot_config"]["final_rag_retriever"]["retriever_config"] retriever_params["query"] = state["query"] - output:str = invoke_lambda( + output: str = invoke_lambda( event_body=retriever_params, lambda_name="Online_Functions", lambda_module_path="functions.functions_utils.retriever.retriever", @@ -275,31 +285,32 @@ def final_rag_retriever_lambda(state: ChatbotState): send_trace(f'**final_rag_retriever** {context}') return {"contexts": contexts} + @node_monitor_wrapper -def final_rag_llm_lambda(state:ChatbotState): +def final_rag_llm_lambda(state: ChatbotState): context = "\n\n".join(state['contexts']) system_prompt = ("你是安踏的客服助理,正在帮消费者解答售前或者售后的问题。 中列举了一些可能有关的具体场景及回复,你可以进行参考:\n" - "\n" - f"{context}\n" - "\n" - "你需要按照下面的guidelines对消费者的问题进行回答:\n" - "\n" - " - 回答内容为一句话,言简意赅。\n" - " - 如果问题与context内容不相关,就不要采用。\n" - "\n" - ) - output:str = invoke_lambda( + "\n" + f"{context}\n" + "\n" + "你需要按照下面的guidelines对消费者的问题进行回答:\n" + "\n" + " - 回答内容为一句话,言简意赅。\n" + " - 如果问题与context内容不相关,就不要采用。\n" + "\n" + ) + output: str = invoke_lambda( lambda_name='Online_LLM_Generate', lambda_module_path="lambda_llm_generate.llm_generate", handler_name='lambda_handler', event_body={ "llm_config": { **state['chatbot_config']['final_rag_retriever']['llm_config'], - "system_prompt":system_prompt, + "system_prompt": system_prompt, "intent_type": LLMTaskType.CHAT}, - "llm_input": { "query": state["query"], "chat_history": state['chat_history']} - } - ) + "llm_input": {"query": state["query"], "chat_history": state['chat_history']} + } + ) return {"answer": output} # def transfer_reply(state:ChatbotState): @@ -315,16 +326,16 @@ def final_rag_llm_lambda(state:ChatbotState): # recent_tool_calling:list[dict] = state['function_calling_parsed_tool_calls'][0] # return {"answer": recent_tool_calling['kwargs']['response']} -def rule_url_reply(state:ChatbotState): +def rule_url_reply(state: ChatbotState): state["extra_response"]["current_agent_intent_type"] = "rule reply" - if state['query'].endswith(('.jpg','.png')): + if state['query'].endswith(('.jpg', '.png')): answer = random.choice([ "收到,亲。请问我们可以怎么为您效劳呢?", "您好,请问有什么需要帮助的吗?" ]) return {"answer": answer} # product information - r = re.findall(r"item.htm\?id=(.*)",state['query']) + r = re.findall(r"item.htm\?id=(.*)", state['query']) if r: goods_id = r[0] else: @@ -335,37 +346,38 @@ def rule_url_reply(state:ChatbotState): output = f"您好,该商品的特点是:\n{human_goods_info}" if human_goods_info: system_prompt = (f"你是安踏的客服助理,当前用户对下面的商品感兴趣:\n" - f"<{goods_info_tag}>\n{human_goods_info}\n\n" - "请你结合商品的基础信息,特别是卖点信息返回一句推荐语。" - ) - output:str = invoke_lambda( + f"<{goods_info_tag}>\n{human_goods_info}\n\n" + "请你结合商品的基础信息,特别是卖点信息返回一句推荐语。" + ) + output: str = invoke_lambda( lambda_name='Online_LLM_Generate', lambda_module_path="lambda_llm_generate.llm_generate", handler_name='lambda_handler', event_body={ "llm_config": { - **state['chatbot_config']['rag_daily_reception_config']['llm_config'], + **state['chatbot_config']['rag_daily_reception_config']['llm_config'], "system_prompt": system_prompt, "intent_type": LLMTaskType.CHAT}, "llm_input": {"query": state['query'], "chat_history": state['chat_history']} - } - ) - - return {"answer":output} - - return {"answer":"您好"} - -def rule_number_reply(state:ChatbotState): + } + ) + + return {"answer": output} + + return {"answer": "您好"} + + +def rule_number_reply(state: ChatbotState): state["extra_response"]["current_agent_intent_type"] = "rule reply" - return {"answer":"收到订单信息"} + return {"answer": "收到订单信息"} def final_results_preparation(state: ChatbotState): state['ddb_additional_kwargs'] = { - "goods_id":state['goods_id'], - "current_agent_intent_type":state['extra_response'].get('current_agent_intent_type',"") - } - app_response = process_response(state['event_body'],state) + "goods_id": state['goods_id'], + "current_agent_intent_type": state['extra_response'].get('current_agent_intent_type', "") + } + app_response = process_response(state['event_body'], state) return {"app_response": app_response} @@ -373,7 +385,7 @@ def final_results_preparation(state: ChatbotState): # define edges # ################ -def query_route(state:dict): +def query_route(state: dict): # check if rule reply query = state['query'] is_all_url = True @@ -382,20 +394,20 @@ def query_route(state:dict): is_all_url = False if is_all_url: return "url" - if query.isnumeric() and len(query)>=8: + if query.isnumeric() and len(query) >= 8: return "number" else: return "continue" -def intent_route(state:dict): - return state['intent_type'] +def intent_route(state: dict): + return state['intent_type'] def agent_route(state: dict): - if state.get("function_calling_is_run_once",False): + if state.get("function_calling_is_run_once", False): return "no need tool calling" - + state["agent_repeated_call_validation"] = state['agent_current_call_number'] < state['agent_repeated_call_limit'] if state["agent_repeated_call_validation"]: @@ -404,13 +416,14 @@ def agent_route(state: dict): # TODO give final strategy raise 'final rag' - + ############################# # define whole online graph # ############################# app_agent = None + def build_graph(chatbot_state_cls): workflow = StateGraph(chatbot_state_cls) # add all nodes @@ -418,8 +431,8 @@ def build_graph(chatbot_state_cls): workflow.add_node("intention_detection", intention_detection) workflow.add_node("agent", agent) workflow.add_node("tools_execution", tool_execution) - workflow.add_node("rule_url_reply",rule_url_reply) - workflow.add_node("rule_number_reply",rule_number_reply) + workflow.add_node("rule_url_reply", rule_url_reply) + workflow.add_node("rule_number_reply", rule_number_reply) # workflow.add_node("rag_promotion_retriever",rag_promotion_retriever_lambda) # workflow.add_node("rag_promotion_llm",rag_promotion_llm_lambda) # workflow.add_node("final_rag_retriever",final_rag_retriever_lambda) @@ -429,8 +442,8 @@ def build_graph(chatbot_state_cls): # add all edges workflow.set_entry_point("query_preprocess") - workflow.add_edge("intention_detection","agent") - workflow.add_edge("tools_execution","agent") + workflow.add_edge("intention_detection", "agent") + workflow.add_edge("tools_execution", "agent") # workflow.add_edge("agent",'parse_tool_calling') # workflow.add_edge("rag_daily_reception_retriever","rag_daily_reception_llm") # workflow.add_edge('rag_goods_exchange_retriever',"rag_goods_exchange_llm") @@ -438,7 +451,7 @@ def build_graph(chatbot_state_cls): # workflow.add_edge('rag_customer_complain_retriever',"rag_customer_complain_llm") # workflow.add_edge('rag_promotion_retriever',"rag_promotion_llm") # workflow.add_edge('final_rag_retriever',"final_rag_llm") - + # end # workflow.add_edge("transfer_reply",END) # workflow.add_edge("give_rhetorical_question",END) @@ -447,8 +460,8 @@ def build_graph(chatbot_state_cls): # workflow.add_edge("rag_goods_exchange_llm",END) # workflow.add_edge("rag_product_aftersales_llm",END) # workflow.add_edge("rag_customer_complain_llm",END) - workflow.add_edge('rule_url_reply',END) - workflow.add_edge('rule_number_reply',END) + workflow.add_edge('rule_url_reply', END) + workflow.add_edge('rule_number_reply', END) # workflow.add_edge("rag_promotion_llm",END) # workflow.add_edge("give_final_response",END) # workflow.add_edge("final_rag_llm",END) @@ -459,9 +472,9 @@ def build_graph(chatbot_state_cls): "query_preprocess", query_route, { - "url": "rule_url_reply", - "number": "rule_number_reply", - "continue": "intention_detection" + "url": "rule_url_reply", + "number": "rule_number_reply", + "continue": "intention_detection" } ) @@ -479,7 +492,9 @@ def build_graph(chatbot_state_cls): app = workflow.compile() return app -app = None + +app = None + def _prepare_chat_history(event_body): if "history_config" in event_body["chatbot_config"]: @@ -493,58 +508,62 @@ def _prepare_chat_history(event_body): current_chat['content'] = hist['content'] current_chat['addional_kwargs'] = {} if 'goods_id' in hist['additional_kwargs']: - current_chat['addional_kwargs']['goods_id'] = str(hist['additional_kwargs']['goods_id']) + current_chat['addional_kwargs']['goods_id'] = str( + hist['additional_kwargs']['goods_id']) chat_history_by_goods_id.append(current_chat) return chat_history_by_goods_id else: return event_body["chat_history"] + def retail_entry(event_body): """ Entry point for the Lambda function. :param event_body: The event body for lambda function. return: answer(str) """ - global app,app_agent + global app, app_agent if app is None: app = build_graph(ChatbotState) - + if app_agent is None: app_agent = build_agent_graph(ChatbotState) # debuging # TODO only write when run local if is_running_local(): - with open('retail_entry_workflow.png','wb') as f: + with open('retail_entry_workflow.png', 'wb') as f: f.write(app.get_graph().draw_mermaid_png()) - - with open('retail_entry_agent_workflow.png','wb') as f: + + with open('retail_entry_agent_workflow.png', 'wb') as f: f.write(app_agent.get_graph().draw_mermaid_png()) ################################################################################ # prepare inputs and invoke graph - event_body['chatbot_config'] = RetailConfigParser.from_chatbot_config(event_body['chatbot_config']) + event_body['chatbot_config'] = RetailConfigParser.from_chatbot_config( + event_body['chatbot_config']) chatbot_config = event_body['chatbot_config'] query = event_body['query'] stream = event_body['stream'] - create_time = chatbot_config.get('create_time',None) + create_time = chatbot_config.get('create_time', None) message_id = event_body['custom_message_id'] ws_connection_id = event_body['ws_connection_id'] enable_trace = chatbot_config["enable_trace"] goods_info_tag = "商品信息" - + goods_info = "" human_goods_info = "" goods_id = str(event_body['chatbot_config']['goods_id']) if goods_id: try: - _goods_info = json.loads(goods_dict.get(goods_id,{}).get("goods_info","")) - _goods_type = goods_dict.get(goods_id,{}).get("goods_type","") + _goods_info = json.loads(goods_dict.get( + goods_id, {}).get("goods_info", "")) + _goods_type = goods_dict.get(goods_id, {}).get("goods_type", "") except Exception as e: - import traceback + import traceback error = traceback.format_exc() - logger.error(f"error meesasge {error}, invalid goods_id: {goods_id}") + logger.error( + f"error meesasge {error}, invalid goods_id: {goods_id}") _goods_info = None - if _goods_info: logger.info(_goods_info) @@ -553,20 +572,21 @@ def retail_entry(event_body): else: goods_info = "" goods_info += f"<{goods_info_tag}>\n" - + human_goods_info = "" - for k,v in _goods_info.items(): - goods_info += f"{k}:{v}\n" - human_goods_info += f"{k}:{v}\n" - + for k, v in _goods_info.items(): + goods_info += f"{k}:{v}\n" + human_goods_info += f"{k}:{v}\n" + goods_info = goods_info.strip() goods_info += f"\n" use_history = chatbot_config['use_history'] chat_history = _prepare_chat_history(event_body) if use_history else [] event_body['chat_history'] = chat_history - logger.info(f'event_body:\n{json.dumps(event_body,ensure_ascii=False,indent=2,cls=JSONEncoder)}') - + logger.info( + f'event_body:\n{json.dumps(event_body,ensure_ascii=False,indent=2,cls=JSONEncoder)}') + logger.info(f"goods_info: {goods_info}") logger.info(f"chat_hisotry: {chat_history}") # invoke graph and get results @@ -583,16 +603,17 @@ def retail_entry(event_body): "ws_connection_id": ws_connection_id, "debug_infos": {}, "extra_response": {}, - "goods_info":goods_info, - "human_goods_info":human_goods_info, + "goods_info": goods_info, + "human_goods_info": human_goods_info, "agent_llm_type": LLMTaskType.RETAIL_TOOL_CALLING, - "query_rewrite_llm_type":LLMTaskType.RETAIL_CONVERSATION_SUMMARY_TYPE, + "query_rewrite_llm_type": LLMTaskType.RETAIL_CONVERSATION_SUMMARY_TYPE, "agent_repeated_call_limit": chatbot_config['agent_repeated_call_limit'], "agent_current_call_number": 0, - "current_agent_intent_type":"", + "current_agent_intent_type": "", "goods_info_tag": "商品信息", "goods_id": goods_id }) return response['app_response'] -main_chain_entry = retail_entry \ No newline at end of file + +main_chain_entry = retail_entry diff --git a/source/lambda/online/lambda_main/main_utils/parse_config.py b/source/lambda/online/lambda_main/main_utils/parse_config.py index b6364d5af..172b5c549 100644 --- a/source/lambda/online/lambda_main/main_utils/parse_config.py +++ b/source/lambda/online/lambda_main/main_utils/parse_config.py @@ -16,7 +16,8 @@ class ConfigParserBase: default_llm_config_str = "{'model_id': 'anthropic.claude-3-sonnet-20240229-v1:0', 'model_kwargs': {'temperature': 0.01, 'max_tokens': 4096}}" - default_index_names = {"intention": [], "private_knowledge": [], "qq_match": []} + default_index_names = {"intention": [], + "private_knowledge": [], "qq_match": []} default_retriever_config = { "intention": {"top_k": 5, "query_key": "query"}, "private_knowledge": { @@ -60,13 +61,15 @@ def from_chatbot_config(cls, chatbot_config: dict): chatbot_config = copy.deepcopy(chatbot_config) default_llm_config = cls.parse_default_llm_config(chatbot_config) default_index_names = cls.parse_default_index_names(chatbot_config) - default_retriever_config = cls.parse_default_retriever_config(chatbot_config) + default_retriever_config = cls.parse_default_retriever_config( + chatbot_config) group_name = chatbot_config["group_name"] chatbot_id = chatbot_config["chatbot_id"] # init chatbot config - chatbot_config_obj = ChatbotConfig(group_name=group_name, chatbot_id=chatbot_id) + chatbot_config_obj = ChatbotConfig( + group_name=group_name, chatbot_id=chatbot_id) # init default llm chatbot_config_obj.update_llm_config(default_llm_config) diff --git a/source/lambda/online/lambda_query_preprocess/query_preprocess.py b/source/lambda/online/lambda_query_preprocess/query_preprocess.py index 28955808e..9ae75bedc 100644 --- a/source/lambda/online/lambda_query_preprocess/query_preprocess.py +++ b/source/lambda/online/lambda_query_preprocess/query_preprocess.py @@ -3,17 +3,17 @@ RunnableLambda ) -from common_logic.common_utils.logger_utils import get_logger +from common_logic.common_utils.logger_utils import get_logger from common_logic.common_utils.langchain_utils import chain_logger -from common_logic.common_utils.lambda_invoke_utils import invoke_lambda,chatbot_lambda_call_wrapper,send_trace +from common_logic.common_utils.lambda_invoke_utils import invoke_lambda, chatbot_lambda_call_wrapper, send_trace from common_logic.common_utils.constant import LLMTaskType from common_logic.common_utils.prompt_utils import get_prompt_templates_from_ddb logger = get_logger("query_preprocess") -def conversation_query_rewrite(query:str, chat_history:list, message_id:str, trace_infos:list, chatbot_config:dict, query_rewrite_llm_type:str): - """rewrite query accoridng to chat history +def conversation_query_rewrite(query: str, chat_history: list, message_id: str, trace_infos: list, chatbot_config: dict, query_rewrite_llm_type: str): + """rewrite query according to chat history Args: query (str): input query from human @@ -31,15 +31,16 @@ def conversation_query_rewrite(query:str, chat_history:list, message_id:str, tra conversation_query_rewrite_config = chatbot_config["query_process_config"][ "conversation_query_rewrite_config" ] - + prompt_templates_from_ddb = get_prompt_templates_from_ddb( group_name, model_id=conversation_query_rewrite_config['model_id'], task_type=query_rewrite_llm_type, chatbot_id=chatbot_id ) - logger.info(f'conversation summary prompt templates: {prompt_templates_from_ddb}') - + logger.info( + f'conversation summary prompt templates: {prompt_templates_from_ddb}') + cqr_llm_chain = RunnableLambda(lambda x: invoke_lambda( lambda_name='Online_LLM_Generate', lambda_module_path="lambda_llm_generate.llm_generate", @@ -48,20 +49,24 @@ def conversation_query_rewrite(query:str, chat_history:list, message_id:str, tra "llm_config": {**prompt_templates_from_ddb, **conversation_query_rewrite_config, "intent_type": query_rewrite_llm_type - }, - "llm_input": {"chat_history":x['chat_history'], "query":x['query']} - } - ) + }, + "llm_input": {"chat_history": x['chat_history'], "query": x['query']} + } ) - - cqr_llm_chain = RunnableBranch( - # single turn - (lambda x: not x['chat_history'], RunnableLambda(lambda x:x['query'])), - cqr_llm_chain ) + rewrite_first_message = conversation_query_rewrite_config.get("rewrite_first_message", False) + logger.info("Rewrite first message: %s", str(rewrite_first_message)) + if not rewrite_first_message: + cqr_llm_chain = RunnableBranch( + # single turn + (lambda x: not x['chat_history'], + RunnableLambda(lambda x: x['query'])), + cqr_llm_chain + ) + conversation_summary_chain = chain_logger( - cqr_llm_chain, + cqr_llm_chain, "conversation_summary_chain", message_id=message_id, trace_infos=trace_infos @@ -69,16 +74,22 @@ def conversation_query_rewrite(query:str, chat_history:list, message_id:str, tra conversation_summary_input = {} conversation_summary_input["chat_history"] = chat_history conversation_summary_input["query"] = query - rewrite_query = conversation_summary_chain.invoke(conversation_summary_input) + rewrite_query = conversation_summary_chain.invoke( + conversation_summary_input) + return rewrite_query - + + @chatbot_lambda_call_wrapper -def lambda_handler(state:dict, context=None): - query = state.get("query","") - chat_history = state.get("chat_history",[]) - message_id = state.get('message_id',"") - trace_infos = state.get('trace_infos',[]) +def lambda_handler(state: dict, context=None): + query = state.get("query", "") + chat_history = state.get("chat_history", []) + message_id = state.get('message_id', "") + trace_infos = state.get('trace_infos', []) chatbot_config = state["chatbot_config"] - query_rewrite_llm_type = state.get("query_rewrite_llm_type",None) or LLMTaskType.CONVERSATION_SUMMARY_TYPE - output:dict = conversation_query_rewrite(query, chat_history, message_id, trace_infos, chatbot_config, query_rewrite_llm_type) + query_rewrite_llm_type = state.get( + "query_rewrite_llm_type", None) or LLMTaskType.CONVERSATION_SUMMARY_TYPE + output: dict = conversation_query_rewrite( + query, chat_history, message_id, trace_infos, chatbot_config, query_rewrite_llm_type) + return output diff --git a/source/lambda/online/lambda_query_preprocess/query_preprocess_utils/query_process_utils/query_process_utils.py b/source/lambda/online/lambda_query_preprocess/query_preprocess_utils/query_process_utils/query_process_utils.py index 73f97dfd1..f3b936834 100644 --- a/source/lambda/online/lambda_query_preprocess/query_preprocess_utils/query_process_utils/query_process_utils.py +++ b/source/lambda/online/lambda_query_preprocess/query_preprocess_utils/query_process_utils/query_process_utils.py @@ -145,7 +145,7 @@ def get_query_process_chain(chat_history, query_process_config, message_id=None) query_process_chain = preprocess_chain query_process_chain = ( conversation_query_rewrite_chain | preprocess_chain - ) # | stepback_promping_chain + ) # | stepback_promping_chain query_process_chain = chain_logger( query_process_chain, "query process module", message_id=message_id diff --git a/source/lambda/online/lambda_query_preprocess/query_preprocess_utils/service_intent_recognition/utils.py b/source/lambda/online/lambda_query_preprocess/query_preprocess_utils/service_intent_recognition/utils.py index 01e012a40..aeb5ffed4 100644 --- a/source/lambda/online/lambda_query_preprocess/query_preprocess_utils/service_intent_recognition/utils.py +++ b/source/lambda/online/lambda_query_preprocess/query_preprocess_utils/service_intent_recognition/utils.py @@ -1,7 +1,8 @@ import os import re -SERVICE_NAME_PATH = os.path.join(os.path.dirname(__file__), "service_names.txt") +SERVICE_NAME_PATH = os.path.join( + os.path.dirname(__file__), "service_names.txt") SERVICE_NAMES = None SERVICE_NAMES_UPPER = None diff --git a/source/portal/src/utils/const.ts b/source/portal/src/utils/const.ts index 24f80fd28..470831160 100644 --- a/source/portal/src/utils/const.ts +++ b/source/portal/src/utils/const.ts @@ -18,18 +18,16 @@ export const LIBRARY_DEFAULT_PREFIX = 'documents'; export const LLM_BOT_MODEL_LIST = [ 'anthropic.claude-3-sonnet-20240229-v1:0', 'anthropic.claude-3-haiku-20240307-v1:0', - // 'anthropic.claude-3-5-sonnet-20240620-v1:0', ]; export const LLM_BOT_COMMON_MODEL_LIST = [ + 'anthropic.claude-3-sonnet-20240229-v1:0', + 'anthropic.claude-3-haiku-20240307-v1:0', 'anthropic.claude-3-5-sonnet-20240620-v1:0', 'anthropic.claude-3-5-haiku-20241022-v1:0', 'meta.llama3-1-70b-instruct-v1:0', 'mistral.mistral-large-2407-v1:0', 'cohere.command-r-plus-v1:0', - 'anthropic.claude-3-sonnet-20240229-v1:0', - 'anthropic.claude-3-haiku-20240307-v1:0' - // 'anthropic.claude-3-5-sonnet-20240620-v1:0', ]; export const LLM_BOT_RETAIL_MODEL_LIST = [