diff --git a/common/py_schemas/schemas.py b/common/py_schemas/schemas.py index 139cd86a..ebc6f17a 100644 --- a/common/py_schemas/schemas.py +++ b/common/py_schemas/schemas.py @@ -126,9 +126,9 @@ class QueryUpsertRequest(BaseModel): query_info: Optional[GSQLQueryInfo] -class Role(enum.Enum): - system = enum.auto() - user = enum.auto() +class Role(enum.StrEnum): + SYSTEM = enum.auto() + USER = enum.auto() class Message(BaseModel): @@ -144,3 +144,13 @@ class Message(BaseModel): response_time: Optional[float] = None # time in fractional seconds feedback: Optional[int] = None comment: Optional[str] = None + + +class ResponseType(enum.StrEnum): + PROGRESS = enum.auto() + MESSAGE = enum.auto() + + +class AgentProgess(BaseModel): + content: str + response_type: ResponseType diff --git a/copilot-ui/src/actions/ActionProvider.tsx b/copilot-ui/src/actions/ActionProvider.tsx index 2abe6430..127b0f8b 100644 --- a/copilot-ui/src/actions/ActionProvider.tsx +++ b/copilot-ui/src/actions/ActionProvider.tsx @@ -3,7 +3,7 @@ import {createClientMessage} from 'react-chatbot-kit'; import useWebSocket, {ReadyState} from 'react-use-websocket'; import Loader from '../components/Loader'; -const WS_URL = 'ws://0.0.0.0:8000/ui/Demo_Graph1/chat'; +const WS_URL = "ws://0.0.0.0:8000/ui/Demo_Graph1/chat"; interface ActionProviderProps { createChatBotMessage: any; @@ -30,10 +30,16 @@ export interface Message { comment: string; } -const ActionProvider: React.FC = ({createChatBotMessage, setState, children}) => { +const ActionProvider: React.FC = ({ + createChatBotMessage, + setState, + children, +}) => { const [socketUrl, setSocketUrl] = useState(WS_URL); - const [messageHistory, setMessageHistory] = useState[]>([]); - const {sendMessage, lastMessage, readyState} = useWebSocket(socketUrl); + const [messageHistory, setMessageHistory] = useState[]>( + [], + ); + const { sendMessage, lastMessage, readyState } = useWebSocket(socketUrl); // eslint-disable-next-line // @ts-ignore @@ -43,8 +49,8 @@ const ActionProvider: React.FC = ({createChatBotMessage, se useWebSocket(WS_URL, { onOpen: () => { - queryCopilotWs2('dXNlcl8yOlRoaXNpc3RoZWFkbWluITE='); - console.log('WebSocket connection established.'); + queryCopilotWs2(localStorage.getItem("creds")!); + console.log("WebSocket connection established."); }, }); diff --git a/copilot-ui/src/components/CustomChatMessage.tsx b/copilot-ui/src/components/CustomChatMessage.tsx index 824cd707..8c16be66 100755 --- a/copilot-ui/src/components/CustomChatMessage.tsx +++ b/copilot-ui/src/components/CustomChatMessage.tsx @@ -99,7 +99,11 @@ export const CustomChatMessage: FC = ({ message ) : (
-

{message.content}

