Skip to content

Commit

Permalink
Merge pull request #265 from tigergraph/GML-1828-load_tg_docs
Browse files Browse the repository at this point in the history
Gml 1828 load tg docs
  • Loading branch information
parkererickson-tg authored Aug 20, 2024
2 parents da05a4a + 7a74365 commit ef2ac8e
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 88 deletions.
4 changes: 0 additions & 4 deletions common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@
"MILVUS_CONFIG must be a .json file or a JSON string, failed with error: "
+ str(e)
)


if llm_config["embedding_service"]["embedding_model_service"].lower() == "openai":
embedding_service = OpenAI_Embedding(llm_config["embedding_service"])
elif llm_config["embedding_service"]["embedding_model_service"].lower() == "azure":
Expand Down Expand Up @@ -128,11 +126,9 @@ def get_llm_service(llm_config) -> LLM_Model:
else:
raise Exception("LLM Completion Service Not Supported")


LogWriter.info(
f"Milvus enabled for host {milvus_config['host']} at port {milvus_config['port']}"
)

if os.getenv("INIT_EMBED_STORE", "true")=="true":
LogWriter.info("Setting up Milvus embedding store for InquiryAI")
try:
Expand Down
8 changes: 7 additions & 1 deletion common/gsql/graphRAG/communities_have_desc.gsql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter) SYNTAX V2{
CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter, BOOL p=False) SYNTAX V2{
SumAccum<INT> @@descrs;
Comms = {Community.*};
Comms = SELECT c FROM Comms:c
Expand All @@ -11,4 +11,10 @@ CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter) SYNTAX V2{

PRINT (@@descrs == Comms.size()) as all_have_desc;
PRINT @@descrs, Comms.size();

IF p THEN
Comms = SELECT c FROM Comms:c
WHERE c.iteration == iter and length(c.description) == 0;
PRINT Comms;
END;
}
10 changes: 10 additions & 0 deletions common/gsql/graphRAG/entities_have_resolution.gsql
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CREATE DISTRIBUTED QUERY entities_have_resolution() SYNTAX V2{
SumAccum<INT> @@resolved;
Ents = {Entity.*};
Ents = SELECT s FROM Ents:s -(RESOLVES_TO>)- ResolvedEntity
ACCUM @@resolved += 1;


PRINT (@@resolved >= Ents.size()) as all_resolved;
PRINT @@resolved, Ents.size();
}
13 changes: 11 additions & 2 deletions common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches
MaxAccum<Move> @best_move; // best move of the node with the highest delta Q to move the isolated node into the new community
MaxAccum<DOUBLE> @@min_double; // used to reset the @best_move
SumAccum<INT> @@move_cnt;
OrAccum @to_change_community;
OrAccum @to_change_community, @has_relations;
SumAccum<INT> @batch_id;
MinAccum<INT> @vid;

Expand Down Expand Up @@ -152,6 +152,8 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches
@@community_sum_total_map.clear();
Tmp = SELECT s FROM AllNodes:s -(_:e)-> ResolvedEntity:t
ACCUM
s.@has_relations += TRUE,
t.@has_relations += TRUE,
IF s.@community_id == t.@community_id THEN
// keep track of how many edges are within the community
@@community_sum_in_map += (s.@community_id -> wt)
Expand All @@ -165,7 +167,14 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches
INSERT INTO IN_COMMUNITY VALUES (s, s.@community_vid+"_1") // link entity to it's first community
;

PRINT @@source_target_k_in_map;
// Continue community hierarchy for unattached partitions
Tmp = SELECT s FROM AllNodes:s
WHERE NOT s.@has_relations
POST-ACCUM
// if s is a part of an unattached partition, add to its community hierarchy to maintain parity with rest of graph
INSERT INTO Community VALUES (s.id+"_1", 1, ""),
INSERT INTO IN_COMMUNITY VALUES (s, s.id+"_1"); // link entity to it's first community


@@community_sum_total_map.clear();
// link communities
Expand Down
106 changes: 86 additions & 20 deletions eventual-consistency-service/app/graphrag/graph_rag.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import asyncio
import json
import logging
import time
import traceback
from collections import defaultdict

import httpx
from aiochannel import Channel
from graphrag import workers
from graphrag.util import (
check_all_ents_resolved,
check_vertex_has_desc,
http_timeout,
init,
load_q,
make_headers,
stream_ids,
tg_sem,
upsert_batch,
)
from pyTigerGraph import TigerGraphConnection

