Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agent): agent refactor #30

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 70 additions & 20 deletions app/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from langchain.agents import AgentType, initialize_agent
from langchain.agents import AgentType, AgentExecutor
from langchain.tools.render import render_text_description_and_args
from langchain.agents.output_parsers import (
ReActJsonSingleInputOutputParser,
)

from langchain.agents.format_scratchpad import format_log_to_messages

from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage

from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.pydantic_v1 import BaseModel, Field
from typing import List, Union
import logging

Expand Down Expand Up @@ -35,25 +47,61 @@ def __init__(self, llm_provider: LLM_Model, db_connection: "TigerGraphConnection
self.mq2s = MapQuestionToSchema(self.conn, self.llm.model, self.llm.map_question_schema_prompt)
self.gen_func = GenerateFunction(self.conn, self.llm.model, self.llm.generate_function_prompt, embedding_model, embedding_store)

tools = [self.mq2s, self.gen_func]
self.tools = [self.mq2s, self.gen_func]

logger.debug(f"request_id={req_id_cv.get()} agent tools created")
self.agent = initialize_agent(tools,
self.llm.model,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=False,
return_intermediate_steps=True,
#max_iterations=7,
early_stopping_method="generate",
handle_parsing_errors=True)

'''
agent_kwargs={
"prefix": """DIRECTLY TRANSFER THE OBSERVATION INTO ACTION INPUTS AS NECESSARY.
BE VERBOSE IN ACTION INPUTS AND THOUGHTS.
NEVER HALLUCINATE FUNCTION CALLS, MY JOB DEPENDS ON CORRECT ANSWERS.
ALWAYS USE THE MapQuestionToSchema TOOL BEFORE GenerateFunction.'"""
}
'''

system_message = f"""Answer the following questions as best you can. You can answer directly if the user is greeting you or similar.
Otherise, you have access to the following tools:

{render_text_description_and_args(self.tools).replace('{', '{{').replace('}', '}}')}

The way you use the tools is by specifying a json blob.
Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).
The only values that should be in the "action" field are: {[t.name for t in self.tools]}
The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
ALWAYS use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action:```$JSON_BLOB```
Observation: the result of the action... (this Thought/Action/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Begin! Reminder to always use the exact characters `Final Answer` when responding.'
"""

prompt = ChatPromptTemplate.from_messages(
[
(
"user",
system_message,
),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)

chat_model_with_stop = self.llm.model.bind(stop=["\nObservation"])
self.agent = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_log_to_messages(x["intermediate_steps"]),
"chat_history": lambda x: x['chat_history']
}
| prompt
| chat_model_with_stop
| ReActJsonSingleInputOutputParser()
)


logger.debug(f"request_id={req_id_cv.get()} agent initialized")

def question_for_agent(self, question: str):
Expand All @@ -68,9 +116,11 @@ def question_for_agent(self, question: str):
logger.info(f"request_id={req_id_cv.get()} ENTRY question_for_agent")
logger.debug_pii(f"request_id={req_id_cv.get()} question_for_agent question={question}")
try:
resp = self.agent({"input": question})
agent_executor = AgentExecutor(agent=self.agent, tools=self.tools, return_intermediate_steps=True)
resp = agent_executor.invoke({"input": question, "chat_history": []})
logger.info(f"request_id={req_id_cv.get()} EXIT question_for_agent")
return resp
except Exception as e:
print(e)
logger.error(f"request_id={req_id_cv.get()} FAILURE question_for_agent")
raise e
1 change: 1 addition & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def retrieve_answer(graphname, query: NaturalLanguageQuery, credentials: Annotat

try:
steps = agent.question_for_agent(query.query)
print(steps)
logger.debug(f"/{graphname}/query request_id={req_id_cv.get()} agent executed")
try:
generate_func_output = steps["intermediate_steps"][-1][-1]
Expand Down
3 changes: 3 additions & 0 deletions app/schemas/tool_io_schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from langchain.pydantic_v1 import BaseModel, Field
from typing import List, Dict, Type

class MapQuestionToSchemaInput(BaseModel):
query: str = Field(description="The user's question")

class MapQuestionToSchemaResponse(BaseModel):
question: str = Field(description="The question restated in terms of the graph schema")
target_vertex_types: List[str] = Field(description="The list of vertices mentioned in the question. If there are no vertices mentioned, then use an empty list.")
Expand Down
5 changes: 3 additions & 2 deletions app/tools/map_schema_to_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from langchain.output_parsers import PydanticOutputParser
from pyTigerGraph import TigerGraphConnection
from langchain.pydantic_v1 import BaseModel, Field, validator
from app.schemas import MapQuestionToSchemaResponse, MapAttributeToAttributeResponse
from typing import List, Dict
from app.schemas import MapQuestionToSchemaResponse, MapAttributeToAttributeResponse, MapQuestionToSchemaInput
from typing import List, Dict, Type
from .validation_utils import validate_schema, MapQuestionToSchemaException
import re
import logging
Expand All @@ -26,6 +26,7 @@ class MapQuestionToSchema(BaseTool):
llm: LLM = None
prompt: str = None
handle_tool_error: bool = True
args_schema: Type[MapQuestionToSchemaInput] = MapQuestionToSchemaInput

def __init__(self, conn, llm, prompt):
""" Initialize MapQuestionToSchema.
Expand Down
Loading