Skip to content

Commit

Permalink
Merge pull request #166 from tigergraph/GML-1716-rev1-prompt
Browse files Browse the repository at this point in the history
feat(agent): add output format instructions to all steps
  • Loading branch information
parkererickson-tg authored Jun 6, 2024
2 parents 5cfc9e0 + 4f5ed08 commit 674850c
Show file tree
Hide file tree
Showing 15 changed files with 185 additions and 72 deletions.
21 changes: 15 additions & 6 deletions app/agent/agent_generation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from langchain.prompts import PromptTemplate
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import PydanticOutputParser
from app.tools.logwriter import LogWriter
import logging
from app.log import req_id_cv
from langchain.pydantic_v1 import BaseModel, Field

logger = logging.getLogger(__name__)

class CoPilotAnswerOutput(BaseModel):
generated_answer: str = Field(description="The generated answer to the question. Make sure maintain a professional tone and keep the answer consice.")

class TigerGraphAgentGenerator:
def __init__(self, llm_model):
self.llm = llm_model
Expand All @@ -20,17 +23,23 @@ def generate_answer(self, question: str, context: str) -> str:
str: The answer to the question.
"""
LogWriter.info(f"request_id={req_id_cv.get()} ENTRY generate_answer")

answer_parser = PydanticOutputParser(pydantic_object=CoPilotAnswerOutput)

prompt = PromptTemplate(
template="""Given the question and the context, generate an answer. \n
Make sure to answer the question in a friendly and informative way. \n
Question: {question} \n
Context: {context}""",
input_variables=["question", "context"]
Context: {context}
Format: {format_instructions}""",
input_variables=["question", "context"],
partial_variables={
"format_instructions": answer_parser.get_format_instructions()
}
)

# Chain
rag_chain = prompt | self.llm.model | StrOutputParser()

rag_chain = prompt | self.llm.model | answer_parser
generation = rag_chain.invoke({"context": context, "question": question})
LogWriter.info(f"request_id={req_id_cv.get()} EXIT generate_answer")
return generation
144 changes: 102 additions & 42 deletions app/agent/agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from app.py_schemas import (MapQuestionToSchemaResponse,
CoPilotResponse)

from pyTigerGraph.pyTigerGraphException import TigerGraphException

import logging
from app.log import req_id_cv