Expand Down Expand Up @@ -83,12 +88,6 @@ async def chunk_docs(
doc_tasks = []
async with asyncio.TaskGroup() as grp:
async for content in docs_chan:
v_id = content["v_id"]
txt = content["attributes"]["text"]
# send the document to be embedded
logger.info("chunk writes to extract")
# await embed_chan.put((v_id, txt, "Document"))

task = grp.create_task(
workers.chunk_doc(conn, content, upsert_chan, embed_chan, extract_chan)
)
Expand Down Expand Up @@ -117,7 +116,66 @@ async def upsert(upsert_chan: Channel):
# execute the task
grp.create_task(func(*args))

logger.info(f"upsert done")
logger.info("upsert done")
logger.info("closing load_q chan")
load_q.close()


async def load(conn: TigerGraphConnection):
logger.info("Reading from load_q")
dd = lambda: defaultdict(dd) # infinite default dict
batch_size = 1000
# while the load q is still open or has contents
while not load_q.closed() or not load_q.empty():
if load_q.closed():
logger.info(
f"load queue closed. Flushing load queue (final load for this stage)"
)
# if there's `batch_size` entities in the channel, load it
# or if the channel is closed, flush it
if load_q.qsize() >= batch_size or load_q.closed() or load_q.should_flush():
batch = {
"vertices": defaultdict(dict[str, any]),
"edges": dd(),
}
n_verts = 0
n_edges = 0
size = (
load_q.qsize()
if load_q.closed() or load_q.should_flush()
else batch_size
)
for _ in range(size):
t, elem = await load_q.get()
if t == "FLUSH":
logger.debug(f"flush recieved: {t}")
load_q._should_flush = False
break
match t:
case "vertices":
vt, v_id, attr = elem
batch[t][vt][v_id] = attr
n_verts += 1
case "edges":
src_v_type, src_v_id, edge_type, tgt_v_type, tgt_v_id, attrs = (
elem
)
batch[t][src_v_type][src_v_id][edge_type][tgt_v_type][
tgt_v_id
] = attrs
n_edges += 1

data = json.dumps(batch)
logger.info(
f"Upserting batch size of {size}. ({n_verts} verts | {n_edges} edges. {len(data.encode())/1000:,} kb)"
)
await upsert_batch(conn, data)
else:
await asyncio.sleep(1)

# TODO: flush q if it's not empty
if not load_q.empty():
raise Exception(f"load_q not empty: {load_q.qsize()}", flush=True)


async def embed(
Expand Down Expand Up @@ -287,7 +345,9 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel):
)
res.raise_for_status()
mod = res.json()["results"][0]["mod"]
logger.info(f"*** mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})")
logger.info(f"mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})")
if mod == 0 or mod - prev_mod <= -0.05:
break

# write iter to chan for layer to be processed
await stream_communities(conn, i + 1, comm_process_chan)
Expand All @@ -308,8 +368,6 @@ async def stream_communities(
logger.info("streaming communities")

headers = make_headers(conn)
# TODO:
# can only do one layer at a time to ensure that every child community has their descriptions

# async for i in community_chan:
# get the community from that layer
Expand All @@ -329,14 +387,16 @@ async def stream_communities(
# Wait for all communities for layer i to be processed before doing next layer
# all community descriptions must be populated before the next layer can be processed
if len(comms) > 0:
n_waits = 0
while not await check_vertex_has_desc(conn, i):
logger.info(f"Waiting for layer{i} to finish processing")
await asyncio.sleep(5)
n_waits += 1
if n_waits > 3:
logger.info("Flushing load_q")
await load_q.flush(("FLUSH", None))
await asyncio.sleep(3)

logger.info("stream_communities done")
logger.info("closing comm_process_chan")


async def summarize_communities(
conn: TigerGraphConnection,
Expand All @@ -353,7 +413,7 @@ async def summarize_communities(
embed_chan.close()


async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
async def run(graphname: str, conn: TigerGraphConnection):
"""
Set up GraphRAG:
- Install necessary queries.
Expand All @@ -370,8 +430,8 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
init_start = time.perf_counter()

doc_process_switch = True
entity_resolution_switch =True
community_detection_switch =True
entity_resolution_switch = True
community_detection_switch = True
if doc_process_switch:
logger.info("Doc Processing Start")
docs_chan = Channel(1)
Expand All @@ -386,7 +446,8 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan)
)
# upsert chunks
grp.create_task(upsert( upsert_chan))
grp.create_task(upsert(upsert_chan))
grp.create_task(load(conn))
# embed
grp.create_task(embed(embed_chan, index_stores, graphname))
# extract entities
Expand All @@ -403,6 +464,7 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
logger.info("Entity Processing Start")
entities_chan = Channel(100)
upsert_chan = Channel(100)
load_q.reopen()
async with asyncio.TaskGroup() as grp:
grp.create_task(stream_entities(conn, entities_chan, 50))
grp.create_task(
Expand All @@ -414,8 +476,12 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
)
)
grp.create_task(upsert(upsert_chan))
grp.create_task(load(conn))
entity_end = time.perf_counter()
logger.info("Entity Processing End")
while not await check_all_ents_resolved(conn):
logger.info(f"Waiting for resolved entites to finish loading")
await asyncio.sleep(1)

# Community Detection
community_start = time.perf_counter()
Expand All @@ -425,17 +491,17 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
comm_process_chan = Channel(100)
upsert_chan = Channel(100)
embed_chan = Channel(100)
load_q.reopen()
async with asyncio.TaskGroup() as grp:
# run louvain
# grp.create_task(communities(conn, communities_chan))
grp.create_task(communities(conn, comm_process_chan))
# get the communities
# grp.create_task( stream_communities(conn, communities_chan, comm_process_chan))
grp.create_task(communities(conn, comm_process_chan))
# summarize each community
grp.create_task(
summarize_communities(conn, comm_process_chan, upsert_chan, embed_chan)
)
grp.create_task(upsert(upsert_chan))
grp.create_task(load(conn))
grp.create_task(embed(embed_chan, index_stores, graphname))

community_end = time.perf_counter()
Expand Down
37 changes: 37 additions & 0 deletions eventual-consistency-service/app/graphrag/reusable_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from asyncio import Queue


class ReuseableChannel:
def __init__(self, maxsize=0) -> None:
self.maxsize = maxsize
self.q = Queue(maxsize)
self._closed = False
self._should_flush = False

async def put(self, item: any) -> None:
await self.q.put(item)

async def get(self) -> any:
return await self.q.get()

def closed(self):
return self._closed

def should_flush(self):
return self._should_flush

async def flush(self, flush_msg=None):
self._should_flush = True
await self.put(flush_msg)

def empty(self):
return self.q.empty()

def close(self):
self._closed = True

def qsize(self) -> int:
return self.q.qsize()

def reopen(self):
self._closed = False
Loading

0 comments on commit ef2ac8e

Please sign in to comment.