Skip to content

Commit

Permalink
GML-1822 remove hallucination and usefullness check and cypher suppor…
Browse files Browse the repository at this point in the history
…t for regession test
  • Loading branch information
Lu Zhou authored and Lu Zhou committed Jul 13, 2024
1 parent 00e0a94 commit 5b9af2c
Showing 1 changed file with 71 additions and 69 deletions.
140 changes: 71 additions & 69 deletions copilot/app/agent/agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
embedding_store,
mq2s_tool,
gen_func_tool,
# cypher_gen_tool=None,
cypher_gen_tool=None,
enable_human_in_loop=False,
q: Q = None,
supportai_retriever="hnsw_overlap",
Expand All @@ -59,7 +59,7 @@ def __init__(
self.embedding_store = embedding_store
self.mq2s = mq2s_tool
self.gen_func = gen_func_tool
# self.cypher_gen = cypher_gen_tool
self.cypher_gen = cypher_gen_tool
self.enable_human_in_loop = enable_human_in_loop
self.q = q

Expand Down Expand Up @@ -156,37 +156,37 @@ def generate_function(self, state):
state["lookup_source"] = "inquiryai"
return state

# def generate_cypher(self, state):
# """
# Run the agent cypher generator.
# """
# self.emit_progress("Generating the Cypher to answer your question")
# cypher = self.cypher_gen._run(state["question"])

# response = self.db_connection.gsql(cypher)
# response_lines = response.split("\n")
# try:
# json_str = "\n".join(response_lines[1:])
# response_json = json.loads(json_str)
# state["context"] = {
# "answer": response_json["results"][0],
# "cypher": cypher,
# "reasoning": "The following OpenCypher query was executed to answer the question. {}".format(
# cypher
# ),
# }
# except:
# state["context"] = {
# "error": True,
# "cypher": cypher,
# }
# if state["error_history"] is None:
# state["error_history"] = []
def generate_cypher(self, state):
"""
Run the agent cypher generator.
"""
self.emit_progress("Generating the Cypher to answer your question")
cypher = self.cypher_gen._run(state["question"])

response = self.db_connection.gsql(cypher)
response_lines = response.split("\n")
try:
json_str = "\n".join(response_lines[1:])
response_json = json.loads(json_str)
state["context"] = {
"answer": response_json["results"][0],
"cypher": cypher,
"reasoning": "The following OpenCypher query was executed to answer the question. {}".format(
cypher
),
}
except:
state["context"] = {
"error": True,
"cypher": cypher,
}
if state["error_history"] is None:
state["error_history"] = []

# state["error_history"].append({"error_message": response, "error_step": "generate_cypher"})
state["error_history"].append({"error_message": response, "error_step": "generate_cypher"})

# state["lookup_source"] = "cypher"
# return state
state["lookup_source"] = "cypher"
return state

def hnsw_overlap_search(self, state):
"""
Expand Down Expand Up @@ -357,46 +357,47 @@ def check_answer_for_hallucinations(self, state):
"""
Run the agent hallucination check.
"""
self.emit_progress("Checking the response is relevant")
step = TigerGraphAgentHallucinationCheck(self.llm_provider)

try:
context_data_str = json.dumps(state["context"]["result"])
# logger.info(f"context_data_str: {context_data_str}")
except (TypeError, ValueError) as e:
logger.error(f"Failed to serialize context to JSON: {e}")
raise ValueError("Invalid context data format. Unable to convert to JSON.")
hallucinations = step.check_hallucination(
state["answer"].natural_language_response, context_data_str
)
logger.info(f"hallucination checker")
logger.info(f"answer: {state['answer'].natural_language_response}")
logger.info(f"context: {context_data_str}")
logger.info(f"if grounded: {hallucinations}")
if hallucinations.score == "yes":
self.emit_progress(DONE)
return "grounded"
else:
return "hallucination"
# self.emit_progress("Checking the response is relevant")
# step = TigerGraphAgentHallucinationCheck(self.llm_provider)

# try:
# context_data_str = json.dumps(state["context"]["result"])
# # logger.info(f"context_data_str: {context_data_str}")
# except (TypeError, ValueError) as e:
# logger.error(f"Failed to serialize context to JSON: {e}")
# raise ValueError("Invalid context data format. Unable to convert to JSON.")
# hallucinations = step.check_hallucination(
# state["answer"].natural_language_response, context_data_str
# )
# logger.info(f"hallucination checker")
# logger.info(f"answer: {state['answer'].natural_language_response}")
# logger.info(f"context: {context_data_str}")
# logger.info(f"if grounded: {hallucinations}")
# if hallucinations.score == "yes":
# self.emit_progress(DONE)
# return "grounded"
# else:
# return "hallucination"
return "grounded"

def check_answer_for_usefulness(self, state):
"""
Run the agent usefulness check.
"""
step = TigerGraphAgentUsefulnessCheck(self.llm_provider)

usefulness = step.check_usefulness(
state["question"], state["answer"].natural_language_response
)
logger.info(f"usefulness checker")
logger.info(f"question: {state['question']}")
logger.info(f"answer: {state['answer'].natural_language_response}")
logger.info(f"if useful: {usefulness}")
if usefulness.score == "yes":
return "useful"
else:
return "not_useful"
# return "useful"
# step = TigerGraphAgentUsefulnessCheck(self.llm_provider)

# usefulness = step.check_usefulness(
# state["question"], state["answer"].natural_language_response
# )
# logger.info(f"usefulness checker")
# logger.info(f"question: {state['question']}")
# logger.info(f"answer: {state['answer'].natural_language_response}")
# logger.info(f"if useful: {usefulness}")
# if usefulness.score == "yes":
# return "useful"
# else:
# return "not_useful"
return "useful"

def check_answer_for_usefulness_and_hallucinations(self, state):
"""
Expand Down Expand Up @@ -490,21 +491,22 @@ def create_graph(self):
self.check_state_for_generation_error,
{"error": "rewrite_question", "success": "generate_answer"},
)

if self.supportai_enabled:
self.workflow.add_conditional_edges(
"generate_answer",
"generate_answer",
self.check_answer_for_usefulness_and_hallucinations,
{
"hallucination": "rewrite_question",
"grounded": END,
"not_useful": "rewrite_question",
# "not_useful": "rewrite_question",
"inquiryai_not_useful": "supportai",
"supportai_not_useful": "map_question_to_schema",
},
)
else:
self.workflow.add_conditional_edges(
"generate_answer",
"generate_answer",
self.check_answer_for_usefulness_and_hallucinations,
{
"hallucination": "rewrite_question",
Expand Down

0 comments on commit 5b9af2c

Please sign in to comment.