diff --git a/common/config.py b/common/config.py index 8812016c..9f1d3cab 100644 --- a/common/config.py +++ b/common/config.py @@ -92,8 +92,6 @@ "MILVUS_CONFIG must be a .json file or a JSON string, failed with error: " + str(e) ) - - if llm_config["embedding_service"]["embedding_model_service"].lower() == "openai": embedding_service = OpenAI_Embedding(llm_config["embedding_service"]) elif llm_config["embedding_service"]["embedding_model_service"].lower() == "azure": @@ -128,11 +126,9 @@ def get_llm_service(llm_config) -> LLM_Model: else: raise Exception("LLM Completion Service Not Supported") - LogWriter.info( f"Milvus enabled for host {milvus_config['host']} at port {milvus_config['port']}" ) - if os.getenv("INIT_EMBED_STORE", "true")=="true": LogWriter.info("Setting up Milvus embedding store for InquiryAI") try: diff --git a/common/gsql/graphRAG/communities_have_desc.gsql b/common/gsql/graphRAG/communities_have_desc.gsql index f5cda70e..75abcef5 100644 --- a/common/gsql/graphRAG/communities_have_desc.gsql +++ b/common/gsql/graphRAG/communities_have_desc.gsql @@ -1,4 +1,4 @@ -CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter) SYNTAX V2{ +CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter, BOOL p=False) SYNTAX V2{ SumAccum @@descrs; Comms = {Community.*}; Comms = SELECT c FROM Comms:c @@ -11,4 +11,10 @@ CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter) SYNTAX V2{ PRINT (@@descrs == Comms.size()) as all_have_desc; PRINT @@descrs, Comms.size(); + + IF p THEN + Comms = SELECT c FROM Comms:c + WHERE c.iteration == iter and length(c.description) == 0; + PRINT Comms; + END; } diff --git a/common/gsql/graphRAG/entities_have_resolution.gsql b/common/gsql/graphRAG/entities_have_resolution.gsql new file mode 100644 index 00000000..ebc8442f --- /dev/null +++ b/common/gsql/graphRAG/entities_have_resolution.gsql @@ -0,0 +1,10 @@ +CREATE DISTRIBUTED QUERY entities_have_resolution() SYNTAX V2{ + SumAccum @@resolved; + Ents = {Entity.*}; + Ents = SELECT s FROM Ents:s -(RESOLVES_TO>)- ResolvedEntity + ACCUM @@resolved += 1; + + + PRINT (@@resolved >= Ents.size()) as all_resolved; + PRINT @@resolved, Ents.size(); +} diff --git a/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql b/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql index 42e9108d..a22c3565 100644 --- a/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql +++ b/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql @@ -18,7 +18,7 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches MaxAccum @best_move; // best move of the node with the highest delta Q to move the isolated node into the new community MaxAccum @@min_double; // used to reset the @best_move SumAccum @@move_cnt; - OrAccum @to_change_community; + OrAccum @to_change_community, @has_relations; SumAccum @batch_id; MinAccum @vid; @@ -152,6 +152,8 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches @@community_sum_total_map.clear(); Tmp = SELECT s FROM AllNodes:s -(_:e)-> ResolvedEntity:t ACCUM + s.@has_relations += TRUE, + t.@has_relations += TRUE, IF s.@community_id == t.@community_id THEN // keep track of how many edges are within the community @@community_sum_in_map += (s.@community_id -> wt) @@ -165,7 +167,14 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches INSERT INTO IN_COMMUNITY VALUES (s, s.@community_vid+"_1") // link entity to it's first community ; - PRINT @@source_target_k_in_map; + // Continue community hierarchy for unattached partitions + Tmp = SELECT s FROM AllNodes:s + WHERE NOT s.@has_relations + POST-ACCUM + // if s is a part of an unattached partition, add to its community hierarchy to maintain parity with rest of graph + INSERT INTO Community VALUES (s.id+"_1", 1, ""), + INSERT INTO IN_COMMUNITY VALUES (s, s.id+"_1"); // link entity to it's first community + @@community_sum_total_map.clear(); // link communities diff --git a/eventual-consistency-service/app/graphrag/graph_rag.py b/eventual-consistency-service/app/graphrag/graph_rag.py index f9a58d44..e4457a77 100644 --- a/eventual-consistency-service/app/graphrag/graph_rag.py +++ b/eventual-consistency-service/app/graphrag/graph_rag.py @@ -1,18 +1,23 @@ import asyncio +import json import logging import time import traceback +from collections import defaultdict import httpx from aiochannel import Channel from graphrag import workers from graphrag.util import ( + check_all_ents_resolved, check_vertex_has_desc, http_timeout, init, + load_q, make_headers, stream_ids, tg_sem, + upsert_batch, ) from pyTigerGraph import TigerGraphConnection @@ -83,12 +88,6 @@ async def chunk_docs( doc_tasks = [] async with asyncio.TaskGroup() as grp: async for content in docs_chan: - v_id = content["v_id"] - txt = content["attributes"]["text"] - # send the document to be embedded - logger.info("chunk writes to extract") - # await embed_chan.put((v_id, txt, "Document")) - task = grp.create_task( workers.chunk_doc(conn, content, upsert_chan, embed_chan, extract_chan) ) @@ -117,7 +116,66 @@ async def upsert(upsert_chan: Channel): # execute the task grp.create_task(func(*args)) - logger.info(f"upsert done") + logger.info("upsert done") + logger.info("closing load_q chan") + load_q.close() + + +async def load(conn: TigerGraphConnection): + logger.info("Reading from load_q") + dd = lambda: defaultdict(dd) # infinite default dict + batch_size = 1000 + # while the load q is still open or has contents + while not load_q.closed() or not load_q.empty(): + if load_q.closed(): + logger.info( + f"load queue closed. Flushing load queue (final load for this stage)" + ) + # if there's `batch_size` entities in the channel, load it + # or if the channel is closed, flush it + if load_q.qsize() >= batch_size or load_q.closed() or load_q.should_flush(): + batch = { + "vertices": defaultdict(dict[str, any]), + "edges": dd(), + } + n_verts = 0 + n_edges = 0 + size = ( + load_q.qsize() + if load_q.closed() or load_q.should_flush() + else batch_size + ) + for _ in range(size): + t, elem = await load_q.get() + if t == "FLUSH": + logger.debug(f"flush recieved: {t}") + load_q._should_flush = False + break + match t: + case "vertices": + vt, v_id, attr = elem + batch[t][vt][v_id] = attr + n_verts += 1 + case "edges": + src_v_type, src_v_id, edge_type, tgt_v_type, tgt_v_id, attrs = ( + elem + ) + batch[t][src_v_type][src_v_id][edge_type][tgt_v_type][ + tgt_v_id + ] = attrs + n_edges += 1 + + data = json.dumps(batch) + logger.info( + f"Upserting batch size of {size}. ({n_verts} verts | {n_edges} edges. {len(data.encode())/1000:,} kb)" + ) + await upsert_batch(conn, data) + else: + await asyncio.sleep(1) + + # TODO: flush q if it's not empty + if not load_q.empty(): + raise Exception(f"load_q not empty: {load_q.qsize()}", flush=True) async def embed( @@ -287,7 +345,9 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel): ) res.raise_for_status() mod = res.json()["results"][0]["mod"] - logger.info(f"*** mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})") + logger.info(f"mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})") + if mod == 0 or mod - prev_mod <= -0.05: + break # write iter to chan for layer to be processed await stream_communities(conn, i + 1, comm_process_chan) @@ -308,8 +368,6 @@ async def stream_communities( logger.info("streaming communities") headers = make_headers(conn) - # TODO: - # can only do one layer at a time to ensure that every child community has their descriptions # async for i in community_chan: # get the community from that layer @@ -329,14 +387,16 @@ async def stream_communities( # Wait for all communities for layer i to be processed before doing next layer # all community descriptions must be populated before the next layer can be processed if len(comms) > 0: + n_waits = 0 while not await check_vertex_has_desc(conn, i): logger.info(f"Waiting for layer{i} to finish processing") await asyncio.sleep(5) + n_waits += 1 + if n_waits > 3: + logger.info("Flushing load_q") + await load_q.flush(("FLUSH", None)) await asyncio.sleep(3) - logger.info("stream_communities done") - logger.info("closing comm_process_chan") - async def summarize_communities( conn: TigerGraphConnection, @@ -353,7 +413,7 @@ async def summarize_communities( embed_chan.close() -async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): +async def run(graphname: str, conn: TigerGraphConnection): """ Set up GraphRAG: - Install necessary queries. @@ -370,8 +430,8 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): init_start = time.perf_counter() doc_process_switch = True - entity_resolution_switch =True - community_detection_switch =True + entity_resolution_switch = True + community_detection_switch = True if doc_process_switch: logger.info("Doc Processing Start") docs_chan = Channel(1) @@ -386,7 +446,8 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan) ) # upsert chunks - grp.create_task(upsert( upsert_chan)) + grp.create_task(upsert(upsert_chan)) + grp.create_task(load(conn)) # embed grp.create_task(embed(embed_chan, index_stores, graphname)) # extract entities @@ -403,6 +464,7 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): logger.info("Entity Processing Start") entities_chan = Channel(100) upsert_chan = Channel(100) + load_q.reopen() async with asyncio.TaskGroup() as grp: grp.create_task(stream_entities(conn, entities_chan, 50)) grp.create_task( @@ -414,8 +476,12 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): ) ) grp.create_task(upsert(upsert_chan)) + grp.create_task(load(conn)) entity_end = time.perf_counter() logger.info("Entity Processing End") + while not await check_all_ents_resolved(conn): + logger.info(f"Waiting for resolved entites to finish loading") + await asyncio.sleep(1) # Community Detection community_start = time.perf_counter() @@ -425,17 +491,17 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): comm_process_chan = Channel(100) upsert_chan = Channel(100) embed_chan = Channel(100) + load_q.reopen() async with asyncio.TaskGroup() as grp: # run louvain - # grp.create_task(communities(conn, communities_chan)) - grp.create_task(communities(conn, comm_process_chan)) # get the communities - # grp.create_task( stream_communities(conn, communities_chan, comm_process_chan)) + grp.create_task(communities(conn, comm_process_chan)) # summarize each community grp.create_task( summarize_communities(conn, comm_process_chan, upsert_chan, embed_chan) ) grp.create_task(upsert(upsert_chan)) + grp.create_task(load(conn)) grp.create_task(embed(embed_chan, index_stores, graphname)) community_end = time.perf_counter() diff --git a/eventual-consistency-service/app/graphrag/reusable_channel.py b/eventual-consistency-service/app/graphrag/reusable_channel.py new file mode 100644 index 00000000..54ec62c9 --- /dev/null +++ b/eventual-consistency-service/app/graphrag/reusable_channel.py @@ -0,0 +1,37 @@ +from asyncio import Queue + + +class ReuseableChannel: + def __init__(self, maxsize=0) -> None: + self.maxsize = maxsize + self.q = Queue(maxsize) + self._closed = False + self._should_flush = False + + async def put(self, item: any) -> None: + await self.q.put(item) + + async def get(self) -> any: + return await self.q.get() + + def closed(self): + return self._closed + + def should_flush(self): + return self._should_flush + + async def flush(self, flush_msg=None): + self._should_flush = True + await self.put(flush_msg) + + def empty(self): + return self.q.empty() + + def close(self): + self._closed = True + + def qsize(self) -> int: + return self.q.qsize() + + def reopen(self): + self._closed = False diff --git a/eventual-consistency-service/app/graphrag/util.py b/eventual-consistency-service/app/graphrag/util.py index 84785e7a..5f288fa2 100644 --- a/eventual-consistency-service/app/graphrag/util.py +++ b/eventual-consistency-service/app/graphrag/util.py @@ -1,16 +1,11 @@ import asyncio import base64 -import json import logging import re import traceback from glob import glob -from typing import Callable import httpx -from graphrag import workers -from pyTigerGraph import TigerGraphConnection - from common.config import ( doc_processing_config, embedding_service, @@ -22,11 +17,14 @@ from common.extractors import GraphExtractor, LLMEntityRelationshipExtractor from common.extractors.BaseExtractor import BaseExtractor from common.logs.logwriter import LogWriter +from graphrag import reusable_channel, workers +from pyTigerGraph import TigerGraphConnection logger = logging.getLogger(__name__) -http_timeout = httpx.Timeout(15.0) +http_timeout = httpx.Timeout(15.0) -tg_sem = asyncio.Semaphore(100) +tg_sem = asyncio.Semaphore(20) +load_q = reusable_channel.ReuseableChannel() tg_sem = asyncio.Semaphore(100) @@ -61,15 +59,12 @@ async def init( ) -> tuple[BaseExtractor, dict[str, MilvusEmbeddingStore]]: # install requried queries requried_queries = [ - # "common/gsql/supportai/Scan_For_Updates", - # "common/gsql/supportai/Update_Vertices_Processing_Status", - # "common/gsql/supportai/ECC_Status", - # "common/gsql/supportai/Check_Nonexistent_Vertices", "common/gsql/graphRAG/StreamIds", "common/gsql/graphRAG/StreamDocContent", "common/gsql/graphRAG/SetEpochProcessing", "common/gsql/graphRAG/ResolveRelationships", "common/gsql/graphRAG/get_community_children", + "common/gsql/graphRAG/entities_have_resolution", "common/gsql/graphRAG/communities_have_desc", "common/gsql/graphRAG/louvain/graphrag_louvain_init", "common/gsql/graphRAG/louvain/graphrag_louvain_communities", @@ -96,7 +91,6 @@ async def init( "DocumentChunk", "Entity", "Relationship", - # "Concept", "Community", ], ) @@ -188,6 +182,7 @@ def process_id(v_id: str): v_id = has_func[0] if v_id == "''" or v_id == '""': return "" + v_id = v_id.replace("(", "").replace(")", "") return v_id @@ -201,14 +196,16 @@ async def upsert_vertex( logger.info(f"Upsert vertex: {vertex_type} {vertex_id}") vertex_id = vertex_id.replace(" ", "_") attrs = map_attrs(attributes) - data = json.dumps({"vertices": {vertex_type: {vertex_id: attrs}}}) + await load_q.put(("vertices", (vertex_type, vertex_id, attrs))) + + +async def upsert_batch(conn: TigerGraphConnection, data: str): headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: async with tg_sem: res = await client.post( f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers ) - res.raise_for_status() @@ -217,13 +214,22 @@ async def check_vertex_exists(conn, v_id: str): headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: async with tg_sem: - res = await client.get( - f"{conn.restppUrl}/graph/{conn.graphname}/vertices/Entity/{v_id}", - headers=headers, - ) + try: + res = await client.get( + f"{conn.restppUrl}/graph/{conn.graphname}/vertices/Entity/{v_id}", + headers=headers, + ) - res.raise_for_status() - return res.json() + except Exception as e: + logger.error(f"Check err:\n{e}") + return {"error": True} + + try: + res.raise_for_status() + return res.json() + except Exception as e: + logger.error(f"Check err:\n{e}\n{res.text}") + return {"error": True} async def upsert_edge( @@ -241,28 +247,19 @@ async def upsert_edge( attrs = map_attrs(attributes) src_v_id = src_v_id.replace(" ", "_") tgt_v_id = tgt_v_id.replace(" ", "_") - data = json.dumps( - { - "edges": { - src_v_type: { - src_v_id: { - edge_type: { - tgt_v_type: { - tgt_v_id: attrs, - } - } - }, - } - } - } + await load_q.put( + ( + "edges", + ( + src_v_type, + src_v_id, + edge_type, + tgt_v_type, + tgt_v_id, + attrs, + ), + ) ) - headers = make_headers(conn) - async with httpx.AsyncClient(timeout=http_timeout) as client: - async with tg_sem: - res = await client.post( - f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers - ) - res.raise_for_status() async def get_commuinty_children(conn, i: int, c: str): @@ -274,20 +271,46 @@ async def get_commuinty_children(conn, i: int, c: str): params={"comm": c, "iter": i}, headers=headers, ) - resp.raise_for_status() + try: + resp.raise_for_status() + except Exception as e: + logger.error(f"Get Children err:\n{e}") descrs = [] for d in resp.json()["results"][0]["children"]: desc = d["attributes"]["description"] - if i == 1 and all(len(x) == 0 for x in desc): - desc = [d["v_id"]] - elif len(desc) == 0: - desc = d["v_id"] - - descrs.append(desc) + # if it's the entity iteration + if i == 1: + # filter out empty strings + desc = list(filter(lambda x: len(x) > 0, desc)) + # if there are no descriptions, make it the v_id + if len(desc) == 0: + desc.append(d["v_id"]) + descrs.extend(desc) + else: + descrs.append(desc) return descrs +async def check_all_ents_resolved(conn): + headers = make_headers(conn) + async with httpx.AsyncClient(timeout=None) as client: + async with tg_sem: + resp = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/entities_have_resolution", + headers=headers, + ) + try: + resp.raise_for_status() + except Exception as e: + logger.error(f"Check Vert Desc err:\n{e}") + + res = resp.json()["results"][0]["all_resolved"] + logger.info(resp.json()["results"]) + + return res + + async def check_vertex_has_desc(conn, i: int): headers = make_headers(conn) async with httpx.AsyncClient(timeout=None) as client: @@ -297,8 +320,12 @@ async def check_vertex_has_desc(conn, i: int): params={"iter": i}, headers=headers, ) - resp.raise_for_status() + try: + resp.raise_for_status() + except Exception as e: + logger.error(f"Check Vert Desc err:\n{e}") res = resp.json()["results"][0]["all_have_desc"] + logger.info(resp.json()["results"]) return res diff --git a/eventual-consistency-service/app/graphrag/workers.py b/eventual-consistency-service/app/graphrag/workers.py index 37786aee..98b3e69c 100644 --- a/eventual-consistency-service/app/graphrag/workers.py +++ b/eventual-consistency-service/app/graphrag/workers.py @@ -1,20 +1,20 @@ import base64 import logging import time +import traceback from urllib.parse import quote_plus import ecc_util import httpx from aiochannel import Channel -from graphrag import community_summarizer, util -from langchain_community.graphs.graph_document import GraphDocument, Node -from pyTigerGraph import TigerGraphConnection - from common.config import milvus_config from common.embeddings.embedding_services import EmbeddingModel from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore from common.extractors.BaseExtractor import BaseExtractor from common.logs.logwriter import LogWriter +from graphrag import community_summarizer, util +from langchain_community.graphs.graph_document import GraphDocument, Node +from pyTigerGraph import TigerGraphConnection vertex_field = milvus_config.get("vertex_field", "vertex_id") @@ -179,7 +179,7 @@ async def extract( # embed the entity # embed with the v_id if the description is blank - if len(desc[0]): + if len(desc[0]) == 0: await embed_chan.put((v_id, v_id, "Entity")) else: # (v_id, content, index_name) @@ -219,7 +219,7 @@ async def extract( for edge in doc.relationships: logger.info( - f"extract writes relates edge to upsert\n{edge.source.id} -({edge.type})-> {edge.target.id}" + f"extract writes relates edge to upsert:{edge.source.id} -({edge.type})-> {edge.target.id}" ) # upsert verts first to make sure their ID becomes an attr v_id = util.process_id(edge.source.id) # src_id @@ -276,6 +276,7 @@ async def extract( ) # embed "Relationship", # (v_id, content, index_name) + # right now, we're not embedding relationships in graphrag async def resolve_entity( @@ -294,7 +295,14 @@ async def resolve_entity( mark as processed """ - results = await emb_store.aget_k_closest(entity_id) + try: + results = await emb_store.aget_k_closest(entity_id) + + except Exception: + err = traceback.format_exc() + logger.error(err) + return + if len(results) == 0: logger.error( f"aget_k_closest should, minimally, return the entity itself.\n{results}" @@ -359,11 +367,6 @@ async def process_community( logger.info(f"Processing Community: {comm_id}") # get the children of the community children = await util.get_commuinty_children(conn, i, comm_id) - if i == 1: - tmp = [] - for c in children: - tmp.extend(c) - children = list(filter(lambda x: len(x) > 0, tmp)) comm_id = util.process_id(comm_id) # if the community only has one child, use its description @@ -374,6 +377,7 @@ async def process_community( summarizer = community_summarizer.CommunitySummarizer(llm) summary = await summarizer.summarize(comm_id, children) + logger.debug(f"Community {comm_id}: {children}, {summary}") await upsert_chan.put( ( util.upsert_vertex, # func to call