diff --git a/common/gsql/supportai/retrievers/GraphRAG_Community_Retriever.gsql b/common/gsql/supportai/retrievers/GraphRAG_Community_Retriever.gsql new file mode 100644 index 00000000..2d6ef9b0 --- /dev/null +++ b/common/gsql/supportai/retrievers/GraphRAG_Community_Retriever.gsql @@ -0,0 +1,7 @@ +CREATE DISTRIBUTED QUERY GraphRAG_CommunityRetriever(INT community_level=2) FOR GRAPH pyTigerGraphRAG { + comms = {Community.*}; + + selected_comms = SELECT c FROM comms:c WHERE c.iteration == community_level; + + PRINT selected_comms; +} \ No newline at end of file diff --git a/common/gsql/supportai/retrievers/HNSW_Search_Content.gsql b/common/gsql/supportai/retrievers/HNSW_Search_Content.gsql index 9de116b7..a0f7d009 100644 --- a/common/gsql/supportai/retrievers/HNSW_Search_Content.gsql +++ b/common/gsql/supportai/retrievers/HNSW_Search_Content.gsql @@ -24,6 +24,8 @@ CREATE DISTRIBUTED QUERY HNSW_Search_Content(STRING v_type, STRING milvus_host, POST-ACCUM IF s.type == "Relationship" OR s.type == "Entity" OR s.type == "Concept" THEN @@final_retrieval += (s.id -> s.definition) + ELSE IF s.type == "Community" THEN + @@final_retrieval += (s.id -> s.description) END; PRINT @@final_retrieval; diff --git a/common/logs/log.py b/common/logs/log.py index b4f11b77..0f974d77 100644 --- a/common/logs/log.py +++ b/common/logs/log.py @@ -64,7 +64,7 @@ def logToRoot(message, *args, **kwargs): addLoggingLevel("DEBUG_PII", logging.DEBUG - 5) log_config = get_log_config() -LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper() +LOGLEVEL = os.environ.get("LOGLEVEL", logging.INFO) log_directory = log_config.get("log_file_path", "/tmp/logs") os.makedirs(log_directory, exist_ok=True) diff --git a/common/requirements.txt b/common/requirements.txt index af45c357..86bdc50c 100644 --- a/common/requirements.txt +++ b/common/requirements.txt @@ -1,3 +1,4 @@ +aiochannel==1.2.1 aiohappyeyeballs==2.3.5 aiohttp==3.10.3 aiosignal==1.3.1 @@ -100,7 +101,7 @@ minio==7.2.7 multidict==6.0.5 mypy-extensions==1.0.0 nest-asyncio==1.6.0 -nltk==3.8.2 +nltk==3.8.1 numpy==1.26.4 openai==1.40.6 ordered-set==4.1.0 diff --git a/copilot/app/routers/supportai.py b/copilot/app/routers/supportai.py index 0eff3c41..fac6601d 100644 --- a/copilot/app/routers/supportai.py +++ b/copilot/app/routers/supportai.py @@ -16,6 +16,7 @@ HNSWOverlapRetriever, HNSWRetriever, HNSWSiblingRetriever, + GraphRAG ) from common.config import ( @@ -175,7 +176,11 @@ def search( embedding_service, embedding_store, get_llm_service(llm_config), conn ) res = retriever.search(query.question, query.method_params["top_k"]) - + elif query.method.lower() == "graphrag": + retriever = GraphRAG( + embedding_service, embedding_store, get_llm_service(llm_config), conn + ) + res = retriever.search(query.question, query.method_params["community_level"]) return res @@ -232,6 +237,16 @@ def answer_question( embedding_service, embedding_store, get_llm_service(llm_config), conn ) res = retriever.retrieve_answer(query.question, query.method_params["top_k"]) + + elif query.method.lower() == "graphrag": + retriever = GraphRAG( + embedding_service, embedding_store, get_llm_service(llm_config), conn + ) + res = retriever.retrieve_answer( + query.question, + query.method_params["community_level"], + query.method_params["top_k_answer_candidates"] + ) else: raise Exception("Method not implemented") diff --git a/copilot/app/supportai/retrievers/GraphRAG.py b/copilot/app/supportai/retrievers/GraphRAG.py new file mode 100644 index 00000000..442f8fcb --- /dev/null +++ b/copilot/app/supportai/retrievers/GraphRAG.py @@ -0,0 +1,88 @@ +from supportai.retrievers import BaseRetriever +import asyncio +from concurrent.futures import ThreadPoolExecutor + +from common.metrics.tg_proxy import TigerGraphConnectionProxy + +from langchain_core.output_parsers import PydanticOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field, validator + +from common.llm_services import LLM_Model + + +class CommunityAnswer(BaseModel): + answer: str = Field(description="The answer to the question, based off of the context provided.") + quality_score: int = Field(description="The quality of the answer, based on how well it answers the question. Rate the answer from 0 (poor) to 100 (excellent).") + +output_parser = PydanticOutputParser(pydantic_object=CommunityAnswer) + +ANSWER_PROMPT = PromptTemplate(template = """ +You are a helpful assistant responsible for generating an answer to the question below using the data provided. +Include a quality score for the answer, based on how well it answers the question. The quality score should be between 0 (poor) and 100 (excellent). + +Question: {question} +Context: {context} + +{format_instructions} +""", +input_variables=["question", "context"], +partial_variables={"format_instructions": output_parser.get_format_instructions()} +) + + +class GraphRAG(BaseRetriever): + def __init__( + self, + embedding_service, + embedding_store, + llm_service: LLM_Model, + connection: TigerGraphConnectionProxy, + ): + super().__init__(embedding_service, embedding_store, llm_service, connection) + self._check_query_install("GraphRAG_CommunityRetriever") + + def search(self, question, community_level: int): + res = self.conn.runInstalledQuery("GraphRAG_CommunityRetriever", {"community_level": community_level}, usePost=True) + return res + + async def _generate_candidate(self, question, context): + model = self.llm_service.model + + + + chain = ANSWER_PROMPT | model | output_parser + + answer = await chain.ainvoke( + { + "question": question, + "context": context, + } + ) + return answer + + def gather_candidates(self, question, context): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + tasks = [self._generate_candidate(question, c) for c in context] + res = loop.run_until_complete(asyncio.gather(*tasks)) + loop.close() + return res + + def retrieve_answer(self, + question: str, + community_level: int, + top_k_answer_candidates: int = 1): + retrieved = self.search(question, community_level) + context = [x["attributes"] for x in retrieved[0]["selected_comms"]] + + with ThreadPoolExecutor() as executor: + res = executor.submit(self.gather_candidates, question, context).result() + + # sort list by quality score + res.sort(key=lambda x: x.quality_score, reverse=True) + + new_context = [{"candidate_answer": x.answer, + "score": x.quality_score} for x in res[:top_k_answer_candidates]] + + return self._generate_response(question, new_context) diff --git a/copilot/app/supportai/retrievers/__init__.py b/copilot/app/supportai/retrievers/__init__.py index 9bdcefa4..aa6cd324 100644 --- a/copilot/app/supportai/retrievers/__init__.py +++ b/copilot/app/supportai/retrievers/__init__.py @@ -3,3 +3,4 @@ from .HNSWOverlapRetriever import HNSWOverlapRetriever from .HNSWSiblingRetriever import HNSWSiblingRetriever from .EntityRelationshipRetriever import EntityRelationshipRetriever +from .GraphRAG import GraphRAG diff --git a/copilot/docs/notebooks/SupportAIDemo.ipynb b/copilot/docs/notebooks/SupportAIDemo.ipynb index 29519463..c5a11c34 100644 --- a/copilot/docs/notebooks/SupportAIDemo.ipynb +++ b/copilot/docs/notebooks/SupportAIDemo.ipynb @@ -159,7 +159,7 @@ } ], "source": [ - "conn.ai.forceConsistencyUpdate()" + "conn.ai.forceConsistencyUpdate(method=\"graphrag\")" ] }, { @@ -546,7 +546,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/copilot/requirements.txt b/copilot/requirements.txt index 4a5ac3d1..e057eb90 100644 --- a/copilot/requirements.txt +++ b/copilot/requirements.txt @@ -101,7 +101,7 @@ minio==7.2.7 multidict==6.0.5 mypy-extensions==1.0.0 nest-asyncio==1.6.0 -nltk==3.8.2 +nltk==3.8.1 numpy==1.26.4 openai==1.40.6 ordered-set==4.1.0 diff --git a/eventual-consistency-service/.dockerignore b/eventual-consistency-service/.dockerignore index 5b04df42..2bf1da45 100644 --- a/eventual-consistency-service/.dockerignore +++ b/eventual-consistency-service/.dockerignore @@ -3,3 +3,4 @@ Dockerfile.tests docs tests udfs +__pycache__ \ No newline at end of file diff --git a/eventual-consistency-service/app/graphrag/graph_rag.py b/eventual-consistency-service/app/graphrag/graph_rag.py index ecca36b2..29f03dce 100644 --- a/eventual-consistency-service/app/graphrag/graph_rag.py +++ b/eventual-consistency-service/app/graphrag/graph_rag.py @@ -12,6 +12,7 @@ init, make_headers, stream_ids, + tg_sem, ) from pyTigerGraph import TigerGraphConnection @@ -44,11 +45,12 @@ async def stream_docs( for d in doc_ids["ids"]: try: - res = await client.get( - f"{conn.restppUrl}/query/{conn.graphname}/StreamDocContent/", - params={"doc": d}, - headers=headers, - ) + async with tg_sem: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/StreamDocContent/", + params={"doc": d}, + headers=headers, + ) if res.status_code != 200: # continue to the next doc. # This doc will not be marked as processed, so the ecc will process it eventually. @@ -85,7 +87,7 @@ async def chunk_docs( txt = content["attributes"]["text"] # send the document to be embedded logger.info("chunk writes to extract") - await embed_chan.put((v_id, txt, "Document")) + # await embed_chan.put((v_id, txt, "Document")) task = grp.create_task( workers.chunk_doc(conn, content, upsert_chan, embed_chan, extract_chan) @@ -221,10 +223,11 @@ async def resolve_entities( # Copy RELATIONSHIP edges to RESOLVED_RELATIONSHIP headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: - res = await client.get( - f"{conn.restppUrl}/query/{conn.graphname}/ResolveRelationships/", - headers=headers, - ) + async with tg_sem: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/ResolveRelationships/", + headers=headers, + ) res.raise_for_status() @@ -236,19 +239,21 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel): logger.info("Initializing Communities (first louvain pass)") headers = make_headers(conn) async with httpx.AsyncClient(timeout=None) as client: - res = await client.get( - f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_init", - params={"n_batches": 1}, - headers=headers, - ) + async with tg_sem: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_init", + params={"n_batches": 1}, + headers=headers, + ) res.raise_for_status() # get the modularity async with httpx.AsyncClient(timeout=None) as client: - res = await client.get( - f"{conn.restppUrl}/query/{conn.graphname}/modularity", - params={"iteration": 1, "batch_num": 1}, - headers=headers, - ) + async with tg_sem: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/modularity", + params={"iteration": 1, "batch_num": 1}, + headers=headers, + ) res.raise_for_status() mod = res.json()["results"][0]["mod"] logger.info(f"****mod pass 1: {mod}") @@ -263,21 +268,23 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel): logger.info(f"Running louvain on Communities (iteration: {i})") # louvain pass async with httpx.AsyncClient(timeout=None) as client: - res = await client.get( - f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_communities", - params={"n_batches": 1, "iteration": i}, - headers=headers, - ) + async with tg_sem: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_communities", + params={"n_batches": 1, "iteration": i}, + headers=headers, + ) res.raise_for_status() # get the modularity async with httpx.AsyncClient(timeout=None) as client: - res = await client.get( - f"{conn.restppUrl}/query/{conn.graphname}/modularity", - params={"iteration": i + 1, "batch_num": 1}, - headers=headers, - ) + async with tg_sem: + res = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/modularity", + params={"iteration": i + 1, "batch_num": 1}, + headers=headers, + ) res.raise_for_status() mod = res.json()["results"][0]["mod"] logger.info(f"*** mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})") @@ -307,11 +314,12 @@ async def stream_communities( # async for i in community_chan: # get the community from that layer async with httpx.AsyncClient(timeout=None) as client: - resp = await client.get( - f"{conn.restppUrl}/query/{conn.graphname}/stream_community", - params={"iter": i}, - headers=headers, - ) + async with tg_sem: + resp = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/stream_community", + params={"iter": i}, + headers=headers, + ) resp.raise_for_status() comms = resp.json()["results"][0]["Comms"] @@ -345,7 +353,7 @@ async def summarize_communities( embed_chan.close() -async def run(graphname: str, conn: TigerGraphConnection): +async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): """ Set up GraphRAG: - Install necessary queries. @@ -362,8 +370,8 @@ async def run(graphname: str, conn: TigerGraphConnection): init_start = time.perf_counter() doc_process_switch = True - entity_resolution_switch = True - community_detection_switch = True + entity_resolution_switch =True + community_detection_switch =True if doc_process_switch: logger.info("Doc Processing Start") docs_chan = Channel(1) @@ -378,7 +386,7 @@ async def run(graphname: str, conn: TigerGraphConnection): chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan) ) # upsert chunks - grp.create_task(upsert(upsert_chan)) + grp.create_task(upsert( upsert_chan)) # embed grp.create_task(embed(embed_chan, index_stores, graphname)) # extract entities @@ -437,5 +445,7 @@ async def run(graphname: str, conn: TigerGraphConnection): end = time.perf_counter() logger.info(f"DONE. graphrag system initializer dT: {init_end-init_start}") logger.info(f"DONE. graphrag entity resolution dT: {entity_end-entity_start}") - logger.info(f"DONE. graphrag community initializer dT: {community_end-community_start}") + logger.info( + f"DONE. graphrag community initializer dT: {community_end-community_start}" + ) logger.info(f"DONE. graphrag.run() total time elaplsed: {end-init_start}") diff --git a/eventual-consistency-service/app/graphrag/util.py b/eventual-consistency-service/app/graphrag/util.py index 186ab11a..04e15afb 100644 --- a/eventual-consistency-service/app/graphrag/util.py +++ b/eventual-consistency-service/app/graphrag/util.py @@ -5,6 +5,7 @@ import re import traceback from glob import glob +from typing import Callable import httpx from graphrag import workers @@ -23,7 +24,9 @@ from common.logs.logwriter import LogWriter logger = logging.getLogger(__name__) -http_timeout = httpx.Timeout(15.0) +http_timeout = httpx.Timeout(15.0) + +tg_sem = asyncio.Semaphore(100) async def install_queries( @@ -141,15 +144,16 @@ async def stream_ids( try: async with httpx.AsyncClient(timeout=http_timeout) as client: - res = await client.post( - f"{conn.restppUrl}/query/{conn.graphname}/StreamIds", - params={ - "current_batch": current_batch, - "ttl_batches": ttl_batches, - "v_type": v_type, - }, - headers=headers, - ) + async with tg_sem: + res = await client.post( + f"{conn.restppUrl}/query/{conn.graphname}/StreamIds", + params={ + "current_batch": current_batch, + "ttl_batches": ttl_batches, + "v_type": v_type, + }, + headers=headers, + ) ids = res.json()["results"][0]["@@ids"] return {"error": False, "ids": ids} @@ -199,9 +203,10 @@ async def upsert_vertex( data = json.dumps({"vertices": {vertex_type: {vertex_id: attrs}}}) headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: - res = await client.post( - f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers - ) + async with tg_sem: + res = await client.post( + f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers + ) res.raise_for_status() @@ -209,10 +214,11 @@ async def upsert_vertex( async def check_vertex_exists(conn, v_id: str): headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: - res = await client.get( - f"{conn.restppUrl}/graph/{conn.graphname}/vertices/Entity/{v_id}", - headers=headers, - ) + async with tg_sem: + res = await client.get( + f"{conn.restppUrl}/graph/{conn.graphname}/vertices/Entity/{v_id}", + headers=headers, + ) res.raise_for_status() return res.json() @@ -250,20 +256,22 @@ async def upsert_edge( ) headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: - res = await client.post( - f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers - ) + async with tg_sem: + res = await client.post( + f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers + ) res.raise_for_status() async def get_commuinty_children(conn, i: int, c: str): headers = make_headers(conn) async with httpx.AsyncClient(timeout=None) as client: - resp = await client.get( - f"{conn.restppUrl}/query/{conn.graphname}/get_community_children", - params={"comm": c, "iter": i}, - headers=headers, - ) + async with tg_sem: + resp = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/get_community_children", + params={"comm": c, "iter": i}, + headers=headers, + ) resp.raise_for_status() descrs = [] for d in resp.json()["results"][0]["children"]: @@ -281,11 +289,12 @@ async def get_commuinty_children(conn, i: int, c: str): async def check_vertex_has_desc(conn, i: int): headers = make_headers(conn) async with httpx.AsyncClient(timeout=None) as client: - resp = await client.get( - f"{conn.restppUrl}/query/{conn.graphname}/communities_have_desc", - params={"iter": i}, - headers=headers, - ) + async with tg_sem: + resp = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/communities_have_desc", + params={"iter": i}, + headers=headers, + ) resp.raise_for_status() res = resp.json()["results"][0]["all_have_desc"] diff --git a/eventual-consistency-service/app/graphrag/workers.py b/eventual-consistency-service/app/graphrag/workers.py index 755b1085..37786aee 100644 --- a/eventual-consistency-service/app/graphrag/workers.py +++ b/eventual-consistency-service/app/graphrag/workers.py @@ -37,11 +37,12 @@ async def install_query( headers = {"Authorization": f"Basic {tkn}"} async with httpx.AsyncClient(timeout=None) as client: - res = await client.post( - conn.gsUrl + "/gsqlserver/gsql/file", - data=quote_plus(query.encode("utf-8")), - headers=headers, - ) + async with util.tg_sem: + res = await client.post( + conn.gsUrl + "/gsqlserver/gsql/file", + data=quote_plus(query.encode("utf-8")), + headers=headers, + ) if "error" in res.text.lower(): LogWriter.error(res.text) @@ -78,7 +79,7 @@ async def chunk_doc( # send chunks to be embedded logger.info("chunk writes to embed_chan") - await embed_chan.put((v_id, chunk, "DocumentChunk")) + await embed_chan.put((chunk_id, chunk, "DocumentChunk")) # send chunks to have entities extracted logger.info("chunk writes to extract_chan")