Expand Down Expand Up @@ -53,6 +55,13 @@ def __init__(self, llm_provider,
self.cypher_gen = cypher_gen_tool
self.enable_human_in_loop = enable_human_in_loop

self.supportai_enabled = True
try:
self.db_connection.getQueryMetadata("HNSW_Overlap")
except TigerGraphException as e:
logger.info("HNSW_Overlap not found in the graph. Disabling supportai.")
self.supportai_enabled = False

def route_question(self, state):
"""
Run the agent router.
Expand All @@ -64,11 +73,14 @@ def route_question(self, state):
return "apologize"
state["question_retry_count"] += 1
logger.debug_pii(f"request_id={req_id_cv.get()} Routing question: {state['question']}")
source = step.route_question(state['question'])
logger.debug_pii(f"request_id={req_id_cv.get()} Routing question to: {source}")
if source.datasource == "vectorstore":
return "supportai_lookup"
elif source.datasource == "functions":
if self.supportai_enabled:
source = step.route_question(state['question'])
logger.debug_pii(f"request_id={req_id_cv.get()} Routing question to: {source}")
if source.datasource == "vectorstore":
return "supportai_lookup"
elif source.datasource == "functions":
return "inquiryai_lookup"
else:
return "inquiryai_lookup"

def apologize(self, state):
Expand Down Expand Up @@ -152,11 +164,16 @@ def generate_answer(self, state):
"""
step = TigerGraphAgentGenerator(self.llm_provider)
logger.debug_pii(f"request_id={req_id_cv.get()} Generating answer for question: {state['question']}")
answer = step.generate_answer(state['question'], state["context"])
logger.debug_pii(f"request_id={req_id_cv.get()} Generated answer: {answer}")
if state["lookup_source"] == "supportai":
answer = step.generate_answer(state['question'], state["context"])
elif state["lookup_source"] == "inquiryai":
answer = step.generate_answer(state['question'], state["context"]["result"])
elif state["lookup_source"] == "cypher":
answer = step.generate_answer(state['question'], state["context"]["answer"])
logger.debug_pii(f"request_id={req_id_cv.get()} Generated answer: {answer.generated_answer}")

try:
resp = CoPilotResponse(natural_language_response=answer,
resp = CoPilotResponse(natural_language_response=answer.generated_answer,
answered_question=True,
response_type=state["lookup_source"],
query_sources=state["context"])
Expand All @@ -183,7 +200,7 @@ def check_answer_for_hallucinations(self, state):
"""
step = TigerGraphAgentHallucinationCheck(self.llm_provider)
hallucinations = step.check_hallucination(state["answer"].natural_language_response, state["context"])
if hallucinations["score"] == "yes":
if hallucinations.score == "yes":
return "grounded"
else:
return "hallucination"
Expand All @@ -194,7 +211,7 @@ def check_answer_for_usefulness(self, state):
"""
step = TigerGraphAgentUsefulnessCheck(self.llm_provider)
usefulness = step.check_usefulness(state["question"], state["answer"].natural_language_response)
if usefulness["score"] == "yes":
if usefulness.score == "yes":
return "useful"
else:
return "not_useful"
Expand Down Expand Up @@ -236,7 +253,8 @@ def create_graph(self):
self.workflow.add_node("generate_answer", self.generate_answer)
self.workflow.add_node("map_question_to_schema", self.map_question_to_schema)
self.workflow.add_node("generate_function", self.generate_function)
self.workflow.add_node("hnsw_overlap_search", self.hnsw_overlap_search)
if self.supportai_enabled:
self.workflow.add_node("hnsw_overlap_search", self.hnsw_overlap_search)
self.workflow.add_node("rewrite_question", self.rewrite_question)
self.workflow.add_node("apologize", self.apologize)

Expand All @@ -256,44 +274,86 @@ def create_graph(self):
"error": "apologize",
"success": "generate_answer"
})
if self.supportai_enabled:
self.workflow.add_conditional_edges(
"generate_answer",
self.check_answer_for_usefulness_and_hallucinations,
{
"hallucination": "rewrite_question",
"grounded": END,
"inquiryai_not_useful": "generate_cypher",
"cypher_not_useful": "hnsw_overlap_search",
"supportai_not_useful": "map_question_to_schema"
}
)
else:
self.workflow.add_conditional_edges(
"generate_answer",
self.check_answer_for_usefulness_and_hallucinations,
{
"hallucination": "rewrite_question",
"grounded": END,
"inquiryai_not_useful": "generate_cypher",
"cypher_not_useful": "apologize"
}
)
else:
self.workflow.add_conditional_edges(
"generate_answer",
self.check_answer_for_usefulness_and_hallucinations,
{
"hallucination": "rewrite_question",
"grounded": END,
"inquiryai_not_useful": "generate_cypher",
"cypher_not_useful": "hnsw_overlap_search",
"supportai_not_useful": "map_question_to_schema"
}
"generate_function",
self.check_state_for_generation_error,
{
"error": "rewrite_question",
"success": "generate_answer"
}
)
if self.supportai_enabled:
self.workflow.add_conditional_edges(
"generate_answer",
self.check_answer_for_usefulness_and_hallucinations,
{
"hallucination": "rewrite_question",
"grounded": END,
"not_useful": "rewrite_question",
"inquiryai_not_useful": "hnsw_overlap_search",
"supportai_not_useful": "map_question_to_schema"
}
)
else:
self.workflow.add_conditional_edges(
"generate_answer",
self.check_answer_for_usefulness_and_hallucinations,
{
"hallucination": "rewrite_question",
"grounded": END,
"not_useful": "rewrite_question",
"inquiryai_not_useful": "apologize",
"supportai_not_useful": "map_question_to_schema"
}
)

if self.supportai_enabled:
self.workflow.add_conditional_edges(
"entry",
self.route_question,
{
"supportai_lookup": "hnsw_overlap_search",
"inquiryai_lookup": "map_question_to_schema",
"apologize": "apologize"
}
)
else:
self.workflow.add_edge("generate_function", "generate_answer")
self.workflow.add_conditional_edges(
"generate_answer",
self.check_answer_for_usefulness_and_hallucinations,
{
"hallucination": "rewrite_question",
"grounded": END,
"not_useful": "rewrite_question",
"inquiryai_not_useful": "hnsw_overlap_search",
"supportai_not_useful": "map_question_to_schema"
}
"entry",
self.route_question,
{
"inquiryai_lookup": "map_question_to_schema",
"apologize": "apologize"
}
)