+ {message.response_type === "progress" ? ( +

{message.content}

+ ) : ( +

{message.content}

+ )}
0: + with self.l: + item = self.q[0] + self.q = self.q[1:] + return item + + + def clear(self): + self.q.clear() diff --git a/copilot/app/agent/agent.py b/copilot/app/agent/agent.py index fb6b4f98..d265af33 100644 --- a/copilot/app/agent/agent.py +++ b/copilot/app/agent/agent.py @@ -3,14 +3,23 @@ from typing import Dict, List from agent.agent_graph import TigerGraphAgentGraph +from agent.Q import Q +from fastapi import WebSocket from tools import GenerateCypher, GenerateFunction, MapQuestionToSchema from common.config import embedding_service, embedding_store, llm_config from common.embeddings.base_embedding_store import EmbeddingStore from common.embeddings.embedding_services import EmbeddingModel -from common.llm_services import (AWS_SageMaker_Endpoint, AWSBedrock, - AzureOpenAI, GoogleVertexAI, Groq, - HuggingFaceEndpoint, Ollama, OpenAI) +from common.llm_services import ( + AWS_SageMaker_Endpoint, + AWSBedrock, + AzureOpenAI, + GoogleVertexAI, + Groq, + HuggingFaceEndpoint, + Ollama, + OpenAI, +) from common.llm_services.base_llm import LLM_Model from common.logs.log import req_id_cv from common.logs.logwriter import LogWriter @@ -43,6 +52,7 @@ def __init__( embedding_model: EmbeddingModel, embedding_store: EmbeddingStore, use_cypher: bool = False, + ws=None, ): self.conn = db_connection @@ -62,26 +72,23 @@ def __init__( embedding_store, ) + self.cypher_tool = None if use_cypher: self.cypher_tool = GenerateCypher(self.conn, self.llm) - self.agent = TigerGraphAgentGraph( - self.llm, - self.conn, - self.embedding_model, - self.embedding_store, - self.mq2s, - self.gen_func, - self.cypher_tool, - ).create_graph() - else: - self.agent = TigerGraphAgentGraph( - self.llm, - self.conn, - self.embedding_model, - self.embedding_store, - self.mq2s, - self.gen_func, - ).create_graph() + + if ws is not None: + self.q = Q() + + self.agent = TigerGraphAgentGraph( + self.llm, + self.conn, + self.embedding_model, + self.embedding_store, + self.mq2s, + self.gen_func, + cypher_gen_tool=self.cypher_tool, + q=self.q, + ).create_graph() logger.debug(f"request_id={req_id_cv.get()} agent initialized") @@ -141,7 +148,7 @@ def question_for_agent( ) -def make_agent(graphname, conn, use_cypher) -> TigerGraphAgent: +def make_agent(graphname, conn, use_cypher, ws: WebSocket = None) -> TigerGraphAgent: if llm_config["completion_service"]["llm_service"].lower() == "openai": llm_service_name = "openai" print(llm_config["completion_service"]) @@ -176,7 +183,13 @@ def make_agent(graphname, conn, use_cypher) -> TigerGraphAgent: logger.debug( f"/{graphname}/query_with_history request_id={req_id_cv.get()} llm_service={llm_service_name} agent created" ) + agent = TigerGraphAgent( - llm_provider, conn, embedding_service, embedding_store, use_cypher=use_cypher + llm_provider, + conn, + embedding_service, + embedding_store, + use_cypher=use_cypher, + ws=ws, ) return agent diff --git a/copilot/app/agent/agent_graph.py b/copilot/app/agent/agent_graph.py index d74ca064..daff2325 100644 --- a/copilot/app/agent/agent_graph.py +++ b/copilot/app/agent/agent_graph.py @@ -1,32 +1,31 @@ -from typing_extensions import TypedDict import json +import logging from typing import Optional -from langgraph.graph import END, StateGraph from agent.agent_generation import TigerGraphAgentGenerator -from agent.agent_router import TigerGraphAgentRouter from agent.agent_hallucination_check import TigerGraphAgentHallucinationCheck -from agent.agent_usefulness_check import TigerGraphAgentUsefulnessCheck from agent.agent_rewrite import TigerGraphAgentRewriter - -from tools import MapQuestionToSchemaException -from supportai.retrievers import HNSWOverlapRetriever - -from common.py_schemas import (MapQuestionToSchemaResponse, - CoPilotResponse) - +from agent.agent_router import TigerGraphAgentRouter +from agent.agent_usefulness_check import TigerGraphAgentUsefulnessCheck +from agent.Q import DONE, Q +from langgraph.graph import END, StateGraph from pyTigerGraph.pyTigerGraphException import TigerGraphException +from supportai.retrievers import HNSWOverlapRetriever +from tools import MapQuestionToSchemaException +from typing_extensions import TypedDict -import logging from common.logs.log import req_id_cv +from common.py_schemas import CoPilotResponse, MapQuestionToSchemaResponse logger = logging.getLogger(__name__) + class GraphState(TypedDict): """ Represents the state of the agent graph. - + """ + question: str generation: str context: str @@ -37,14 +36,18 @@ class GraphState(TypedDict): class TigerGraphAgentGraph: - def __init__(self, llm_provider, - db_connection, - embedding_model, - embedding_store, - mq2s_tool, - gen_func_tool, - cypher_gen_tool = None, - enable_human_in_loop=False): + def __init__( + self, + llm_provider, + db_connection, + embedding_model, + embedding_store, + mq2s_tool, + gen_func_tool, + cypher_gen_tool=None, + enable_human_in_loop=False, + q: Q = None, + ): self.workflow = StateGraph(GraphState) self.llm_provider = llm_provider self.db_connection = db_connection @@ -54,6 +57,7 @@ def __init__(self, llm_provider, self.gen_func = gen_func_tool self.cypher_gen = cypher_gen_tool self.enable_human_in_loop = enable_human_in_loop + self.q = q self.supportai_enabled = True try: @@ -62,6 +66,10 @@ def __init__(self, llm_provider, logger.info("HNSW_Overlap not found in the graph. Disabling supportai.") self.supportai_enabled = False + def emit_progress(self, msg): + if self.q is not None: + self.q.put(msg) + def route_question(self, state): """ Run the agent router. @@ -72,150 +80,185 @@ def route_question(self, state): elif state["question_retry_count"] > 2: return "apologize" state["question_retry_count"] += 1 - logger.debug_pii(f"request_id={req_id_cv.get()} Routing question: {state['question']}") + logger.debug_pii( + f"request_id={req_id_cv.get()} Routing question: {state['question']}" + ) if self.supportai_enabled: - source = step.route_question(state['question']) - logger.debug_pii(f"request_id={req_id_cv.get()} Routing question to: {source}") + source = step.route_question(state["question"]) + logger.debug_pii( + f"request_id={req_id_cv.get()} Routing question to: {source}" + ) if source.datasource == "vectorstore": return "supportai_lookup" elif source.datasource == "functions": return "inquiryai_lookup" else: return "inquiryai_lookup" - + def apologize(self, state): """ Apologize for not being able to answer the question. """ - state["answer"] = CoPilotResponse(natural_language_response="I'm sorry, I don't know the answer to that question. Please try rephrasing your question.", - answered_question=False, - response_type="error", - query_sources={"error": "Question could not be routed to a datasource."}) + self.emit_progress(DONE) + state["answer"] = CoPilotResponse( + natural_language_response="I'm sorry, I don't know the answer to that question. Please try rephrasing your question.", + answered_question=False, + response_type="error", + query_sources={"error": "Question could not be routed to a datasource."}, + ) return state - + def map_question_to_schema(self, state): """ Run the agent schema mapping. """ + self.emit_progress("Mapping your question to the graph's schema") try: - step = self.mq2s._run(state['question']) + step = self.mq2s._run(state["question"]) state["schema_mapping"] = step return state except MapQuestionToSchemaException: return "failure" - - + def generate_function(self, state): """ Run the agent function generator. """ + self.emit_progress("Generating the code to answer your question") try: - step = self.gen_func._run(state['question'], - state["schema_mapping"].target_vertex_types, - state["schema_mapping"].target_vertex_attributes, - state["schema_mapping"].target_vertex_ids, - state["schema_mapping"].target_edge_types, - state["schema_mapping"].target_edge_attributes) + step = self.gen_func._run( + state["question"], + state["schema_mapping"].target_vertex_types, + state["schema_mapping"].target_vertex_attributes, + state["schema_mapping"].target_vertex_ids, + state["schema_mapping"].target_edge_types, + state["schema_mapping"].target_edge_attributes, + ) state["context"] = step except Exception as e: state["context"] = {"error": str(e)} state["lookup_source"] = "inquiryai" return state - + def generate_cypher(self, state): """ Run the agent cypher generator. """ - cypher = self.cypher_gen._run(state['question']) + self.emit_progress("Generating the Cypher to answer your question") + cypher = self.cypher_gen._run(state["question"]) response = self.db_connection.gsql(cypher) - response_lines = response.split('\n') + response_lines = response.split("\n") try: - json_str = '\n'.join(response_lines[1:]) + json_str = "\n".join(response_lines[1:]) response_json = json.loads(json_str) state["context"] = {"answer": response_json["results"][0], "cypher": cypher} except: - state["context"] = {"error": True, "error_message": response, "cypher": cypher} + state["context"] = { + "error": True, + "error_message": response, + "cypher": cypher, + } state["lookup_source"] = "cypher" return state - + def hnsw_overlap_search(self, state): """ Run the agent overlap search. """ - retriever = HNSWOverlapRetriever(self.embedding_model, - self.embedding_store, - self.llm_provider.model, - self.db_connection) - step = retriever.search(state['question'], - indices=["Document", "DocumentChunk", - "Entity", "Relationship"], - num_seen_min=2) + self.emit_progress("Searching the graph for relevant information") + retriever = HNSWOverlapRetriever( + self.embedding_model, + self.embedding_store, + self.llm_provider.model, + self.db_connection, + ) + step = retriever.search( + state["question"], + indices=["Document", "DocumentChunk", "Entity", "Relationship"], + num_seen_min=2, + ) state["context"] = step[0] state["lookup_source"] = "supportai" return state - - + def generate_answer(self, state): """ Run the agent generator. """ + self.emit_progress("Putting the pieces together") step = TigerGraphAgentGenerator(self.llm_provider) - logger.debug_pii(f"request_id={req_id_cv.get()} Generating answer for question: {state['question']}") + logger.debug_pii( + f"request_id={req_id_cv.get()} Generating answer for question: {state['question']}" + ) if state["lookup_source"] == "supportai": - answer = step.generate_answer(state['question'], state["context"]) + answer = step.generate_answer(state["question"], state["context"]) elif state["lookup_source"] == "inquiryai": - answer = step.generate_answer(state['question'], state["context"]["result"]) + answer = step.generate_answer(state["question"], state["context"]["result"]) elif state["lookup_source"] == "cypher": - answer = step.generate_answer(state['question'], state["context"]["answer"]) - logger.debug_pii(f"request_id={req_id_cv.get()} Generated answer: {answer.generated_answer}") + answer = step.generate_answer(state["question"], state["context"]["answer"]) + logger.debug_pii( + f"request_id={req_id_cv.get()} Generated answer: {answer.generated_answer}" + ) try: - resp = CoPilotResponse(natural_language_response=answer.generated_answer, - answered_question=True, - response_type=state["lookup_source"], - query_sources=state["context"]) + resp = CoPilotResponse( + natural_language_response=answer.generated_answer, + answered_question=True, + response_type=state["lookup_source"], + query_sources=state["context"], + ) except Exception as e: - resp = CoPilotResponse(natural_language_response="I'm sorry, I don't know the answer to that question.", - answered_question=False, - response_type=state["lookup_source"], - query_sources={"error": str(e)}) + resp = CoPilotResponse( + natural_language_response="I'm sorry, I don't know the answer to that question.", + answered_question=False, + response_type=state["lookup_source"], + query_sources={"error": str(e)}, + ) state["answer"] = resp - + return state - + def rewrite_question(self, state): """ Run the agent question rewriter. """ + self.emit_progress("Thinking") step = TigerGraphAgentRewriter(self.llm_provider) state["question"] = step.rewrite_question(state["question"]) return state - + def check_answer_for_hallucinations(self, state): """ Run the agent hallucination check. """ + self.emit_progress("Checking the response for mistakes") step = TigerGraphAgentHallucinationCheck(self.llm_provider) - hallucinations = step.check_hallucination(state["answer"].natural_language_response, state["context"]) + hallucinations = step.check_hallucination( + state["answer"].natural_language_response, state["context"] + ) if hallucinations.score == "yes": + self.emit_progress(DONE) return "grounded" else: return "hallucination" - + def check_answer_for_usefulness(self, state): """ Run the agent usefulness check. """ + # self.emit_progress("check usefulness") step = TigerGraphAgentUsefulnessCheck(self.llm_provider) - usefulness = step.check_usefulness(state["question"], state["answer"].natural_language_response) + usefulness = step.check_usefulness( + state["question"], state["answer"].natural_language_response + ) if usefulness.score == "yes": return "useful" else: return "not_useful" - + def check_answer_for_usefulness_and_hallucinations(self, state): """ Run the agent usefulness and hallucination check. @@ -226,6 +269,7 @@ def check_answer_for_usefulness_and_hallucinations(self, state): else: useful = self.check_answer_for_usefulness(state) if useful == "useful": + self.emit_progress(DONE) return "grounded" else: if state["lookup_source"] == "supportai": @@ -234,7 +278,7 @@ def check_answer_for_usefulness_and_hallucinations(self, state): return "inquiryai_not_useful" elif state["lookup_source"] == "cypher": return "cypher_not_useful" - + def check_state_for_generation_error(self, state): """ Check if the state has an error. @@ -263,72 +307,65 @@ def create_graph(self): self.workflow.add_conditional_edges( "generate_function", self.check_state_for_generation_error, - { - "error": "generate_cypher", - "success": "generate_answer" - } + {"error": "generate_cypher", "success": "generate_answer"}, + ) + self.workflow.add_conditional_edges( + "generate_cypher", + self.check_state_for_generation_error, + {"error": "apologize", "success": "generate_answer"}, ) - self.workflow.add_conditional_edges("generate_cypher", - self.check_state_for_generation_error, - { - "error": "apologize", - "success": "generate_answer" - }) if self.supportai_enabled: self.workflow.add_conditional_edges( "generate_answer", self.check_answer_for_usefulness_and_hallucinations, - { - "hallucination": "rewrite_question", - "grounded": END, - "inquiryai_not_useful": "generate_cypher", - "cypher_not_useful": "hnsw_overlap_search", - "supportai_not_useful": "map_question_to_schema" - } - ) + { + "hallucination": "rewrite_question", + "grounded": END, + "inquiryai_not_useful": "generate_cypher", + "cypher_not_useful": "hnsw_overlap_search", + "supportai_not_useful": "map_question_to_schema", + }, + ) else: self.workflow.add_conditional_edges( "generate_answer", self.check_answer_for_usefulness_and_hallucinations, - { - "hallucination": "rewrite_question", - "grounded": END, - "inquiryai_not_useful": "generate_cypher", - "cypher_not_useful": "apologize" - } - ) + { + "hallucination": "rewrite_question", + "grounded": END, + "inquiryai_not_useful": "generate_cypher", + "cypher_not_useful": "apologize", + }, + ) else: self.workflow.add_conditional_edges( "generate_function", self.check_state_for_generation_error, - { - "error": "rewrite_question", - "success": "generate_answer" - } + {"error": "rewrite_question", "success": "generate_answer"}, ) if self.supportai_enabled: self.workflow.add_conditional_edges( "generate_answer", self.check_answer_for_usefulness_and_hallucinations, - { - "hallucination": "rewrite_question", - "grounded": END, - "not_useful": "rewrite_question", - "inquiryai_not_useful": "hnsw_overlap_search", - "supportai_not_useful": "map_question_to_schema" - } + { + "hallucination": "rewrite_question", + "grounded": END, + "not_useful": "rewrite_question", + "inquiryai_not_useful": "hnsw_overlap_search", + "supportai_not_useful": "map_question_to_schema", + }, ) else: self.workflow.add_conditional_edges( "generate_answer", self.check_answer_for_usefulness_and_hallucinations, - { - "hallucination": "rewrite_question", - "grounded": END, - "not_useful": "rewrite_question", - "inquiryai_not_useful": "apologize", - "supportai_not_useful": "map_question_to_schema" - } + { + "hallucination": "rewrite_question", + "grounded": END, + "not_useful": "rewrite_question", + "inquiryai_not_useful": "apologize", + "supportai_not_useful": "map_question_to_schema", + }, ) if self.supportai_enabled: @@ -338,8 +375,8 @@ def create_graph(self): { "supportai_lookup": "hnsw_overlap_search", "inquiryai_lookup": "map_question_to_schema", - "apologize": "apologize" - } + "apologize": "apologize", + }, ) else: self.workflow.add_conditional_edges( @@ -347,8 +384,8 @@ def create_graph(self): self.route_question, { "inquiryai_lookup": "map_question_to_schema", - "apologize": "apologize" - } + "apologize": "apologize", + }, ) self.workflow.add_edge("map_question_to_schema", "generate_function") @@ -356,9 +393,6 @@ def create_graph(self): self.workflow.add_edge("hnsw_overlap_search", "generate_answer") self.workflow.add_edge("rewrite_question", "entry") self.workflow.add_edge("apologize", END) - app = self.workflow.compile() return app - - diff --git a/copilot/app/routers/ui.py b/copilot/app/routers/ui.py index 6b04ac94..fd9537ef 100644 --- a/copilot/app/routers/ui.py +++ b/copilot/app/routers/ui.py @@ -1,3 +1,4 @@ +import asyncio import base64 import logging import os @@ -7,9 +8,11 @@ import uuid from typing import Annotated +import asyncer import httpx import requests from agent.agent import TigerGraphAgent, make_agent +from agent.Q import DONE from fastapi import APIRouter, Depends, HTTPException, WebSocket, status from fastapi.security import HTTPBasic, HTTPBasicCredentials from pyTigerGraph import TigerGraphConnection @@ -20,7 +23,7 @@ from common.logs.log import req_id_cv from common.logs.logwriter import LogWriter from common.metrics.prometheus_metrics import metrics as pmetrics -from common.py_schemas.schemas import CoPilotResponse, Message, Role +from common.py_schemas.schemas import AgentProgess, CoPilotResponse, Message, ResponseType, Role logger = logging.getLogger(__name__) @@ -105,19 +108,46 @@ def add_feedback( return {"message": "feedback saved", "message_id": message.message_id} -def run_agent( +async def emit_progress(agent: TigerGraphAgent, ws: WebSocket): + # loop on q until done token emit events through ws + msg = None + pop = asyncer.asyncify(agent.q.pop) + + while msg != DONE: + msg = await pop() + if msg is not None and msg != DONE: + message = AgentProgess( + content=msg, + response_type=ResponseType.PROGRESS, + ) + await ws.send_text(message.model_dump_json()) + + +async def run_agent( agent: TigerGraphAgent, data: str, conversation_history: list[dict[str, str]], graphname, + ws: WebSocket, ) -> CoPilotResponse: resp = CoPilotResponse( natural_language_response="", answered_question=False, response_type="inquiryai" ) + a_question_for_agent = asyncer.asyncify(agent.question_for_agent) try: - # TODO: make num mesages in history configureable - resp = agent.question_for_agent(data, conversation_history[-4:]) + # start agent and sample from Q to emit progress + + async with asyncio.TaskGroup() as tg: + # run agent + a_resp = tg.create_task( + # TODO: make num mesages in history configureable + a_question_for_agent(data, conversation_history[-4:]) + ) + # sample Q and emit events + tg.create_task(emit_progress(agent, ws)) pmetrics.llm_success_response_total.labels(embedding_service.model_name).inc() + resp = a_resp.result() + agent.q.clear() except MapQuestionToSchemaException: resp.natural_language_response = ( @@ -190,7 +220,9 @@ async def chat( # create convo_id conversation_history = [] # TODO: go get history instead of starting from 0 convo_id = str(uuid.uuid4()) - agent = make_agent(graphname, conn, use_cypher) + agent = make_agent(graphname, conn, use_cypher, ws=websocket) + + # from anyio import sleep as asleep prev_id = None while True: @@ -203,7 +235,7 @@ async def chat( parent_id=prev_id, model=llm_config["model_name"], content=data, - role=Role.user.name, + role=Role.USER, ) # save message await write_message_to_history(message, usr_auth) @@ -211,7 +243,7 @@ async def chat( # generate response and keep track of response time start = time.monotonic() - resp = run_agent(agent, data, conversation_history, graphname) + resp = await run_agent(agent, data, conversation_history, graphname, websocket) elapsed = time.monotonic() - start # save message @@ -221,7 +253,7 @@ async def chat( parent_id=prev_id, model=llm_config["model_name"], content=resp.natural_language_response, - role=Role.system.name, + role=Role.SYSTEM, response_time=elapsed, answered_question=resp.answered_question, response_type=resp.response_type, diff --git a/copilot/requirements.txt b/copilot/requirements.txt index 6cfb4557..b678a702 100644 --- a/copilot/requirements.txt +++ b/copilot/requirements.txt @@ -6,6 +6,7 @@ appdirs==1.4.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 async-timeout==4.0.3 +asyncer==0.0.7 attrs==23.1.0 azure-core==1.30.1 azure-storage-blob==12.19.1