Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gml 1821 graphrag retriever #263

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CREATE DISTRIBUTED QUERY GraphRAG_CommunityRetriever(INT community_level=2) FOR GRAPH pyTigerGraphRAG {
comms = {Community.*};

selected_comms = SELECT c FROM comms:c WHERE c.iteration == community_level;

PRINT selected_comms;
}
2 changes: 2 additions & 0 deletions common/gsql/supportai/retrievers/HNSW_Search_Content.gsql
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ CREATE DISTRIBUTED QUERY HNSW_Search_Content(STRING v_type, STRING milvus_host,
POST-ACCUM
IF s.type == "Relationship" OR s.type == "Entity" OR s.type == "Concept" THEN
@@final_retrieval += (s.id -> s.definition)
ELSE IF s.type == "Community" THEN
@@final_retrieval += (s.id -> s.description)
END;

PRINT @@final_retrieval;
Expand Down
2 changes: 1 addition & 1 deletion common/logs/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def logToRoot(message, *args, **kwargs):

addLoggingLevel("DEBUG_PII", logging.DEBUG - 5)
log_config = get_log_config()
LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
LOGLEVEL = os.environ.get("LOGLEVEL", logging.INFO)

log_directory = log_config.get("log_file_path", "/tmp/logs")
os.makedirs(log_directory, exist_ok=True)
Expand Down
3 changes: 2 additions & 1 deletion common/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
aiochannel==1.2.1
aiohappyeyeballs==2.3.5
aiohttp==3.10.3
aiosignal==1.3.1
Expand Down Expand Up @@ -100,7 +101,7 @@ minio==7.2.7
multidict==6.0.5
mypy-extensions==1.0.0
nest-asyncio==1.6.0
nltk==3.8.2
nltk==3.8.1
numpy==1.26.4
openai==1.40.6
ordered-set==4.1.0
Expand Down
17 changes: 16 additions & 1 deletion copilot/app/routers/supportai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
HNSWOverlapRetriever,
HNSWRetriever,
HNSWSiblingRetriever,
GraphRAG
)

from common.config import (
Expand Down Expand Up @@ -175,7 +176,11 @@ def search(
embedding_service, embedding_store, get_llm_service(llm_config), conn
)
res = retriever.search(query.question, query.method_params["top_k"])

elif query.method.lower() == "graphrag":
retriever = GraphRAG(
embedding_service, embedding_store, get_llm_service(llm_config), conn
)
res = retriever.search(query.question, query.method_params["community_level"])
return res


Expand Down Expand Up @@ -232,6 +237,16 @@ def answer_question(
embedding_service, embedding_store, get_llm_service(llm_config), conn
)
res = retriever.retrieve_answer(query.question, query.method_params["top_k"])

elif query.method.lower() == "graphrag":
retriever = GraphRAG(
embedding_service, embedding_store, get_llm_service(llm_config), conn
)
res = retriever.retrieve_answer(
query.question,
query.method_params["community_level"],
query.method_params["top_k_answer_candidates"]
)
else:
raise Exception("Method not implemented")

Expand Down
88 changes: 88 additions & 0 deletions copilot/app/supportai/retrievers/GraphRAG.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from supportai.retrievers import BaseRetriever
import asyncio
from concurrent.futures import ThreadPoolExecutor

from common.metrics.tg_proxy import TigerGraphConnectionProxy

from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, validator

from common.llm_services import LLM_Model


class CommunityAnswer(BaseModel):
answer: str = Field(description="The answer to the question, based off of the context provided.")
quality_score: int = Field(description="The quality of the answer, based on how well it answers the question. Rate the answer from 0 (poor) to 100 (excellent).")

output_parser = PydanticOutputParser(pydantic_object=CommunityAnswer)

ANSWER_PROMPT = PromptTemplate(template = """
You are a helpful assistant responsible for generating an answer to the question below using the data provided.
Include a quality score for the answer, based on how well it answers the question. The quality score should be between 0 (poor) and 100 (excellent).

Question: {question}
Context: {context}

{format_instructions}
""",
input_variables=["question", "context"],
partial_variables={"format_instructions": output_parser.get_format_instructions()}
)


class GraphRAG(BaseRetriever):
def __init__(
self,
embedding_service,
embedding_store,
llm_service: LLM_Model,
connection: TigerGraphConnectionProxy,
):
super().__init__(embedding_service, embedding_store, llm_service, connection)
self._check_query_install("GraphRAG_CommunityRetriever")

def search(self, question, community_level: int):
res = self.conn.runInstalledQuery("GraphRAG_CommunityRetriever", {"community_level": community_level}, usePost=True)
return res

async def _generate_candidate(self, question, context):
model = self.llm_service.model



chain = ANSWER_PROMPT | model | output_parser

answer = await chain.ainvoke(
{
"question": question,
"context": context,
}
)
return answer

def gather_candidates(self, question, context):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
tasks = [self._generate_candidate(question, c) for c in context]
res = loop.run_until_complete(asyncio.gather(*tasks))
loop.close()
return res

def retrieve_answer(self,
question: str,
community_level: int,
top_k_answer_candidates: int = 1):
retrieved = self.search(question, community_level)
context = [x["attributes"] for x in retrieved[0]["selected_comms"]]

with ThreadPoolExecutor() as executor:
res = executor.submit(self.gather_candidates, question, context).result()
RobRossmiller-TG marked this conversation as resolved.
Show resolved Hide resolved

# sort list by quality score
res.sort(key=lambda x: x.quality_score, reverse=True)

new_context = [{"candidate_answer": x.answer,
"score": x.quality_score} for x in res[:top_k_answer_candidates]]

return self._generate_response(question, new_context)
1 change: 1 addition & 0 deletions copilot/app/supportai/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .HNSWOverlapRetriever import HNSWOverlapRetriever
from .HNSWSiblingRetriever import HNSWSiblingRetriever
from .EntityRelationshipRetriever import EntityRelationshipRetriever
from .GraphRAG import GraphRAG
4 changes: 2 additions & 2 deletions copilot/docs/notebooks/SupportAIDemo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
}
],
"source": [
"conn.ai.forceConsistencyUpdate()"
"conn.ai.forceConsistencyUpdate(method=\"graphrag\")"
]
},
{
Expand Down Expand Up @@ -546,7 +546,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions eventual-consistency-service/.dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Dockerfile.tests
docs
tests
udfs
__pycache__
90 changes: 50 additions & 40 deletions eventual-consistency-service/app/graphrag/graph_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
init,
make_headers,
stream_ids,
tg_sem,
)
from pyTigerGraph import TigerGraphConnection