self.workflow.add_conditional_edges(
"entry",
self.route_question,
{
"supportai_lookup": "hnsw_overlap_search",
"inquiryai_lookup": "map_question_to_schema",
"apologize": "apologize"
}
)

self.workflow.add_edge("map_question_to_schema", "generate_function")
self.workflow.add_edge("hnsw_overlap_search", "generate_answer")
if self.supportai_enabled:
self.workflow.add_edge("hnsw_overlap_search", "generate_answer")
self.workflow.add_edge("rewrite_question", "entry")
self.workflow.add_edge("apologize", END)

Expand Down
20 changes: 15 additions & 5 deletions app/agent/agent_hallucination_check.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from langchain.prompts import PromptTemplate
from langchain import hub
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.output_parsers import PydanticOutputParser
from app.tools.logwriter import LogWriter
import logging
from app.log import req_id_cv
from langchain.pydantic_v1 import BaseModel, Field

logger = logging.getLogger(__name__)

class HallucinationCheckResponse(BaseModel):
score: str = Field(description="The score of the hallucination check. Either 'yes' or 'no', indicating if the answer is hallucinated.")

class TigerGraphAgentHallucinationCheck:
def __init__(self, llm_model):
self.llm = llm_model
Expand All @@ -20,6 +23,9 @@ def check_hallucination(self, generation: str, context: str) -> dict:
dict: The answer to the question and a boolean indicating if the answer is hallucinated.
"""
LogWriter.info(f"request_id={req_id_cv.get()} ENTRY check_hallucination")

hallucination_parser = PydanticOutputParser(pydantic_object=HallucinationCheckResponse)

prompt = PromptTemplate(
template="""You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n
Here are the facts:
Expand All @@ -28,12 +34,16 @@ def check_hallucination(self, generation: str, context: str) -> dict:
\n ------- \n
Here is the answer: {generation}
Give a binary score 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n
Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
input_variables=["generation", "context"]
Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
Format: {format_instructions}""",
input_variables=["generation", "context"],
partial_variables={
"format_instructions": hallucination_parser.get_format_instructions()
}
)

# Chain
rag_chain = prompt | self.llm.model | JsonOutputParser()
rag_chain = prompt | self.llm.model | hallucination_parser

prediction = rag_chain.invoke({"context": context, "generation": generation})
LogWriter.info(f"request_id={req_id_cv.get()} EXIT check_hallucination")
Expand Down
20 changes: 15 additions & 5 deletions app/agent/agent_rewrite.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from langchain.prompts import PromptTemplate
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import PydanticOutputParser
from app.tools.logwriter import LogWriter
import logging
from app.log import req_id_cv
from langchain.pydantic_v1 import BaseModel, Field

logger = logging.getLogger(__name__)

class QuestionRewriteResponse(BaseModel):
rewritten_question: str = Field(description="The rewritten question.")

class TigerGraphAgentRewriter:
def __init__(self, llm_model):
self.llm = llm_model
Expand All @@ -19,16 +22,23 @@ def rewrite_question(self, question: str) -> str:
str: The rewritten question.
"""
LogWriter.info(f"request_id={req_id_cv.get()} ENTRY generate_answer")

rewrite_parser = PydanticOutputParser(pydantic_object=QuestionRewriteResponse)

re_write_prompt = PromptTemplate(
template="""You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the initial and formulate an improved question. \n
Here is the initial question: \n\n {question}. Improved question with no preamble: \n """,
for AI agent question answering. Look at the initial and formulate an improved question. \n
Here is the initial question: \n\n {question}.
Format your response in the following manner {format_instructions}""",
input_variables=["question"],
partial_variables={
"format_instructions": rewrite_parser.get_format_instructions()
}
)


# Chain
question_rewriter = re_write_prompt | self.llm.model | StrOutputParser()
question_rewriter = re_write_prompt | self.llm.model | rewrite_parser

generation = question_rewriter.invoke({"question": question})
LogWriter.info(f"request_id={req_id_cv.get()} EXIT generate_answer")
Expand Down
3 changes: 2 additions & 1 deletion app/agent/agent_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@ def route_question(self, question: str) -> str:
)

question_router = prompt | self.llm.model | router_parser
res = question_router.invoke({"question": question, "v_types": v_types, "e_types": e_types})
LogWriter.info(f"request_id={req_id_cv.get()} EXIT route_question")
return question_router.invoke({"question": question, "v_types": v_types, "e_types": e_types})
return res
Loading

0 comments on commit 674850c

Please sign in to comment.