Skip to content

Commit

Permalink
Merge pull request #263 from tigergraph/GML-1821-graphrag-retriever
Browse files Browse the repository at this point in the history
Gml 1821 graphrag retriever
  • Loading branch information
parkererickson-tg authored Aug 20, 2024
2 parents 43c11fe + 0f7d8eb commit 1b8b7ba
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 82 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CREATE DISTRIBUTED QUERY GraphRAG_CommunityRetriever(INT community_level=2) FOR GRAPH pyTigerGraphRAG {
comms = {Community.*};

selected_comms = SELECT c FROM comms:c WHERE c.iteration == community_level;

PRINT selected_comms;
}
2 changes: 2 additions & 0 deletions common/gsql/supportai/retrievers/HNSW_Search_Content.gsql
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ CREATE DISTRIBUTED QUERY HNSW_Search_Content(STRING v_type, STRING milvus_host,
POST-ACCUM
IF s.type == "Relationship" OR s.type == "Entity" OR s.type == "Concept" THEN
@@final_retrieval += (s.id -> s.definition)
ELSE IF s.type == "Community" THEN
@@final_retrieval += (s.id -> s.description)
END;

PRINT @@final_retrieval;
Expand Down
2 changes: 1 addition & 1 deletion common/logs/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def logToRoot(message, *args, **kwargs):

addLoggingLevel("DEBUG_PII", logging.DEBUG - 5)
log_config = get_log_config()
LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
LOGLEVEL = os.environ.get("LOGLEVEL", logging.INFO)

log_directory = log_config.get("log_file_path", "/tmp/logs")
os.makedirs(log_directory, exist_ok=True)
Expand Down
3 changes: 2 additions & 1 deletion common/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
aiochannel==1.2.1
aiohappyeyeballs==2.3.5
aiohttp==3.10.3
aiosignal==1.3.1
Expand Down Expand Up @@ -100,7 +101,7 @@ minio==7.2.7
multidict==6.0.5
mypy-extensions==1.0.0
nest-asyncio==1.6.0
nltk==3.8.2
nltk==3.8.1
numpy==1.26.4
openai==1.40.6
ordered-set==4.1.0
Expand Down
17 changes: 16 additions & 1 deletion copilot/app/routers/supportai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
HNSWOverlapRetriever,
HNSWRetriever,
HNSWSiblingRetriever,
GraphRAG
)

