Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support query rewrite for first sentence #436

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
this.dynamodbStatement = createKnowledgeBaseTablesAndPoliciesResult.dynamodbStatement;

this.sfnOutput = this.createKnowledgeBaseJob(props);

}

private createKnowledgeBaseTablesAndPolicies(props: any) {
Expand Down Expand Up @@ -153,7 +153,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac

// If this.region is cn-north-1 or cn-northwest-1, use the glue-job-script-cn.py
const glueJobScript = "glue-job-script.py";


const extraPythonFiles = new s3deploy.BucketDeployment(
this,
Expand All @@ -172,9 +172,9 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
const extraPythonFilesList = [
this.glueLibS3Bucket.s3UrlForObject("llm_bot_dep-0.1.0-py3-none-any.whl"),
].join(",");





const glueRole = new iam.Role(this, "ETLGlueJobRole", {
assumedBy: new iam.ServicePrincipal("glue.amazonaws.com"),
// The role is used by the glue job to access AOS and by default it has 1 hour session duration which is not enough for the glue job to finish the embedding injection
Expand Down
6 changes: 2 additions & 4 deletions source/lambda/online/common_logic/common_utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ class LLMTaskType(ConstantBase):
AUTO_EVALUATION = "auto_evaluation"



class MessageType(ConstantBase):
HUMAN_MESSAGE_TYPE = 'human'
AI_MESSAGE_TYPE = 'ai'
Expand Down Expand Up @@ -149,7 +148,6 @@ class LLMModelType(ConstantBase):
COHERE_COMMAND_R_PLUS = "cohere.command-r-plus-v1:0"



class EmbeddingModelType(ConstantBase):
BEDROCK_TITAN_V1 = "amazon.titan-embed-text-v1"

Expand Down Expand Up @@ -179,7 +177,8 @@ class IndexTag(Enum):

@unique
class KBType(Enum):
AOS = "aos"
AOS = "aos"


GUIDE_INTENTION_NOT_FOUND = "Intention not found, please add intentions first when using agent mode, refer to https://amzn-chn.feishu.cn/docx/HlxvduJYgoOz8CxITxXc43XWn8e"
INDEX_DESC = "Answer question based on search result"
Expand All @@ -188,4 +187,3 @@ class KBType(Enum):
class Threshold(ConstantBase):
QQ_IN_RAG_CONTEXT = 0.5
INTENTION_ALL_KNOWLEDGE_RETRIEVAL = 0.4

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib
import json
import time
import os
import os
from typing import Any, Dict, Optional, Callable, Union
import threading

Expand Down Expand Up @@ -34,23 +34,23 @@

class StateContext:

def __init__(self,state):
self.state=state
def __init__(self, state):
self.state = state

@classmethod
def get_current_state(cls):
# print("thread id",threading.get_ident(),'parent id',threading.)
# state = getattr(thread_local,'state',None)
state = CURRENT_STATE
assert state is not None,"There is not a valid state in current context"
assert state is not None, "There is not a valid state in current context"
return state

@classmethod
def set_current_state(cls, state):
global CURRENT_STATE
global CURRENT_STATE
assert CURRENT_STATE is None, "Parallel node executions are not alowed"
CURRENT_STATE = state

@classmethod
def clear_state(cls):
global CURRENT_STATE
Expand Down Expand Up @@ -108,7 +108,8 @@ def validate_environment(cls, values: Dict):
session = boto3.Session()
values["client"] = session.client(
"lambda",
region_name=values.get("region_name",os.environ['AWS_REGION'])
region_name=values.get(
"region_name", os.environ['AWS_REGION'])
)
except Exception as e:
raise ValueError(
Expand Down Expand Up @@ -320,16 +321,17 @@ def wrapper(state: Dict[str, Any]) -> Dict[str, Any]:
current_stream_use, ws_connection_id, enable_trace)
state['trace_infos'].append(
f"Enter: {func.__name__}, time: {time.time()}")
with StateContext(state):

with StateContext(state):
output = func(state)

current_monitor_infos = output.get(monitor_key, None)
if current_monitor_infos is not None:
send_trace(f"\n\n {current_monitor_infos}",
current_stream_use, ws_connection_id, enable_trace)
exit_time = time.time()
state['trace_infos'].append(f"Exit: {func.__name__}, time: {time.time()}")
state['trace_infos'].append(
f"Exit: {func.__name__}, time: {time.time()}")
send_trace(f"\n\n Elapsed time: {round((exit_time-enter_time)*100)/100} s",
current_stream_use, ws_connection_id, enable_trace)
return output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,4 @@ def format_trace_infos(trace_infos: list, use_pretty_table=True):


class NestUpdateState(TypedDict):
keys: Annotated[dict, update_nest_dict]
keys: Annotated[dict, update_nest_dict]
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,3 @@ def _inner(*args, **kwargs):
print_llm_messages(kwargs)
return fn(*args, **kwargs)
return _inner

Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def format_qq_data(data) -> str:
"""
if is_null_or_empty(data):
return ""

markdown_table = "**QQ Match Result**\n"
markdown_table += "| Source | Score | Question | Answer |\n"
markdown_table += "|-----|-----|-----|-----|\n"
Expand Down Expand Up @@ -66,7 +66,7 @@ def format_rag_data(data, qq_result) -> str:
score = item.get("score", -1)
page_content = item.get("retrieval_content", "").replace("\n", "<br>")
markdown_table += f"| {source} | {raw_source} | {score} | {page_content} |\n"

if not is_null_or_empty(qq_result):
markdown_table += "\n**QQ Match Result**\n"
markdown_table += "| Source File Name | Source URI | Score | Question | Answer |\n"
Expand Down
14 changes: 6 additions & 8 deletions source/lambda/online/common_logic/common_utils/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def prompt_template_render(self, prompt_template: dict):
)


################
################
# query rewrite prompt template from paper https://arxiv.org/pdf/2401.10225
###################
CQR_SYSTEM_PROMPT = """You are a helpful, pattern-following assistant."""
Expand Down Expand Up @@ -273,7 +273,7 @@ def prompt_template_render(self, prompt_template: dict):
LLMModelType.LLAMA3_2_90B_INSTRUCT,
LLMModelType.MISTRAL_LARGE_2407,
LLMModelType.COHERE_COMMAND_R_PLUS,

],
task_type=LLMTaskType.CONVERSATION_SUMMARY_TYPE,
prompt_template=CQR_SYSTEM_PROMPT,
Expand Down Expand Up @@ -330,7 +330,6 @@ def prompt_template_render(self, prompt_template: dict):
)



############## xml agent prompt #############
# AGENT_USER_PROMPT = "你是一个AI助理。今天是{date},{weekday}. "
# register_prompt_templates(
Expand Down Expand Up @@ -397,31 +396,31 @@ def prompt_template_render(self, prompt_template: dict):
</context>"""

# AGENT_SYSTEM_PROMPT = """\
# You are a helpful and honest AI assistant. Today is {date},{weekday}.
# You are a helpful and honest AI assistant. Today is {date},{weekday}.
# Here are some guidelines for you:
# <guidlines>
# - Output your step by step thinking in one pair of <thinking> and </thinking> tags, here are steps for you to think about deciding to use which tool:
# 1. If the context contains the result of last tool call, it needs to be analyzed.
# 2. Determine whether the current context is sufficient to answer the user's question.
# 3. If the current context is sufficient to answer the user's question, call the `give_final_response` tool.
# 4. If the current context is not sufficient to answer the user's question, you can consider calling one of the provided tools.
# 5. If any of required parameters of the tool you want to call do not appears in context, call the `give_rhetorical_question` tool to ask the user for more information.
# 5. If any of required parameters of the tool you want to call do not appears in context, call the `give_rhetorical_question` tool to ask the user for more information.
# - Always output with the same language as the content from user. If the content is English, use English to output. If the content is Chinese, use Chinese to output.
# - Always invoke one tool.
# - Before invoking any tool, be sure to first output your thought process in one pair of <thinking> and </thinking> tag.
# </guidlines>"""


# AGENT_SYSTEM_PROMPT = """\
# You are a helpful and honest AI assistant. Today is {date},{weekday}.
# You are a helpful and honest AI assistant. Today is {date},{weekday}.
# Here are some guidelines for you:
# <guidlines>
# - Output your step by step thinking in one pair of <thinking> and </thinking> tags, here are steps for you to think about deciding to use which tool:
# 1. If the context contains the result of last tool call, it needs to be analyzed.
# 2. Determine whether the current context is sufficient to answer the user's question.
# 3. If the current context is sufficient to answer the user's question, call the `give_final_response` tool.
# 4. If the current context is not sufficient to answer the user's question, you can consider calling one of the provided tools.
# 5. If any of required parameters of the tool you want to call do not appears in context, call the `give_rhetorical_question` tool to ask the user for more information.
# 5. If any of required parameters of the tool you want to call do not appears in context, call the `give_rhetorical_question` tool to ask the user for more information.
# - Always output with the same language as the content from user. If the content is English, use English to output. If the content is Chinese, use Chinese to output.
# - Always invoke one tool.
# </guidlines>
Expand Down Expand Up @@ -505,6 +504,5 @@ def prompt_template_render(self, prompt_template: dict):
)



if __name__ == "__main__":
print(get_all_templates())
21 changes: 12 additions & 9 deletions source/lambda/online/common_logic/common_utils/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ class LLMConfig(AllowBaseModel):
model_kwargs: dict = {"temperature": 0.01, "max_tokens": 4096}


class QueryRewriteConfig(LLMConfig):
rewrite_first_message: bool = False

class QueryProcessConfig(ForbidBaseModel):
conversation_query_rewrite_config: LLMConfig = Field(
default_factory=LLMConfig)
conversation_query_rewrite_config: QueryRewriteConfig = Field(
default_factory=QueryRewriteConfig)


class RetrieverConfigBase(AllowBaseModel):
Expand Down Expand Up @@ -88,7 +91,7 @@ class RagToolConfig(AllowBaseModel):

class AgentConfig(ForbidBaseModel):
llm_config: LLMConfig = Field(default_factory=LLMConfig)
tools: list[Union[str,dict]] = Field(default_factory=list)
tools: list[Union[str, dict]] = Field(default_factory=list)
only_use_rag_tool: bool = False


Expand Down Expand Up @@ -178,8 +181,7 @@ def get_index_infos_from_ddb(cls, group_name, chatbot_id):
cls.format_index_info(info)
for info in list(index_info["value"].values())
]
infos[index_type] = {info["index_name"]
: info for info in info_list}
infos[index_type] = {info["index_name"] : info for info in info_list}

for index_type in IndexType.all_values():
if index_type not in infos:
Expand All @@ -194,17 +196,18 @@ def update_retrievers(
):
index_infos = self.get_index_infos_from_ddb(
self.group_name, self.chatbot_id)
print(f"index_infos: {index_infos}")
print(f"default_index_names: {default_index_names}")
print(f"default_retriever_config: {default_retriever_config}")
logger.info(f"index_infos: {index_infos}")
logger.info(f"default_index_names: {default_index_names}")
logger.info(f"default_retriever_config: {default_retriever_config}")
for task_name, index_names in default_index_names.items():
assert task_name in ("qq_match", "intention", "private_knowledge")
if task_name == "qq_match":
index_type = IndexType.QQ
elif task_name == "intention":
index_type = IndexType.INTENTION
elif task_name == "private_knowledge":
index_type = IndexType.QD
else:
raise ValueError(f"Invalid task_name: {task_name}")

# default to use all index
if not index_names:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections.abc

def update_nest_dict(d:dict, u:dict):

def update_nest_dict(d: dict, u: dict):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = update_nest_dict(d.get(k, {}), v)
Expand Down
Loading
Loading