Expand Down Expand Up @@ -44,11 +45,12 @@ async def stream_docs(

for d in doc_ids["ids"]:
try:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/StreamDocContent/",
params={"doc": d},
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/StreamDocContent/",
params={"doc": d},
headers=headers,
)
if res.status_code != 200:
# continue to the next doc.
# This doc will not be marked as processed, so the ecc will process it eventually.
Expand Down Expand Up @@ -85,7 +87,7 @@ async def chunk_docs(
txt = content["attributes"]["text"]
# send the document to be embedded
logger.info("chunk writes to extract")
await embed_chan.put((v_id, txt, "Document"))
# await embed_chan.put((v_id, txt, "Document"))

task = grp.create_task(
workers.chunk_doc(conn, content, upsert_chan, embed_chan, extract_chan)
Expand Down Expand Up @@ -221,10 +223,11 @@ async def resolve_entities(
# Copy RELATIONSHIP edges to RESOLVED_RELATIONSHIP
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=http_timeout) as client:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/ResolveRelationships/",
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/ResolveRelationships/",
headers=headers,
)
res.raise_for_status()


Expand All @@ -236,19 +239,21 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel):
logger.info("Initializing Communities (first louvain pass)")
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=None) as client:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_init",
params={"n_batches": 1},
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_init",
params={"n_batches": 1},
headers=headers,
)
res.raise_for_status()
# get the modularity
async with httpx.AsyncClient(timeout=None) as client:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/modularity",
params={"iteration": 1, "batch_num": 1},
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/modularity",
params={"iteration": 1, "batch_num": 1},
headers=headers,
)
res.raise_for_status()
mod = res.json()["results"][0]["mod"]
logger.info(f"****mod pass 1: {mod}")
Expand All @@ -263,21 +268,23 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel):
logger.info(f"Running louvain on Communities (iteration: {i})")
# louvain pass
async with httpx.AsyncClient(timeout=None) as client:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_communities",
params={"n_batches": 1, "iteration": i},
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/graphrag_louvain_communities",
params={"n_batches": 1, "iteration": i},
headers=headers,
)

res.raise_for_status()

# get the modularity
async with httpx.AsyncClient(timeout=None) as client:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/modularity",
params={"iteration": i + 1, "batch_num": 1},
headers=headers,
)
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/modularity",
params={"iteration": i + 1, "batch_num": 1},
headers=headers,
)
res.raise_for_status()
mod = res.json()["results"][0]["mod"]
logger.info(f"*** mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})")
Expand Down Expand Up @@ -307,11 +314,12 @@ async def stream_communities(
# async for i in community_chan:
# get the community from that layer
async with httpx.AsyncClient(timeout=None) as client:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/stream_community",
params={"iter": i},
headers=headers,
)
async with tg_sem:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/stream_community",
params={"iter": i},
headers=headers,
)
resp.raise_for_status()
comms = resp.json()["results"][0]["Comms"]

Expand Down Expand Up @@ -345,7 +353,7 @@ async def summarize_communities(
embed_chan.close()


async def run(graphname: str, conn: TigerGraphConnection):
async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
"""
Set up GraphRAG:
- Install necessary queries.
Expand All @@ -362,8 +370,8 @@ async def run(graphname: str, conn: TigerGraphConnection):
init_start = time.perf_counter()

doc_process_switch = True
entity_resolution_switch = True
community_detection_switch = True
entity_resolution_switch =True
community_detection_switch =True
if doc_process_switch:
logger.info("Doc Processing Start")
docs_chan = Channel(1)
Expand All @@ -378,7 +386,7 @@ async def run(graphname: str, conn: TigerGraphConnection):
chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan)
)
# upsert chunks
grp.create_task(upsert(upsert_chan))
grp.create_task(upsert( upsert_chan))
# embed
grp.create_task(embed(embed_chan, index_stores, graphname))
# extract entities
Expand Down Expand Up @@ -437,5 +445,7 @@ async def run(graphname: str, conn: TigerGraphConnection):
end = time.perf_counter()
logger.info(f"DONE. graphrag system initializer dT: {init_end-init_start}")
logger.info(f"DONE. graphrag entity resolution dT: {entity_end-entity_start}")
logger.info(f"DONE. graphrag community initializer dT: {community_end-community_start}")
logger.info(
f"DONE. graphrag community initializer dT: {community_end-community_start}"
)
logger.info(f"DONE. graphrag.run() total time elaplsed: {end-init_start}")
Loading
Loading