from common.config import (
Expand Down Expand Up @@ -175,7 +176,11 @@ def search(
embedding_service, embedding_store, get_llm_service(llm_config), conn
)
res = retriever.search(query.question, query.method_params["top_k"])

elif query.method.lower() == "graphrag":
retriever = GraphRAG(
embedding_service, embedding_store, get_llm_service(llm_config), conn
)
res = retriever.search(query.question, query.method_params["community_level"])
return res


Expand Down Expand Up @@ -232,6 +237,16 @@ def answer_question(
embedding_service, embedding_store, get_llm_service(llm_config), conn
)
res = retriever.retrieve_answer(query.question, query.method_params["top_k"])

elif query.method.lower() == "graphrag":
retriever = GraphRAG(
embedding_service, embedding_store, get_llm_service(llm_config), conn
)
res = retriever.retrieve_answer(
query.question,
query.method_params["community_level"],
query.method_params["top_k_answer_candidates"]
)
else:
raise Exception("Method not implemented")

Expand Down
88 changes: 88 additions & 0 deletions copilot/app/supportai/retrievers/GraphRAG.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from supportai.retrievers import BaseRetriever
import asyncio
from concurrent.futures import ThreadPoolExecutor

from common.metrics.tg_proxy import TigerGraphConnectionProxy

from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, validator

from common.llm_services import LLM_Model


class CommunityAnswer(BaseModel):
answer: str = Field(description="The answer to the question, based off of the context provided.")
quality_score: int = Field(description="The quality of the answer, based on how well it answers the question. Rate the answer from 0 (poor) to 100 (excellent).")

output_parser = PydanticOutputParser(pydantic_object=CommunityAnswer)

ANSWER_PROMPT = PromptTemplate(template = """
You are a helpful assistant responsible for generating an answer to the question below using the data provided.
Include a quality score for the answer, based on how well it answers the question. The quality score should be between 0 (poor) and 100 (excellent).
Question: {question}
Context: {context}
{format_instructions}
""",
input_variables=["question", "context"],
partial_variables={"format_instructions": output_parser.get_format_instructions()}
)


class GraphRAG(BaseRetriever):
def __init__(
self,
embedding_service,
embedding_store,
llm_service: LLM_Model,
connection: TigerGraphConnectionProxy,
):
super().__init__(embedding_service, embedding_store, llm_service, connection)
self._check_query_install("GraphRAG_CommunityRetriever")

def search(self, question, community_level: int):
res = self.conn.runInstalledQuery("GraphRAG_CommunityRetriever", {"community_level": community_level}, usePost=True)
return res

async def _generate_candidate(self, question, context):
model = self.llm_service.model



chain = ANSWER_PROMPT | model | output_parser

answer = await chain.ainvoke(
{
"question": question,
"context": context,
}
)
return answer

def gather_candidates(self, question, context):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
tasks = [self._generate_candidate(question, c) for c in context]
res = loop.run_until_complete(asyncio.gather(*tasks))
loop.close()
return res

def retrieve_answer(self,
question: str,
community_level: int,
top_k_answer_candidates: int = 1):
retrieved = self.search(question, community_level)
context = [x["attributes"] for x in retrieved[0]["selected_comms"]]

with ThreadPoolExecutor() as executor:
res = executor.submit(self.gather_candidates, question, context).result()

# sort list by quality score
res.sort(key=lambda x: x.quality_score, reverse=True)

new_context = [{"candidate_answer": x.answer,
"score": x.quality_score} for x in res[:top_k_answer_candidates]]

return self._generate_response(question, new_context)
1 change: 1 addition & 0 deletions copilot/app/supportai/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .HNSWOverlapRetriever import HNSWOverlapRetriever
from .HNSWSiblingRetriever import HNSWSiblingRetriever
from .EntityRelationshipRetriever import EntityRelationshipRetriever
from .GraphRAG import GraphRAG
4 changes: 2 additions & 2 deletions copilot/docs/notebooks/SupportAIDemo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
}
],
"source": [
"conn.ai.forceConsistencyUpdate()"
"conn.ai.forceConsistencyUpdate(method=\"graphrag\")"
]
},
{
Expand Down Expand Up @@ -546,7 +546,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion copilot/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ minio==7.2.7
multidict==6.0.5
mypy-extensions==1.0.0
nest-asyncio==1.6.0
nltk==3.8.2
nltk==3.8.1
numpy==1.26.4
openai==1.40.6
ordered-set==4.1.0
Expand Down
1 change: 1 addition & 0 deletions eventual-consistency-service/.dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Dockerfile.tests
docs
tests
udfs
__pycache__
90 changes: 50 additions & 40 deletions eventual-consistency-service/app/graphrag/graph_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
init,
make_headers,
stream_ids,
tg_sem,
)
from pyTigerGraph import TigerGraphConnection

