diff --git a/common/config.py b/common/config.py index 9f1d3cab..4548aa3a 100644 --- a/common/config.py +++ b/common/config.py @@ -187,7 +187,7 @@ def get_llm_service(llm_config) -> LLM_Model: ): doc_processing_config = { "chunker": "semantic", - "chunker_config": {"method": "percentile", "threshold": 0.95}, + "chunker_config": {"method": "percentile", "threshold": 0.90}, "extractor": "graphrag", "extractor_config": {}, } diff --git a/common/gsql/supportai/retrievers/GraphRAG_Community_Retriever.gsql b/common/gsql/supportai/retrievers/GraphRAG_Community_Retriever.gsql index 2d6ef9b0..97e44d10 100644 --- a/common/gsql/supportai/retrievers/GraphRAG_Community_Retriever.gsql +++ b/common/gsql/supportai/retrievers/GraphRAG_Community_Retriever.gsql @@ -1,4 +1,4 @@ -CREATE DISTRIBUTED QUERY GraphRAG_CommunityRetriever(INT community_level=2) FOR GRAPH pyTigerGraphRAG { +CREATE DISTRIBUTED QUERY GraphRAG_Community_Retriever(INT community_level=2) { comms = {Community.*}; selected_comms = SELECT c FROM comms:c WHERE c.iteration == community_level; diff --git a/copilot/app/supportai/retrievers/GraphRAG.py b/copilot/app/supportai/retrievers/GraphRAG.py index 442f8fcb..4a973dc8 100644 --- a/copilot/app/supportai/retrievers/GraphRAG.py +++ b/copilot/app/supportai/retrievers/GraphRAG.py @@ -40,10 +40,10 @@ def __init__( connection: TigerGraphConnectionProxy, ): super().__init__(embedding_service, embedding_store, llm_service, connection) - self._check_query_install("GraphRAG_CommunityRetriever") + self._check_query_install("GraphRAG_Community_Retriever") def search(self, question, community_level: int): - res = self.conn.runInstalledQuery("GraphRAG_CommunityRetriever", {"community_level": community_level}, usePost=True) + res = self.conn.runInstalledQuery("GraphRAG_Community_Retriever", {"community_level": community_level}, usePost=True) return res async def _generate_candidate(self, question, context): diff --git a/eventual-consistency-service/app/graphrag/graph_rag.py b/eventual-consistency-service/app/graphrag/graph_rag.py index e4457a77..71e2f0f0 100644 --- a/eventual-consistency-service/app/graphrag/graph_rag.py +++ b/eventual-consistency-service/app/graphrag/graph_rag.py @@ -14,6 +14,7 @@ http_timeout, init, load_q, + loading_event, make_headers, stream_ids, tg_sem, @@ -124,7 +125,7 @@ async def upsert(upsert_chan: Channel): async def load(conn: TigerGraphConnection): logger.info("Reading from load_q") dd = lambda: defaultdict(dd) # infinite default dict - batch_size = 1000 + batch_size = 500 # while the load q is still open or has contents while not load_q.closed() or not load_q.empty(): if load_q.closed(): @@ -169,7 +170,11 @@ async def load(conn: TigerGraphConnection): logger.info( f"Upserting batch size of {size}. ({n_verts} verts | {n_edges} edges. {len(data.encode())/1000:,} kb)" ) + + loading_event.clear() await upsert_batch(conn, data) + await asyncio.sleep(5) + loading_event.set() else: await asyncio.sleep(1) @@ -435,12 +440,12 @@ async def run(graphname: str, conn: TigerGraphConnection): if doc_process_switch: logger.info("Doc Processing Start") docs_chan = Channel(1) - embed_chan = Channel(100) - upsert_chan = Channel(100) - extract_chan = Channel(100) + embed_chan = Channel() + upsert_chan = Channel() + extract_chan = Channel() async with asyncio.TaskGroup() as grp: # get docs - grp.create_task(stream_docs(conn, docs_chan, 10)) + grp.create_task(stream_docs(conn, docs_chan, 100)) # process docs grp.create_task( chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan) @@ -462,8 +467,8 @@ async def run(graphname: str, conn: TigerGraphConnection): if entity_resolution_switch: logger.info("Entity Processing Start") - entities_chan = Channel(100) - upsert_chan = Channel(100) + entities_chan = Channel() + upsert_chan = Channel() load_q.reopen() async with asyncio.TaskGroup() as grp: grp.create_task(stream_entities(conn, entities_chan, 50)) @@ -487,10 +492,10 @@ async def run(graphname: str, conn: TigerGraphConnection): community_start = time.perf_counter() if community_detection_switch: logger.info("Community Processing Start") - upsert_chan = Channel(10) - comm_process_chan = Channel(100) - upsert_chan = Channel(100) - embed_chan = Channel(100) + upsert_chan = Channel() + comm_process_chan = Channel() + upsert_chan = Channel() + embed_chan = Channel() load_q.reopen() async with asyncio.TaskGroup() as grp: # run louvain diff --git a/eventual-consistency-service/app/graphrag/util.py b/eventual-consistency-service/app/graphrag/util.py index 5838ec7e..35f5bcdf 100644 --- a/eventual-consistency-service/app/graphrag/util.py +++ b/eventual-consistency-service/app/graphrag/util.py @@ -6,6 +6,9 @@ from glob import glob import httpx +from graphrag import reusable_channel, workers +from pyTigerGraph import TigerGraphConnection + from common.config import ( doc_processing_config, embedding_service, @@ -17,15 +20,17 @@ from common.extractors import GraphExtractor, LLMEntityRelationshipExtractor from common.extractors.BaseExtractor import BaseExtractor from common.logs.logwriter import LogWriter -from graphrag import reusable_channel, workers -from pyTigerGraph import TigerGraphConnection logger = logging.getLogger(__name__) http_timeout = httpx.Timeout(15.0) -tg_sem = asyncio.Semaphore(20) +tg_sem = asyncio.Semaphore(2) load_q = reusable_channel.ReuseableChannel() +# will pause workers until the event is false +loading_event = asyncio.Event() +loading_event.set() # set the event to true to allow the workers to run + async def install_queries( requried_queries: list[str], conn: TigerGraphConnection, @@ -207,7 +212,6 @@ async def upsert_batch(conn: TigerGraphConnection, data: str): res.raise_for_status() - async def check_vertex_exists(conn, v_id: str): headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: @@ -219,7 +223,8 @@ async def check_vertex_exists(conn, v_id: str): ) except Exception as e: - logger.error(f"Check err:\n{e}") + err = traceback.format_exc() + logger.error(f"Check err:\n{err}") return {"error": True} try: @@ -264,17 +269,25 @@ async def get_commuinty_children(conn, i: int, c: str): headers = make_headers(conn) async with httpx.AsyncClient(timeout=None) as client: async with tg_sem: - resp = await client.get( - f"{conn.restppUrl}/query/{conn.graphname}/get_community_children", - params={"comm": c, "iter": i}, - headers=headers, - ) + try: + resp = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/get_community_children", + params={"comm": c, "iter": i}, + headers=headers, + ) + except: + logger.error(f"Get Children err:\n{traceback.format_exc()}") try: resp.raise_for_status() except Exception as e: logger.error(f"Get Children err:\n{e}") descrs = [] - for d in resp.json()["results"][0]["children"]: + try: + res = resp.json()["results"][0]["children"] + except Exception as e: + logger.error(f"Get Children err:\n{e}") + res = [] + for d in res: desc = d["attributes"]["description"] # if it's the entity iteration if i == 1: diff --git a/eventual-consistency-service/app/graphrag/workers.py b/eventual-consistency-service/app/graphrag/workers.py index 98b3e69c..c00b9187 100644 --- a/eventual-consistency-service/app/graphrag/workers.py +++ b/eventual-consistency-service/app/graphrag/workers.py @@ -1,3 +1,4 @@ +import asyncio import base64 import logging import time @@ -7,14 +8,15 @@ import ecc_util import httpx from aiochannel import Channel +from graphrag import community_summarizer, util +from langchain_community.graphs.graph_document import GraphDocument, Node +from pyTigerGraph import TigerGraphConnection + from common.config import milvus_config from common.embeddings.embedding_services import EmbeddingModel from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore from common.extractors.BaseExtractor import BaseExtractor from common.logs.logwriter import LogWriter -from graphrag import community_summarizer, util -from langchain_community.graphs.graph_document import GraphDocument, Node -from pyTigerGraph import TigerGraphConnection vertex_field = milvus_config.get("vertex_field", "vertex_id") @@ -55,6 +57,9 @@ async def install_query( return {"result": res, "error": False} +chunk_sem = asyncio.Semaphore(20) + + async def chunk_doc( conn: TigerGraphConnection, doc: dict[str, str], @@ -67,23 +72,30 @@ async def chunk_doc( Places the resulting chunks into the upsert channel (to be upserted to TG) and the embed channel (to be embedded and written to the vector store) """ - chunker = ecc_util.get_chunker() - chunks = chunker.chunk(doc["attributes"]["text"]) - v_id = util.process_id(doc["v_id"]) - logger.info(f"Chunking {v_id}") - for i, chunk in enumerate(chunks): - chunk_id = f"{v_id}_chunk_{i}" - # send chunks to be upserted (func, args) - logger.info("chunk writes to upsert_chan") - await upsert_chan.put((upsert_chunk, (conn, v_id, chunk_id, chunk))) - - # send chunks to be embedded - logger.info("chunk writes to embed_chan") - await embed_chan.put((chunk_id, chunk, "DocumentChunk")) - - # send chunks to have entities extracted - logger.info("chunk writes to extract_chan") - await extract_chan.put((chunk, chunk_id)) + + # if loader is running, wait until it's done + if not util.loading_event.is_set(): + logger.info("Chunk worker waiting for loading event to finish") + await util.loading_event.wait() + + async with chunk_sem: + chunker = ecc_util.get_chunker() + chunks = chunker.chunk(doc["attributes"]["text"]) + v_id = util.process_id(doc["v_id"]) + logger.info(f"Chunking {v_id}") + for i, chunk in enumerate(chunks): + chunk_id = f"{v_id}_chunk_{i}" + # send chunks to be upserted (func, args) + logger.info("chunk writes to upsert_chan") + await upsert_chan.put((upsert_chunk, (conn, v_id, chunk_id, chunk))) + + # send chunks to have entities extracted + logger.info("chunk writes to extract_chan") + await extract_chan.put((chunk, chunk_id)) + + # send chunks to be embedded + logger.info("chunk writes to embed_chan") + await embed_chan.put((chunk_id, chunk, "DocumentChunk")) return doc["v_id"] @@ -120,6 +132,9 @@ async def upsert_chunk(conn: TigerGraphConnection, doc_id, chunk_id, chunk): ) +embed_sem = asyncio.Semaphore(20) + + async def embed( embed_svc: EmbeddingModel, embed_store: MilvusEmbeddingStore, @@ -141,10 +156,22 @@ async def embed( index_name: str the vertex index to write to """ - logger.info(f"Embedding {v_id}") - - vec = await embed_svc.aembed_query(content) - await embed_store.aadd_embeddings([(content, vec)], [{vertex_field: v_id}]) + async with embed_sem: + logger.info(f"Embedding {v_id}") + + # if loader is running, wait until it's done + if not util.loading_event.is_set(): + logger.info("Embed worker waiting for loading event to finish") + await util.loading_event.wait() + try: + vec = await embed_svc.aembed_query(content) + except Exception as e: + logger.error(f"Failed to embed {v_id}: {e}") + return + try: + await embed_store.aadd_embeddings([(content, vec)], [{vertex_field: v_id}]) + except Exception as e: + logger.error(f"Failed to add embeddings for {v_id}: {e}") async def get_vert_desc(conn, v_id, node: Node): @@ -158,6 +185,9 @@ async def get_vert_desc(conn, v_id, node: Node): return desc +extract_sem = asyncio.Semaphore(20) + + async def extract( upsert_chan: Channel, embed_chan: Channel, @@ -166,117 +196,133 @@ async def extract( chunk: str, chunk_id: str, ): - logger.info(f"Extracting chunk: {chunk_id}") - extracted: list[GraphDocument] = await extractor.aextract(chunk) - # upsert nodes and edges to the graph - for doc in extracted: - for node in doc.nodes: - logger.info(f"extract writes entity vert to upsert\nNode: {node.id}") - v_id = util.process_id(str(node.id)) - if len(v_id) == 0: - continue - desc = await get_vert_desc(conn, v_id, node) - - # embed the entity - # embed with the v_id if the description is blank - if len(desc[0]) == 0: - await embed_chan.put((v_id, v_id, "Entity")) - else: - # (v_id, content, index_name) - await embed_chan.put((v_id, desc[0], "Entity")) - - await upsert_chan.put( - ( - util.upsert_vertex, # func to call + # if loader is running, wait until it's done + if not util.loading_event.is_set(): + logger.info("Extract worker waiting for loading event to finish") + await util.loading_event.wait() + + async with extract_sem: + try: + extracted: list[GraphDocument] = await extractor.aextract(chunk) + logger.info( + f"Extracting chunk: {chunk_id} ({len(extracted)} graph docs extracted)" + ) + except Exception as e: + logger.error(f"Failed to extract chunk {chunk_id}: {e}") + extracted = [] + + # upsert nodes and edges to the graph + for doc in extracted: + for node in doc.nodes: + logger.info(f"extract writes entity vert to upsert\nNode: {node.id}") + v_id = util.process_id(str(node.id)) + if len(v_id) == 0: + continue + desc = await get_vert_desc(conn, v_id, node) + + # embed the entity + # embed with the v_id if the description is blank + if len(desc[0]) == 0: + await embed_chan.put((v_id, v_id, "Entity")) + else: + # (v_id, content, index_name) + await embed_chan.put((v_id, desc[0], "Entity")) + + await upsert_chan.put( ( - conn, - "Entity", # v_type - v_id, # v_id - { # attrs - "description": desc, - "epoch_added": int(time.time()), - }, - ), + util.upsert_vertex, # func to call + ( + conn, + "Entity", # v_type + v_id, # v_id + { # attrs + "description": desc, + "epoch_added": int(time.time()), + }, + ), + ) ) - ) - # link the entity to the chunk it came from - logger.info("extract writes contains edge to upsert") - await upsert_chan.put( - ( - util.upsert_edge, + # link the entity to the chunk it came from + logger.info("extract writes contains edge to upsert") + await upsert_chan.put( ( - conn, - "DocumentChunk", # src_type - chunk_id, # src_id - "CONTAINS_ENTITY", # edge_type - "Entity", # tgt_type - v_id, # tgt_id - None, # attributes - ), + util.upsert_edge, + ( + conn, + "DocumentChunk", # src_type + chunk_id, # src_id + "CONTAINS_ENTITY", # edge_type + "Entity", # tgt_type + v_id, # tgt_id + None, # attributes + ), + ) ) - ) - for edge in doc.relationships: - logger.info( - f"extract writes relates edge to upsert:{edge.source.id} -({edge.type})-> {edge.target.id}" - ) - # upsert verts first to make sure their ID becomes an attr - v_id = util.process_id(edge.source.id) # src_id - if len(v_id) == 0: - continue - desc = await get_vert_desc(conn, v_id, edge.source) - await upsert_chan.put( - ( - util.upsert_vertex, # func to call + for edge in doc.relationships: + logger.info( + f"extract writes relates edge to upsert:{edge.source.id} -({edge.type})-> {edge.target.id}" + ) + # upsert verts first to make sure their ID becomes an attr + v_id = util.process_id(edge.source.id) # src_id + if len(v_id) == 0: + continue + desc = await get_vert_desc(conn, v_id, edge.source) + await upsert_chan.put( ( - conn, - "Entity", # v_type - v_id, - { # attrs - "description": desc, - "epoch_added": int(time.time()), - }, - ), + util.upsert_vertex, # func to call + ( + conn, + "Entity", # v_type + v_id, + { # attrs + "description": desc, + "epoch_added": int(time.time()), + }, + ), + ) ) - ) - v_id = util.process_id(edge.target.id) - if len(v_id) == 0: - continue - desc = await get_vert_desc(conn, v_id, edge.target) - await upsert_chan.put( - ( - util.upsert_vertex, # func to call + v_id = util.process_id(edge.target.id) + if len(v_id) == 0: + continue + desc = await get_vert_desc(conn, v_id, edge.target) + await upsert_chan.put( ( - conn, - "Entity", # v_type - v_id, # src_id - { # attrs - "description": desc, - "epoch_added": int(time.time()), - }, - ), + util.upsert_vertex, # func to call + ( + conn, + "Entity", # v_type + v_id, # src_id + { # attrs + "description": desc, + "epoch_added": int(time.time()), + }, + ), + ) ) - ) - # upsert the edge between the two entities - await upsert_chan.put( - ( - util.upsert_edge, + # upsert the edge between the two entities + await upsert_chan.put( ( - conn, - "Entity", # src_type - util.process_id(edge.source.id), # src_id - "RELATIONSHIP", # edgeType - "Entity", # tgt_type - util.process_id(edge.target.id), # tgt_id - {"relation_type": edge.type}, # attributes - ), + util.upsert_edge, + ( + conn, + "Entity", # src_type + util.process_id(edge.source.id), # src_id + "RELATIONSHIP", # edgeType + "Entity", # tgt_type + util.process_id(edge.target.id), # tgt_id + {"relation_type": edge.type}, # attributes + ), + ) ) - ) - # embed "Relationship", - # (v_id, content, index_name) - # right now, we're not embedding relationships in graphrag + # embed "Relationship", + # (v_id, content, index_name) + # right now, we're not embedding relationships in graphrag + + +resolve_sem = asyncio.Semaphore(20) async def resolve_entity( @@ -295,58 +341,68 @@ async def resolve_entity( mark as processed """ - try: - results = await emb_store.aget_k_closest(entity_id) - except Exception: - err = traceback.format_exc() - logger.error(err) - return + # if loader is running, wait until it's done + if not util.loading_event.is_set(): + logger.info("Entity Resolution worker waiting for loading event to finish") + await util.loading_event.wait() - if len(results) == 0: - logger.error( - f"aget_k_closest should, minimally, return the entity itself.\n{results}" - ) - raise Exception() - - # merge all entities into the ResolvedEntity vertex - # use the longest v_id as the resolved entity's v_id - resolved_entity_id = entity_id - for v in results: - if len(v) > len(resolved_entity_id): - resolved_entity_id = v - - # upsert the resolved entity - await upsert_chan.put( - ( - util.upsert_vertex, # func to call - ( - conn, - "ResolvedEntity", # v_type - resolved_entity_id, # v_id - { # attrs - }, - ), - ) - ) + async with resolve_sem: + try: + results = await emb_store.aget_k_closest(entity_id) + + except Exception: + err = traceback.format_exc() + logger.error(err) + return + + if len(results) == 0: + logger.error( + f"aget_k_closest should, minimally, return the entity itself.\n{results}" + ) + raise Exception() + + # merge all entities into the ResolvedEntity vertex + # use the longest v_id as the resolved entity's v_id + resolved_entity_id = entity_id + for v in results: + if len(v) > len(resolved_entity_id): + resolved_entity_id = v - # create RESOLVES_TO edges from each entity to the ResolvedEntity - for v in results: + # upsert the resolved entity await upsert_chan.put( ( - util.upsert_edge, + util.upsert_vertex, # func to call ( conn, - "Entity", # src_type - v, # src_id - "RESOLVES_TO", # edge_type - "ResolvedEntity", # tgt_type - resolved_entity_id, # tgt_id - None, # attributes + "ResolvedEntity", # v_type + resolved_entity_id, # v_id + { # attrs + }, ), ) ) + # create RESOLVES_TO edges from each entity to the ResolvedEntity + for v in results: + await upsert_chan.put( + ( + util.upsert_edge, + ( + conn, + "Entity", # src_type + v, # src_id + "RESOLVES_TO", # edge_type + "ResolvedEntity", # tgt_type + resolved_entity_id, # tgt_id + None, # attributes + ), + ) + ) + + +comm_sem = asyncio.Semaphore(20) + async def process_community( conn: TigerGraphConnection, @@ -363,35 +419,40 @@ async def process_community( embed summaries """ - - logger.info(f"Processing Community: {comm_id}") - # get the children of the community - children = await util.get_commuinty_children(conn, i, comm_id) - comm_id = util.process_id(comm_id) - - # if the community only has one child, use its description - if len(children) == 1: - summary = children[0] - else: - llm = ecc_util.get_llm_service() - summarizer = community_summarizer.CommunitySummarizer(llm) - summary = await summarizer.summarize(comm_id, children) - - logger.debug(f"Community {comm_id}: {children}, {summary}") - await upsert_chan.put( - ( - util.upsert_vertex, # func to call + # if loader is running, wait until it's done + if not util.loading_event.is_set(): + logger.info("Process Community worker waiting for loading event to finish") + await util.loading_event.wait() + + async with comm_sem: + logger.info(f"Processing Community: {comm_id}") + # get the children of the community + children = await util.get_commuinty_children(conn, i, comm_id) + comm_id = util.process_id(comm_id) + + # if the community only has one child, use its description + if len(children) == 1: + summary = children[0] + else: + llm = ecc_util.get_llm_service() + summarizer = community_summarizer.CommunitySummarizer(llm) + summary = await summarizer.summarize(comm_id, children) + + logger.debug(f"Community {comm_id}: {children}, {summary}") + await upsert_chan.put( ( - conn, - "Community", # v_type - comm_id, # v_id - { # attrs - "description": summary, - "iteration": i, - }, - ), + util.upsert_vertex, # func to call + ( + conn, + "Community", # v_type + comm_id, # v_id + { # attrs + "description": summary, + "iteration": i, + }, + ), + ) ) - ) - # (v_id, content, index_name) - await embed_chan.put((comm_id, summary, "Community")) + # (v_id, content, index_name) + await embed_chan.put((comm_id, summary, "Community"))