From e80d882342f6c1537b20406490866a957d39186c Mon Sep 17 00:00:00 2001 From: RobRossmiller-TG <165701656+RobRossmiller-TG@users.noreply.github.com> Date: Mon, 19 Aug 2024 18:03:06 -0400 Subject: [PATCH 1/7] batched loader --- common/config.py | 4 - .../gsql/graphRAG/communities_have_desc.gsql | 8 +- .../graphRAG/entities_have_resolution.gsql | 10 ++ common/gsql/graphRAG/loaders/tmp.gsql | 26 +++ .../louvain/graphrag_louvain_init.gsql | 13 +- common/requirements.txt | 5 +- copilot/requirements.txt | 4 +- .../app/graphrag/graph_rag.py | 99 ++++++++++-- .../app/graphrag/reusable_channel.py | 37 +++++ .../app/graphrag/util.py | 148 +++++++++++++----- .../app/graphrag/workers.py | 12 +- 11 files changed, 299 insertions(+), 67 deletions(-) create mode 100644 common/gsql/graphRAG/entities_have_resolution.gsql create mode 100644 common/gsql/graphRAG/loaders/tmp.gsql create mode 100644 eventual-consistency-service/app/graphrag/reusable_channel.py diff --git a/common/config.py b/common/config.py index 8812016c..9f1d3cab 100644 --- a/common/config.py +++ b/common/config.py @@ -92,8 +92,6 @@ "MILVUS_CONFIG must be a .json file or a JSON string, failed with error: " + str(e) ) - - if llm_config["embedding_service"]["embedding_model_service"].lower() == "openai": embedding_service = OpenAI_Embedding(llm_config["embedding_service"]) elif llm_config["embedding_service"]["embedding_model_service"].lower() == "azure": @@ -128,11 +126,9 @@ def get_llm_service(llm_config) -> LLM_Model: else: raise Exception("LLM Completion Service Not Supported") - LogWriter.info( f"Milvus enabled for host {milvus_config['host']} at port {milvus_config['port']}" ) - if os.getenv("INIT_EMBED_STORE", "true")=="true": LogWriter.info("Setting up Milvus embedding store for InquiryAI") try: diff --git a/common/gsql/graphRAG/communities_have_desc.gsql b/common/gsql/graphRAG/communities_have_desc.gsql index f5cda70e..75abcef5 100644 --- a/common/gsql/graphRAG/communities_have_desc.gsql +++ b/common/gsql/graphRAG/communities_have_desc.gsql @@ -1,4 +1,4 @@ -CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter) SYNTAX V2{ +CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter, BOOL p=False) SYNTAX V2{ SumAccum @@descrs; Comms = {Community.*}; Comms = SELECT c FROM Comms:c @@ -11,4 +11,10 @@ CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter) SYNTAX V2{ PRINT (@@descrs == Comms.size()) as all_have_desc; PRINT @@descrs, Comms.size(); + + IF p THEN + Comms = SELECT c FROM Comms:c + WHERE c.iteration == iter and length(c.description) == 0; + PRINT Comms; + END; } diff --git a/common/gsql/graphRAG/entities_have_resolution.gsql b/common/gsql/graphRAG/entities_have_resolution.gsql new file mode 100644 index 00000000..ebc8442f --- /dev/null +++ b/common/gsql/graphRAG/entities_have_resolution.gsql @@ -0,0 +1,10 @@ +CREATE DISTRIBUTED QUERY entities_have_resolution() SYNTAX V2{ + SumAccum @@resolved; + Ents = {Entity.*}; + Ents = SELECT s FROM Ents:s -(RESOLVES_TO>)- ResolvedEntity + ACCUM @@resolved += 1; + + + PRINT (@@resolved >= Ents.size()) as all_resolved; + PRINT @@resolved, Ents.size(); +} diff --git a/common/gsql/graphRAG/loaders/tmp.gsql b/common/gsql/graphRAG/loaders/tmp.gsql new file mode 100644 index 00000000..e8d8d417 --- /dev/null +++ b/common/gsql/graphRAG/loaders/tmp.gsql @@ -0,0 +1,26 @@ +CREATE LOADING load_entity@uuid@ { + DEFINE FILENAME Content; + LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; +} + + +CREATE LOADING load_ResolvedEntity@uuid@ { + DEFINE FILENAME Content; + LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; +} +CREATE LOADING load_ asdfasdf @uuid@ { + DEFINE FILENAME Content; + LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; +} +CREATE LOADING load_ asdfasdf @uuid@ { + DEFINE FILENAME Content; + LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; +} +CREATE LOADING load_ asdfasdf @uuid@ { + DEFINE FILENAME Content; + LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; +} +CREATE LOADING load_ asdfasdf @uuid@ { + DEFINE FILENAME Content; + LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; +} diff --git a/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql b/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql index 42e9108d..a22c3565 100644 --- a/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql +++ b/common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql @@ -18,7 +18,7 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches MaxAccum @best_move; // best move of the node with the highest delta Q to move the isolated node into the new community MaxAccum @@min_double; // used to reset the @best_move SumAccum @@move_cnt; - OrAccum @to_change_community; + OrAccum @to_change_community, @has_relations; SumAccum @batch_id; MinAccum @vid; @@ -152,6 +152,8 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches @@community_sum_total_map.clear(); Tmp = SELECT s FROM AllNodes:s -(_:e)-> ResolvedEntity:t ACCUM + s.@has_relations += TRUE, + t.@has_relations += TRUE, IF s.@community_id == t.@community_id THEN // keep track of how many edges are within the community @@community_sum_in_map += (s.@community_id -> wt) @@ -165,7 +167,14 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches INSERT INTO IN_COMMUNITY VALUES (s, s.@community_vid+"_1") // link entity to it's first community ; - PRINT @@source_target_k_in_map; + // Continue community hierarchy for unattached partitions + Tmp = SELECT s FROM AllNodes:s + WHERE NOT s.@has_relations + POST-ACCUM + // if s is a part of an unattached partition, add to its community hierarchy to maintain parity with rest of graph + INSERT INTO Community VALUES (s.id+"_1", 1, ""), + INSERT INTO IN_COMMUNITY VALUES (s, s.id+"_1"); // link entity to it's first community + @@community_sum_total_map.clear(); // link communities diff --git a/common/requirements.txt b/common/requirements.txt index af45c357..9912b4a8 100644 --- a/common/requirements.txt +++ b/common/requirements.txt @@ -1,3 +1,4 @@ +aiochannel==1.2.1 aiohappyeyeballs==2.3.5 aiohttp==3.10.3 aiosignal==1.3.1 @@ -83,7 +84,7 @@ langchain-experimental==0.0.64 langchain-groq==0.1.9 langchain-ibm==0.1.12 langchain-milvus==0.1.4 -langchain-openai==0.1.21 +langchain-openai==0.1.22 langchain-text-splitters==0.2.2 langchainhub==0.1.21 langdetect==1.0.9 @@ -100,7 +101,7 @@ minio==7.2.7 multidict==6.0.5 mypy-extensions==1.0.0 nest-asyncio==1.6.0 -nltk==3.8.2 +nltk==3.8.1 numpy==1.26.4 openai==1.40.6 ordered-set==4.1.0 diff --git a/copilot/requirements.txt b/copilot/requirements.txt index 4a5ac3d1..cd9bb7bc 100644 --- a/copilot/requirements.txt +++ b/copilot/requirements.txt @@ -84,7 +84,7 @@ langchain-experimental==0.0.64 langchain-groq==0.1.9 langchain-ibm==0.1.12 langchain-milvus==0.1.4 -langchain-openai==0.1.21 +langchain-openai==0.1.22 langchain-text-splitters==0.2.2 langchainhub==0.1.21 langdetect==1.0.9 @@ -101,7 +101,7 @@ minio==7.2.7 multidict==6.0.5 mypy-extensions==1.0.0 nest-asyncio==1.6.0 -nltk==3.8.2 +nltk==3.8.1 numpy==1.26.4 openai==1.40.6 ordered-set==4.1.0 diff --git a/eventual-consistency-service/app/graphrag/graph_rag.py b/eventual-consistency-service/app/graphrag/graph_rag.py index 29f03dce..54e47f26 100644 --- a/eventual-consistency-service/app/graphrag/graph_rag.py +++ b/eventual-consistency-service/app/graphrag/graph_rag.py @@ -2,17 +2,21 @@ import logging import time import traceback +from collections import defaultdict import httpx from aiochannel import Channel from graphrag import workers from graphrag.util import ( + check_all_ents_resolved, check_vertex_has_desc, http_timeout, init, + load_q, make_headers, stream_ids, tg_sem, + upsert_batch, ) from pyTigerGraph import TigerGraphConnection @@ -83,8 +87,8 @@ async def chunk_docs( doc_tasks = [] async with asyncio.TaskGroup() as grp: async for content in docs_chan: - v_id = content["v_id"] - txt = content["attributes"]["text"] + # v_id = content["v_id"] + # txt = content["attributes"]["text"] # send the document to be embedded logger.info("chunk writes to extract") # await embed_chan.put((v_id, txt, "Document")) @@ -117,7 +121,65 @@ async def upsert(upsert_chan: Channel): # execute the task grp.create_task(func(*args)) - logger.info(f"upsert done") + logger.info("upsert done") + logger.info("closing load_q chan") + load_q.close() + + +async def load(conn: TigerGraphConnection): + logger.info("Reading from load_q") + dd = lambda: defaultdict(dd) # infinite default dict + batch_size = 250 + # while the load q is still open or has contents + while not load_q.closed() or not load_q.empty(): + if load_q.closed(): + logger.info( + f"load queue closed. Flushing load queue (final load for this stage)" + ) + # if there's `batch_size` entities in the channel, load it + # or if the channel is closed, flush it + if load_q.qsize() >= batch_size or load_q.closed() or load_q.should_flush(): + batch = { + "vertices": defaultdict(dict[str, any]), + "edges": dd(), + } + n_verts = 0 + n_edges = 0 + size = ( + load_q.qsize() + if load_q.closed() or load_q.should_flush() + else batch_size + ) + for _ in range(size): + t, elem = await load_q.get() + if t == "FLUSH": + logger.debug(f"flush recieved: {t}") + load_q._should_flush = False + break + match t: + case "vertices": + vt, v_id, attr = elem + batch[t][vt][v_id] = attr + n_verts += 1 + case "edges": + src_v_type, src_v_id, edge_type, tgt_v_type, tgt_v_id, attrs = ( + elem + ) + batch[t][src_v_type][src_v_id][edge_type][tgt_v_type][ + tgt_v_id + ] = attrs + n_edges += 1 + + logger.info( + f"Upserting batch size of {size}. ({n_verts} verts | {n_edges} edges)" + ) + await upsert_batch(conn, batch) + else: + await asyncio.sleep(1) + + # TODO: flush q if it's not empty + if not load_q.empty(): + raise Exception(f"load_q not empty: {load_q.qsize()}", flush=True) async def embed( @@ -132,6 +194,7 @@ async def embed( async with asyncio.TaskGroup() as grp: # consume task queue async for v_id, content, index_name in embed_chan: + continue embedding_store = index_stores[f"{graphname}_{index_name}"] logger.info(f"Embed to {graphname}_{index_name}: {v_id}") grp.create_task( @@ -288,6 +351,8 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel): res.raise_for_status() mod = res.json()["results"][0]["mod"] logger.info(f"*** mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})") + if mod == 0: + break # write iter to chan for layer to be processed await stream_communities(conn, i + 1, comm_process_chan) @@ -329,13 +394,18 @@ async def stream_communities( # Wait for all communities for layer i to be processed before doing next layer # all community descriptions must be populated before the next layer can be processed if len(comms) > 0: + n_waits = 0 while not await check_vertex_has_desc(conn, i): logger.info(f"Waiting for layer{i} to finish processing") await asyncio.sleep(5) + n_waits += 1 + if n_waits > 3: + logger.info("Flushing load_q") + await load_q.flush(("FLUSH", None)) await asyncio.sleep(3) - logger.info("stream_communities done") - logger.info("closing comm_process_chan") + # logger.info("stream_communities done") + # logger.info("closing comm_process_chan") async def summarize_communities( @@ -353,7 +423,7 @@ async def summarize_communities( embed_chan.close() -async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): +async def run(graphname: str, conn: TigerGraphConnection): """ Set up GraphRAG: - Install necessary queries. @@ -369,9 +439,9 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): extractor, index_stores = await init(conn) init_start = time.perf_counter() - doc_process_switch = True - entity_resolution_switch =True - community_detection_switch =True + doc_process_switch = False + entity_resolution_switch = False + community_detection_switch = True if doc_process_switch: logger.info("Doc Processing Start") docs_chan = Channel(1) @@ -386,7 +456,8 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan) ) # upsert chunks - grp.create_task(upsert( upsert_chan)) + grp.create_task(upsert(upsert_chan)) + grp.create_task(load(conn)) # embed grp.create_task(embed(embed_chan, index_stores, graphname)) # extract entities @@ -403,6 +474,7 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): logger.info("Entity Processing Start") entities_chan = Channel(100) upsert_chan = Channel(100) + load_q.reopen() async with asyncio.TaskGroup() as grp: grp.create_task(stream_entities(conn, entities_chan, 50)) grp.create_task( @@ -414,8 +486,12 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): ) ) grp.create_task(upsert(upsert_chan)) + grp.create_task(load(conn)) entity_end = time.perf_counter() logger.info("Entity Processing End") + while not await check_all_ents_resolved(conn): + logger.info(f"Waiting for resolved entites to finish loading") + await asyncio.sleep(1) # Community Detection community_start = time.perf_counter() @@ -425,9 +501,9 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): comm_process_chan = Channel(100) upsert_chan = Channel(100) embed_chan = Channel(100) + load_q.reopen() async with asyncio.TaskGroup() as grp: # run louvain - # grp.create_task(communities(conn, communities_chan)) grp.create_task(communities(conn, comm_process_chan)) # get the communities # grp.create_task( stream_communities(conn, communities_chan, comm_process_chan)) @@ -436,6 +512,7 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100): summarize_communities(conn, comm_process_chan, upsert_chan, embed_chan) ) grp.create_task(upsert(upsert_chan)) + grp.create_task(load(conn)) grp.create_task(embed(embed_chan, index_stores, graphname)) community_end = time.perf_counter() diff --git a/eventual-consistency-service/app/graphrag/reusable_channel.py b/eventual-consistency-service/app/graphrag/reusable_channel.py new file mode 100644 index 00000000..54ec62c9 --- /dev/null +++ b/eventual-consistency-service/app/graphrag/reusable_channel.py @@ -0,0 +1,37 @@ +from asyncio import Queue + + +class ReuseableChannel: + def __init__(self, maxsize=0) -> None: + self.maxsize = maxsize + self.q = Queue(maxsize) + self._closed = False + self._should_flush = False + + async def put(self, item: any) -> None: + await self.q.put(item) + + async def get(self) -> any: + return await self.q.get() + + def closed(self): + return self._closed + + def should_flush(self): + return self._should_flush + + async def flush(self, flush_msg=None): + self._should_flush = True + await self.put(flush_msg) + + def empty(self): + return self.q.empty() + + def close(self): + self._closed = True + + def qsize(self) -> int: + return self.q.qsize() + + def reopen(self): + self._closed = False diff --git a/eventual-consistency-service/app/graphrag/util.py b/eventual-consistency-service/app/graphrag/util.py index a934f272..2c4dce98 100644 --- a/eventual-consistency-service/app/graphrag/util.py +++ b/eventual-consistency-service/app/graphrag/util.py @@ -5,10 +5,9 @@ import re import traceback from glob import glob -from typing import Callable import httpx -from graphrag import workers +from graphrag import reusable_channel, workers from pyTigerGraph import TigerGraphConnection from common.config import ( @@ -24,9 +23,10 @@ from common.logs.logwriter import LogWriter logger = logging.getLogger(__name__) -http_timeout = httpx.Timeout(15.0) +http_timeout = httpx.Timeout(15.0) tg_sem = asyncio.Semaphore(100) +load_q = reusable_channel.ReuseableChannel() async def install_queries( @@ -114,7 +114,7 @@ async def init( vector_field=milvus_config.get("vector_field", "document_vector"), text_field=milvus_config.get("text_field", "document_content"), vertex_field=vertex_field, - drop_old=True, + drop_old=False, ) LogWriter.info(f"Initializing {name}") @@ -200,15 +200,37 @@ async def upsert_vertex( logger.info(f"Upsert vertex: {vertex_type} {vertex_id}") vertex_id = vertex_id.replace(" ", "_") attrs = map_attrs(attributes) - data = json.dumps({"vertices": {vertex_type: {vertex_id: attrs}}}) + await load_q.put(("vertices", (vertex_type, vertex_id, attrs))) + # data = json.dumps({"vertices": {vertex_type: {vertex_id: attrs}}}) + # headers = make_headers(conn) + # async with httpx.AsyncClient(timeout=http_timeout) as client: + # async with tg_sem: + # res = await client.post( + # f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers + # ) + # + # try: + # res.raise_for_status() + # except Exception as e: + # logger.error(f"Upsert err: {vertex_type} {vertex_id}\n{e}") + + +async def upsert_batch(conn: TigerGraphConnection, batch): + # logger.info(f"Upsert vertex: {vertex_type} {vertex_id}") + # vertex_id = vertex_id.replace(" ", "_") + # attrs = map_attrs(attributes) + # await load_q.put(('vertices')) + data = json.dumps(batch) headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: async with tg_sem: res = await client.post( f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers ) - + # try: res.raise_for_status() + # except Exception as e: + # logger.error(f"Upsert err: {vertex_type} {vertex_id}\n{e}") async def check_vertex_exists(conn, v_id: str): @@ -220,8 +242,12 @@ async def check_vertex_exists(conn, v_id: str): headers=headers, ) - res.raise_for_status() - return res.json() + try: + res.raise_for_status() + return res.json() + except Exception as e: + logger.error(f"Check err:\n{e}\n{res.text}") + return {"error": True} async def upsert_edge( @@ -239,28 +265,45 @@ async def upsert_edge( attrs = map_attrs(attributes) src_v_id = src_v_id.replace(" ", "_") tgt_v_id = tgt_v_id.replace(" ", "_") - data = json.dumps( - { - "edges": { - src_v_type: { - src_v_id: { - edge_type: { - tgt_v_type: { - tgt_v_id: attrs, - } - } - }, - } - } - } + # data = json.dumps( + # { + # "edges": { + # src_v_type: { + # src_v_id: { + # edge_type: { + # tgt_v_type: { + # tgt_v_id: attrs, + # } + # } + # }, + # } + # } + # } + # ) + await load_q.put( + ( + "edges", + ( + src_v_type, + src_v_id, + edge_type, + tgt_v_type, + tgt_v_id, + attrs, + ), + ) ) - headers = make_headers(conn) - async with httpx.AsyncClient(timeout=http_timeout) as client: - async with tg_sem: - res = await client.post( - f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers - ) - res.raise_for_status() + + # headers = make_headers(conn) + # async with httpx.AsyncClient(timeout=http_timeout) as client: + # async with tg_sem: + # res = await client.post( + # f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers + # ) + # try: + # res.raise_for_status() + # except Exception as e: + # logger.error(f"Upsert Edge err:\n{e}") async def get_commuinty_children(conn, i: int, c: str): @@ -272,20 +315,47 @@ async def get_commuinty_children(conn, i: int, c: str): params={"comm": c, "iter": i}, headers=headers, ) - resp.raise_for_status() + try: + resp.raise_for_status() + except Exception as e: + logger.error(f"Get Children err:\n{e}") descrs = [] for d in resp.json()["results"][0]["children"]: desc = d["attributes"]["description"] - if i == 1 and all(len(x) == 0 for x in desc): - desc = [d["v_id"]] - elif len(desc) == 0: - desc = d["v_id"] - - descrs.append(desc) + # if it's the entity iteration + if i == 1: + # filter out empty strings + desc = list(filter(lambda x: len(x) > 0, desc)) + # if there are no descriptions, make it the v_id + if len(desc) == 0: + desc.append(d["v_id"]) + descrs.extend(desc) + else: + descrs.append(desc) + print(f"Comm: {c} --> {descrs}", flush=True) return descrs +async def check_all_ents_resolved(conn): + headers = make_headers(conn) + async with httpx.AsyncClient(timeout=None) as client: + async with tg_sem: + resp = await client.get( + f"{conn.restppUrl}/query/{conn.graphname}/entities_have_resolution", + headers=headers, + ) + try: + resp.raise_for_status() + except Exception as e: + logger.error(f"Check Vert Desc err:\n{e}") + + res = resp.json()["results"][0]["all_resolved"] + logger.info(resp.json()["results"]) + + return res + + async def check_vertex_has_desc(conn, i: int): headers = make_headers(conn) async with httpx.AsyncClient(timeout=None) as client: @@ -295,8 +365,12 @@ async def check_vertex_has_desc(conn, i: int): params={"iter": i}, headers=headers, ) - resp.raise_for_status() + try: + resp.raise_for_status() + except Exception as e: + logger.error(f"Check Vert Desc err:\n{e}") res = resp.json()["results"][0]["all_have_desc"] + logger.info(resp.json()["results"]) return res diff --git a/eventual-consistency-service/app/graphrag/workers.py b/eventual-consistency-service/app/graphrag/workers.py index 9d8df3c8..1fb5a743 100644 --- a/eventual-consistency-service/app/graphrag/workers.py +++ b/eventual-consistency-service/app/graphrag/workers.py @@ -79,7 +79,7 @@ async def chunk_doc( # send chunks to be embedded logger.info("chunk writes to embed_chan") - await embed_chan.put((v_id, chunk, "DocumentChunk")) + await embed_chan.put((chunk_id, chunk, "DocumentChunk")) # send chunks to have entities extracted logger.info("chunk writes to extract_chan") @@ -179,7 +179,7 @@ async def extract( # embed the entity # embed with the v_id if the description is blank - if len(desc[0]): + if len(desc[0]) == 0: await embed_chan.put((v_id, v_id, "Entity")) else: # (v_id, content, index_name) @@ -219,7 +219,7 @@ async def extract( for edge in doc.relationships: logger.info( - f"extract writes relates edge to upsert\n{edge.source.id} -({edge.type})-> {edge.target.id}" + f"extract writes relates edge to upsert:{edge.source.id} -({edge.type})-> {edge.target.id}" ) # upsert verts first to make sure their ID becomes an attr v_id = util.process_id(edge.source.id) # src_id @@ -359,11 +359,6 @@ async def process_community( logger.info(f"Processing Community: {comm_id}") # get the children of the community children = await util.get_commuinty_children(conn, i, comm_id) - if i == 1: - tmp = [] - for c in children: - tmp.extend(c) - children = list(filter(lambda x: len(x) > 0, tmp)) comm_id = util.process_id(comm_id) # if the community only has one child, use its description @@ -374,6 +369,7 @@ async def process_community( summarizer = community_summarizer.CommunitySummarizer(llm) summary = await summarizer.summarize(comm_id, children) + print(f"*******>{comm_id}: {children}, {summary}", flush=True) await upsert_chan.put( ( util.upsert_vertex, # func to call From b0b833eee133b0d5e6d2e3b592fba57450dd2121 Mon Sep 17 00:00:00 2001 From: RobRossmiller-TG <165701656+RobRossmiller-TG@users.noreply.github.com> Date: Tue, 20 Aug 2024 11:34:22 -0400 Subject: [PATCH 2/7] starting cleanup --- .../app/graphrag/graph_rag.py | 20 +++++++----------- .../app/graphrag/util.py | 21 +++++++++---------- .../app/graphrag/workers.py | 10 ++++++++- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/eventual-consistency-service/app/graphrag/graph_rag.py b/eventual-consistency-service/app/graphrag/graph_rag.py index 54e47f26..a99fc1f7 100644 --- a/eventual-consistency-service/app/graphrag/graph_rag.py +++ b/eventual-consistency-service/app/graphrag/graph_rag.py @@ -1,4 +1,5 @@ import asyncio +import json import logging import time import traceback @@ -129,7 +130,7 @@ async def upsert(upsert_chan: Channel): async def load(conn: TigerGraphConnection): logger.info("Reading from load_q") dd = lambda: defaultdict(dd) # infinite default dict - batch_size = 250 + batch_size = 1000 # while the load q is still open or has contents while not load_q.closed() or not load_q.empty(): if load_q.closed(): @@ -170,10 +171,11 @@ async def load(conn: TigerGraphConnection): ] = attrs n_edges += 1 + data = json.dumps(batch) logger.info( - f"Upserting batch size of {size}. ({n_verts} verts | {n_edges} edges)" + f"Upserting batch size of {size}. ({n_verts} verts | {n_edges} edges. {len(data.encode())/1000:,} kb)" ) - await upsert_batch(conn, batch) + await upsert_batch(conn, data) else: await asyncio.sleep(1) @@ -194,7 +196,6 @@ async def embed( async with asyncio.TaskGroup() as grp: # consume task queue async for v_id, content, index_name in embed_chan: - continue embedding_store = index_stores[f"{graphname}_{index_name}"] logger.info(f"Embed to {graphname}_{index_name}: {v_id}") grp.create_task( @@ -350,7 +351,7 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel): ) res.raise_for_status() mod = res.json()["results"][0]["mod"] - logger.info(f"*** mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})") + logger.info(f"mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})") if mod == 0: break @@ -373,8 +374,6 @@ async def stream_communities( logger.info("streaming communities") headers = make_headers(conn) - # TODO: - # can only do one layer at a time to ensure that every child community has their descriptions # async for i in community_chan: # get the community from that layer @@ -404,9 +403,6 @@ async def stream_communities( await load_q.flush(("FLUSH", None)) await asyncio.sleep(3) - # logger.info("stream_communities done") - # logger.info("closing comm_process_chan") - async def summarize_communities( conn: TigerGraphConnection, @@ -439,8 +435,8 @@ async def run(graphname: str, conn: TigerGraphConnection): extractor, index_stores = await init(conn) init_start = time.perf_counter() - doc_process_switch = False - entity_resolution_switch = False + doc_process_switch = True + entity_resolution_switch = True community_detection_switch = True if doc_process_switch: logger.info("Doc Processing Start") diff --git a/eventual-consistency-service/app/graphrag/util.py b/eventual-consistency-service/app/graphrag/util.py index 2c4dce98..5c73008e 100644 --- a/eventual-consistency-service/app/graphrag/util.py +++ b/eventual-consistency-service/app/graphrag/util.py @@ -10,13 +10,8 @@ from graphrag import reusable_channel, workers from pyTigerGraph import TigerGraphConnection -from common.config import ( - doc_processing_config, - embedding_service, - get_llm_service, - llm_config, - milvus_config, -) +from common.config import (doc_processing_config, embedding_service, + get_llm_service, llm_config, milvus_config) from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore from common.extractors import GraphExtractor, LLMEntityRelationshipExtractor from common.extractors.BaseExtractor import BaseExtractor @@ -25,7 +20,7 @@ logger = logging.getLogger(__name__) http_timeout = httpx.Timeout(15.0) -tg_sem = asyncio.Semaphore(100) +tg_sem = asyncio.Semaphore(20) load_q = reusable_channel.ReuseableChannel() @@ -215,12 +210,11 @@ async def upsert_vertex( # logger.error(f"Upsert err: {vertex_type} {vertex_id}\n{e}") -async def upsert_batch(conn: TigerGraphConnection, batch): +async def upsert_batch(conn: TigerGraphConnection, data: str): # logger.info(f"Upsert vertex: {vertex_type} {vertex_id}") # vertex_id = vertex_id.replace(" ", "_") # attrs = map_attrs(attributes) # await load_q.put(('vertices')) - data = json.dumps(batch) headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: async with tg_sem: @@ -237,11 +231,16 @@ async def check_vertex_exists(conn, v_id: str): headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: async with tg_sem: - res = await client.get( + try: + res = await client.get( f"{conn.restppUrl}/graph/{conn.graphname}/vertices/Entity/{v_id}", headers=headers, ) + except Exception as e: + logger.error(f"Check err:\n{e}") + return {"error": True} + try: res.raise_for_status() return res.json() diff --git a/eventual-consistency-service/app/graphrag/workers.py b/eventual-consistency-service/app/graphrag/workers.py index 1fb5a743..9d8267b2 100644 --- a/eventual-consistency-service/app/graphrag/workers.py +++ b/eventual-consistency-service/app/graphrag/workers.py @@ -1,6 +1,7 @@ import base64 import logging import time +import traceback from urllib.parse import quote_plus import ecc_util @@ -294,7 +295,14 @@ async def resolve_entity( mark as processed """ - results = await emb_store.aget_k_closest(entity_id) + try: + results = await emb_store.aget_k_closest(entity_id) + + except Exception: + err = traceback.format_exc() + logger.error(err) + return + if len(results) == 0: logger.error( f"aget_k_closest should, minimally, return the entity itself.\n{results}" From 2b7099b060936822cc33febc8b17e4d01a919b44 Mon Sep 17 00:00:00 2001 From: RobRossmiller-TG <165701656+RobRossmiller-TG@users.noreply.github.com> Date: Tue, 20 Aug 2024 12:42:38 -0400 Subject: [PATCH 3/7] starting cleanup --- .../app/graphrag/graph_rag.py | 2 +- .../app/graphrag/util.py | 67 +++---------------- .../app/graphrag/workers.py | 1 + 3 files changed, 13 insertions(+), 57 deletions(-) diff --git a/eventual-consistency-service/app/graphrag/graph_rag.py b/eventual-consistency-service/app/graphrag/graph_rag.py index a99fc1f7..b69198c4 100644 --- a/eventual-consistency-service/app/graphrag/graph_rag.py +++ b/eventual-consistency-service/app/graphrag/graph_rag.py @@ -352,7 +352,7 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel): res.raise_for_status() mod = res.json()["results"][0]["mod"] logger.info(f"mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})") - if mod == 0: + if mod == 0 or mod - prev_mod < -0.05: break # write iter to chan for layer to be processed diff --git a/eventual-consistency-service/app/graphrag/util.py b/eventual-consistency-service/app/graphrag/util.py index 5c73008e..ccf00cb2 100644 --- a/eventual-consistency-service/app/graphrag/util.py +++ b/eventual-consistency-service/app/graphrag/util.py @@ -1,6 +1,5 @@ import asyncio import base64 -import json import logging import re import traceback @@ -10,8 +9,13 @@ from graphrag import reusable_channel, workers from pyTigerGraph import TigerGraphConnection -from common.config import (doc_processing_config, embedding_service, - get_llm_service, llm_config, milvus_config) +from common.config import ( + doc_processing_config, + embedding_service, + get_llm_service, + llm_config, + milvus_config, +) from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore from common.extractors import GraphExtractor, LLMEntityRelationshipExtractor from common.extractors.BaseExtractor import BaseExtractor @@ -55,15 +59,12 @@ async def init( ) -> tuple[BaseExtractor, dict[str, MilvusEmbeddingStore]]: # install requried queries requried_queries = [ - # "common/gsql/supportai/Scan_For_Updates", - # "common/gsql/supportai/Update_Vertices_Processing_Status", - # "common/gsql/supportai/ECC_Status", - # "common/gsql/supportai/Check_Nonexistent_Vertices", "common/gsql/graphRAG/StreamIds", "common/gsql/graphRAG/StreamDocContent", "common/gsql/graphRAG/SetEpochProcessing", "common/gsql/graphRAG/ResolveRelationships", "common/gsql/graphRAG/get_community_children", + "common/gsql/graphRAG/entities_have_resolution", "common/gsql/graphRAG/communities_have_desc", "common/gsql/graphRAG/louvain/graphrag_louvain_init", "common/gsql/graphRAG/louvain/graphrag_louvain_communities", @@ -90,7 +91,6 @@ async def init( "DocumentChunk", "Entity", "Relationship", - # "Concept", "Community", ], ) @@ -196,35 +196,16 @@ async def upsert_vertex( vertex_id = vertex_id.replace(" ", "_") attrs = map_attrs(attributes) await load_q.put(("vertices", (vertex_type, vertex_id, attrs))) - # data = json.dumps({"vertices": {vertex_type: {vertex_id: attrs}}}) - # headers = make_headers(conn) - # async with httpx.AsyncClient(timeout=http_timeout) as client: - # async with tg_sem: - # res = await client.post( - # f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers - # ) - # - # try: - # res.raise_for_status() - # except Exception as e: - # logger.error(f"Upsert err: {vertex_type} {vertex_id}\n{e}") async def upsert_batch(conn: TigerGraphConnection, data: str): - # logger.info(f"Upsert vertex: {vertex_type} {vertex_id}") - # vertex_id = vertex_id.replace(" ", "_") - # attrs = map_attrs(attributes) - # await load_q.put(('vertices')) headers = make_headers(conn) async with httpx.AsyncClient(timeout=http_timeout) as client: async with tg_sem: res = await client.post( f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers ) - # try: res.raise_for_status() - # except Exception as e: - # logger.error(f"Upsert err: {vertex_type} {vertex_id}\n{e}") async def check_vertex_exists(conn, v_id: str): @@ -233,9 +214,9 @@ async def check_vertex_exists(conn, v_id: str): async with tg_sem: try: res = await client.get( - f"{conn.restppUrl}/graph/{conn.graphname}/vertices/Entity/{v_id}", - headers=headers, - ) + f"{conn.restppUrl}/graph/{conn.graphname}/vertices/Entity/{v_id}", + headers=headers, + ) except Exception as e: logger.error(f"Check err:\n{e}") @@ -264,21 +245,6 @@ async def upsert_edge( attrs = map_attrs(attributes) src_v_id = src_v_id.replace(" ", "_") tgt_v_id = tgt_v_id.replace(" ", "_") - # data = json.dumps( - # { - # "edges": { - # src_v_type: { - # src_v_id: { - # edge_type: { - # tgt_v_type: { - # tgt_v_id: attrs, - # } - # } - # }, - # } - # } - # } - # ) await load_q.put( ( "edges", @@ -293,17 +259,6 @@ async def upsert_edge( ) ) - # headers = make_headers(conn) - # async with httpx.AsyncClient(timeout=http_timeout) as client: - # async with tg_sem: - # res = await client.post( - # f"{conn.restppUrl}/graph/{conn.graphname}", data=data, headers=headers - # ) - # try: - # res.raise_for_status() - # except Exception as e: - # logger.error(f"Upsert Edge err:\n{e}") - async def get_commuinty_children(conn, i: int, c: str): headers = make_headers(conn) diff --git a/eventual-consistency-service/app/graphrag/workers.py b/eventual-consistency-service/app/graphrag/workers.py index 9d8267b2..c1b355f2 100644 --- a/eventual-consistency-service/app/graphrag/workers.py +++ b/eventual-consistency-service/app/graphrag/workers.py @@ -277,6 +277,7 @@ async def extract( ) # embed "Relationship", # (v_id, content, index_name) + # right now, we're not embedding relationships in graphrag async def resolve_entity( From a03188d4b7276463597c82922ae90586fb443646 Mon Sep 17 00:00:00 2001 From: RobRossmiller-TG <165701656+RobRossmiller-TG@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:25:07 -0400 Subject: [PATCH 4/7] graph rag v0 --- eventual-consistency-service/app/graphrag/graph_rag.py | 5 ++--- eventual-consistency-service/app/graphrag/util.py | 7 +++---- eventual-consistency-service/app/graphrag/workers.py | 9 ++++----- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/eventual-consistency-service/app/graphrag/graph_rag.py b/eventual-consistency-service/app/graphrag/graph_rag.py index b69198c4..4f8ccc61 100644 --- a/eventual-consistency-service/app/graphrag/graph_rag.py +++ b/eventual-consistency-service/app/graphrag/graph_rag.py @@ -352,7 +352,7 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel): res.raise_for_status() mod = res.json()["results"][0]["mod"] logger.info(f"mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})") - if mod == 0 or mod - prev_mod < -0.05: + if mod == 0 or mod - prev_mod <= -0.05: break # write iter to chan for layer to be processed @@ -500,9 +500,8 @@ async def run(graphname: str, conn: TigerGraphConnection): load_q.reopen() async with asyncio.TaskGroup() as grp: # run louvain - grp.create_task(communities(conn, comm_process_chan)) # get the communities - # grp.create_task( stream_communities(conn, communities_chan, comm_process_chan)) + grp.create_task(communities(conn, comm_process_chan)) # summarize each community grp.create_task( summarize_communities(conn, comm_process_chan, upsert_chan, embed_chan) diff --git a/eventual-consistency-service/app/graphrag/util.py b/eventual-consistency-service/app/graphrag/util.py index ccf00cb2..ca99cdde 100644 --- a/eventual-consistency-service/app/graphrag/util.py +++ b/eventual-consistency-service/app/graphrag/util.py @@ -6,9 +6,6 @@ from glob import glob import httpx -from graphrag import reusable_channel, workers -from pyTigerGraph import TigerGraphConnection - from common.config import ( doc_processing_config, embedding_service, @@ -20,6 +17,8 @@ from common.extractors import GraphExtractor, LLMEntityRelationshipExtractor from common.extractors.BaseExtractor import BaseExtractor from common.logs.logwriter import LogWriter +from graphrag import reusable_channel, workers +from pyTigerGraph import TigerGraphConnection logger = logging.getLogger(__name__) http_timeout = httpx.Timeout(15.0) @@ -182,6 +181,7 @@ def process_id(v_id: str): v_id = has_func[0] if v_id == "''" or v_id == '""': return "" + v_id = v_id.replace("(", "").replace(")", "") return v_id @@ -287,7 +287,6 @@ async def get_commuinty_children(conn, i: int, c: str): else: descrs.append(desc) - print(f"Comm: {c} --> {descrs}", flush=True) return descrs diff --git a/eventual-consistency-service/app/graphrag/workers.py b/eventual-consistency-service/app/graphrag/workers.py index c1b355f2..d696df8b 100644 --- a/eventual-consistency-service/app/graphrag/workers.py +++ b/eventual-consistency-service/app/graphrag/workers.py @@ -7,15 +7,14 @@ import ecc_util import httpx from aiochannel import Channel -from graphrag import community_summarizer, util -from langchain_community.graphs.graph_document import GraphDocument, Node -from pyTigerGraph import TigerGraphConnection - from common.config import milvus_config from common.embeddings.embedding_services import EmbeddingModel from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore from common.extractors.BaseExtractor import BaseExtractor from common.logs.logwriter import LogWriter +from graphrag import community_summarizer, util +from langchain_community.graphs.graph_document import GraphDocument, Node +from pyTigerGraph import TigerGraphConnection vertex_field = milvus_config.get("vertex_field", "vertex_id") @@ -378,7 +377,7 @@ async def process_community( summarizer = community_summarizer.CommunitySummarizer(llm) summary = await summarizer.summarize(comm_id, children) - print(f"*******>{comm_id}: {children}, {summary}", flush=True) + logger.debug(f"*******>{comm_id}: {children}, {summary}") await upsert_chan.put( ( util.upsert_vertex, # func to call From a3c6dfbdf8e96c88f2208aaaa861021b1aa77ac6 Mon Sep 17 00:00:00 2001 From: RobRossmiller-TG <165701656+RobRossmiller-TG@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:28:22 -0400 Subject: [PATCH 5/7] rm loader --- common/gsql/graphRAG/loaders/tmp.gsql | 26 -------------------------- 1 file changed, 26 deletions(-) delete mode 100644 common/gsql/graphRAG/loaders/tmp.gsql diff --git a/common/gsql/graphRAG/loaders/tmp.gsql b/common/gsql/graphRAG/loaders/tmp.gsql deleted file mode 100644 index e8d8d417..00000000 --- a/common/gsql/graphRAG/loaders/tmp.gsql +++ /dev/null @@ -1,26 +0,0 @@ -CREATE LOADING load_entity@uuid@ { - DEFINE FILENAME Content; - LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; -} - - -CREATE LOADING load_ResolvedEntity@uuid@ { - DEFINE FILENAME Content; - LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; -} -CREATE LOADING load_ asdfasdf @uuid@ { - DEFINE FILENAME Content; - LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; -} -CREATE LOADING load_ asdfasdf @uuid@ { - DEFINE FILENAME Content; - LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; -} -CREATE LOADING load_ asdfasdf @uuid@ { - DEFINE FILENAME Content; - LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; -} -CREATE LOADING load_ asdfasdf @uuid@ { - DEFINE FILENAME Content; - LOAD DocumentContent TO VERTEX Document VALUES() USING SEPARATOR="|", HEADER="true", EOL="\n", QUOTE="double"; -} From 9d286e089ea6d8ad3ab79a93cd31ad1119f4608e Mon Sep 17 00:00:00 2001 From: RobRossmiller-TG <165701656+RobRossmiller-TG@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:31:18 -0400 Subject: [PATCH 6/7] final cleanup --- eventual-consistency-service/app/graphrag/graph_rag.py | 6 ------ eventual-consistency-service/app/graphrag/workers.py | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/eventual-consistency-service/app/graphrag/graph_rag.py b/eventual-consistency-service/app/graphrag/graph_rag.py index 4f8ccc61..1d8f6084 100644 --- a/eventual-consistency-service/app/graphrag/graph_rag.py +++ b/eventual-consistency-service/app/graphrag/graph_rag.py @@ -88,12 +88,6 @@ async def chunk_docs( doc_tasks = [] async with asyncio.TaskGroup() as grp: async for content in docs_chan: - # v_id = content["v_id"] - # txt = content["attributes"]["text"] - # send the document to be embedded - logger.info("chunk writes to extract") - # await embed_chan.put((v_id, txt, "Document")) - task = grp.create_task( workers.chunk_doc(conn, content, upsert_chan, embed_chan, extract_chan) ) diff --git a/eventual-consistency-service/app/graphrag/workers.py b/eventual-consistency-service/app/graphrag/workers.py index d696df8b..98b3e69c 100644 --- a/eventual-consistency-service/app/graphrag/workers.py +++ b/eventual-consistency-service/app/graphrag/workers.py @@ -377,7 +377,7 @@ async def process_community( summarizer = community_summarizer.CommunitySummarizer(llm) summary = await summarizer.summarize(comm_id, children) - logger.debug(f"*******>{comm_id}: {children}, {summary}") + logger.debug(f"Community {comm_id}: {children}, {summary}") await upsert_chan.put( ( util.upsert_vertex, # func to call From 7a7436522c84e2f917a195402426462c643ec14c Mon Sep 17 00:00:00 2001 From: RobRossmiller-TG <165701656+RobRossmiller-TG@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:56:41 -0400 Subject: [PATCH 7/7] reset langchain openai version --- common/requirements.txt | 2 +- copilot/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/requirements.txt b/common/requirements.txt index 9912b4a8..86bdc50c 100644 --- a/common/requirements.txt +++ b/common/requirements.txt @@ -84,7 +84,7 @@ langchain-experimental==0.0.64 langchain-groq==0.1.9 langchain-ibm==0.1.12 langchain-milvus==0.1.4 -langchain-openai==0.1.22 +langchain-openai==0.1.21 langchain-text-splitters==0.2.2 langchainhub==0.1.21 langdetect==1.0.9 diff --git a/copilot/requirements.txt b/copilot/requirements.txt index cd9bb7bc..e057eb90 100644 --- a/copilot/requirements.txt +++ b/copilot/requirements.txt @@ -84,7 +84,7 @@ langchain-experimental==0.0.64 langchain-groq==0.1.9 langchain-ibm==0.1.12 langchain-milvus==0.1.4 -langchain-openai==0.1.22 +langchain-openai==0.1.21 langchain-text-splitters==0.2.2 langchainhub==0.1.21 langdetect==1.0.9