Expand Down Expand Up @@ -44,11 +45,12 @@ async def stream_docs(

for d in doc_ids["ids"]:
try:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/StreamDocContent/",
params={"doc": d},
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/StreamDocContent/",
params={"doc": d},
headers=headers,
)
if res.status_code != 200:
# continue to the next doc.
# This doc will not be marked as processed, so the ecc will process it eventually.
Expand Down Expand Up @@ -85,7 +87,7 @@ async def chunk_docs(
txt = content["attributes"]["text"]
# send the document to be embedded
logger.info("chunk writes to extract")
await embed_chan.put((v_id, txt, "Document"))
# await embed_chan.put((v_id, txt, "Document"))

task = grp.create_task(
workers.chunk_doc(conn, content, upsert_chan, embed_chan, extract_chan)
Expand Down Expand Up @@ -221,10 +223,11 @@ async def resolve_entities(
# Copy RELATIONSHIP edges to RESOLVED_RELATIONSHIP
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=http_timeout) as client:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/ResolveRelationships/",
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/ResolveRelationships/",
headers=headers,
)
res.raise_for_status()


Expand All @@ -236,19 +239,21 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel):
logger.info("Initializing Communities (first louvain pass)")
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=None) as client:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_init",
params={"n_batches": 1},
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_init",
params={"n_batches": 1},
headers=headers,
)
res.raise_for_status()
# get the modularity
async with httpx.AsyncClient(timeout=None) as client:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/modularity",
params={"iteration": 1, "batch_num": 1},
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/modularity",
params={"iteration": 1, "batch_num": 1},
headers=headers,
)
res.raise_for_status()
mod = res.json()["results"][0]["mod"]
logger.info(f"****mod pass 1: {mod}")
Expand All @@ -263,21 +268,23 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel):
logger.info(f"Running louvain on Communities (iteration: {i})")
# louvain pass
async with httpx.AsyncClient(timeout=None) as client:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_communities",
params={"n_batches": 1, "iteration": i},
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_communities",
params={"n_batches": 1, "iteration": i},
headers=headers,
)

res.raise_for_status()

# get the modularity
async with httpx.AsyncClient(timeout=None) as client:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/modularity",
params={"iteration": i + 1, "batch_num": 1},
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/modularity",
params={"iteration": i + 1, "batch_num": 1},
headers=headers,
)
res.raise_for_status()
mod = res.json()["results"][0]["mod"]
logger.info(f"*** mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})")
Expand Down Expand Up @@ -307,11 +314,12 @@ async def stream_communities(
# async for i in community_chan:
# get the community from that layer
async with httpx.AsyncClient(timeout=None) as client:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/stream_community",
params={"iter": i},
headers=headers,
)
async with tg_sem:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/stream_community",
params={"iter": i},
headers=headers,
)
resp.raise_for_status()
comms = resp.json()["results"][0]["Comms"]

Expand Down Expand Up @@ -345,7 +353,7 @@ async def summarize_communities(
embed_chan.close()


async def run(graphname: str, conn: TigerGraphConnection):
async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
"""
Set up GraphRAG:
- Install necessary queries.
Expand All @@ -362,8 +370,8 @@ async def run(graphname: str, conn: TigerGraphConnection):
init_start = time.perf_counter()

doc_process_switch = True
entity_resolution_switch = True
community_detection_switch = True
entity_resolution_switch =True
community_detection_switch =True
if doc_process_switch:
logger.info("Doc Processing Start")
docs_chan = Channel(1)
Expand All @@ -378,7 +386,7 @@ async def run(graphname: str, conn: TigerGraphConnection):
chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan)
)
# upsert chunks
grp.create_task(upsert(upsert_chan))
grp.create_task(upsert( upsert_chan))
# embed
grp.create_task(embed(embed_chan, index_stores, graphname))
# extract entities
Expand Down Expand Up @@ -437,5 +445,7 @@ async def run(graphname: str, conn: TigerGraphConnection):
end = time.perf_counter()
logger.info(f"DONE. graphrag system initializer dT: {init_end-init_start}")
logger.info(f"DONE. graphrag entity resolution dT: {entity_end-entity_start}")
logger.info(f"DONE. graphrag community initializer dT: {community_end-community_start}")
logger.info(
f"DONE. graphrag community initializer dT: {community_end-community_start}"
)
logger.info(f"DONE. graphrag.run() total time elaplsed: {end-init_start}")
Loading

0 comments on commit 1b8b7ba

Please sign in to comment.