Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gml 1828 load tg docs #265

Merged
merged 7 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@
"MILVUS_CONFIG must be a .json file or a JSON string, failed with error: "
+ str(e)
)


if llm_config["embedding_service"]["embedding_model_service"].lower() == "openai":
embedding_service = OpenAI_Embedding(llm_config["embedding_service"])
elif llm_config["embedding_service"]["embedding_model_service"].lower() == "azure":
Expand Down Expand Up @@ -128,11 +126,9 @@ def get_llm_service(llm_config) -> LLM_Model:
else:
raise Exception("LLM Completion Service Not Supported")


LogWriter.info(
f"Milvus enabled for host {milvus_config['host']} at port {milvus_config['port']}"
)

if os.getenv("INIT_EMBED_STORE", "true")=="true":
LogWriter.info("Setting up Milvus embedding store for InquiryAI")
try:
Expand Down
8 changes: 7 additions & 1 deletion common/gsql/graphRAG/communities_have_desc.gsql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter) SYNTAX V2{
CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter, BOOL p=False) SYNTAX V2{
SumAccum<INT> @@descrs;
Comms = {Community.*};
Comms = SELECT c FROM Comms:c
Expand All @@ -11,4 +11,10 @@ CREATE DISTRIBUTED QUERY communities_have_desc(UINT iter) SYNTAX V2{

PRINT (@@descrs == Comms.size()) as all_have_desc;
PRINT @@descrs, Comms.size();

IF p THEN
Comms = SELECT c FROM Comms:c
WHERE c.iteration == iter and length(c.description) == 0;
PRINT Comms;
END;
}
10 changes: 10 additions & 0 deletions common/gsql/graphRAG/entities_have_resolution.gsql
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CREATE DISTRIBUTED QUERY entities_have_resolution() SYNTAX V2{
SumAccum<INT> @@resolved;
Ents = {Entity.*};
Ents = SELECT s FROM Ents:s -(RESOLVES_TO>)- ResolvedEntity
ACCUM @@resolved += 1;


PRINT (@@resolved >= Ents.size()) as all_resolved;
PRINT @@resolved, Ents.size();
}
13 changes: 11 additions & 2 deletions common/gsql/graphRAG/louvain/graphrag_louvain_init.gsql
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches
MaxAccum<Move> @best_move; // best move of the node with the highest delta Q to move the isolated node into the new community
MaxAccum<DOUBLE> @@min_double; // used to reset the @best_move
SumAccum<INT> @@move_cnt;
OrAccum @to_change_community;
OrAccum @to_change_community, @has_relations;
SumAccum<INT> @batch_id;
MinAccum<INT> @vid;

Expand Down Expand Up @@ -152,6 +152,8 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches
@@community_sum_total_map.clear();
Tmp = SELECT s FROM AllNodes:s -(_:e)-> ResolvedEntity:t
ACCUM
s.@has_relations += TRUE,
t.@has_relations += TRUE,
IF s.@community_id == t.@community_id THEN
// keep track of how many edges are within the community
@@community_sum_in_map += (s.@community_id -> wt)
Expand All @@ -165,7 +167,14 @@ CREATE DISTRIBUTED QUERY graphrag_louvain_init(UINT max_hop = 10, UINT n_batches
INSERT INTO IN_COMMUNITY VALUES (s, s.@community_vid+"_1") // link entity to it's first community
;

PRINT @@source_target_k_in_map;
// Continue community hierarchy for unattached partitions
Tmp = SELECT s FROM AllNodes:s
WHERE NOT s.@has_relations
POST-ACCUM
// if s is a part of an unattached partition, add to its community hierarchy to maintain parity with rest of graph
INSERT INTO Community VALUES (s.id+"_1", 1, ""),
INSERT INTO IN_COMMUNITY VALUES (s, s.id+"_1"); // link entity to it's first community


@@community_sum_total_map.clear();
// link communities
Expand Down
3 changes: 2 additions & 1 deletion common/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
aiochannel==1.2.1
aiohappyeyeballs==2.3.5
aiohttp==3.10.3
aiosignal==1.3.1
Expand Down Expand Up @@ -100,7 +101,7 @@ minio==7.2.7
multidict==6.0.5
mypy-extensions==1.0.0
nest-asyncio==1.6.0
nltk==3.8.2
nltk==3.8.1
numpy==1.26.4
openai==1.40.6
ordered-set==4.1.0
Expand Down
2 changes: 1 addition & 1 deletion copilot/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ minio==7.2.7
multidict==6.0.5
mypy-extensions==1.0.0
nest-asyncio==1.6.0
nltk==3.8.2
nltk==3.8.1
numpy==1.26.4
openai==1.40.6
ordered-set==4.1.0
Expand Down
106 changes: 86 additions & 20 deletions eventual-consistency-service/app/graphrag/graph_rag.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import asyncio
import json
import logging
import time
import traceback
from collections import defaultdict

import httpx
from aiochannel import Channel
from graphrag import workers
from graphrag.util import (
check_all_ents_resolved,
check_vertex_has_desc,
http_timeout,
init,
load_q,
make_headers,
stream_ids,
tg_sem,
upsert_batch,
)
from pyTigerGraph import TigerGraphConnection

Expand Down Expand Up @@ -83,12 +88,6 @@ async def chunk_docs(
doc_tasks = []
async with asyncio.TaskGroup() as grp:
async for content in docs_chan:
v_id = content["v_id"]
txt = content["attributes"]["text"]
# send the document to be embedded
logger.info("chunk writes to extract")
# await embed_chan.put((v_id, txt, "Document"))

task = grp.create_task(
workers.chunk_doc(conn, content, upsert_chan, embed_chan, extract_chan)
)
Expand Down Expand Up @@ -117,7 +116,66 @@ async def upsert(upsert_chan: Channel):
# execute the task
grp.create_task(func(*args))

logger.info(f"upsert done")
logger.info("upsert done")
logger.info("closing load_q chan")
load_q.close()


async def load(conn: TigerGraphConnection):
logger.info("Reading from load_q")
dd = lambda: defaultdict(dd) # infinite default dict
batch_size = 1000
# while the load q is still open or has contents
while not load_q.closed() or not load_q.empty():
if load_q.closed():
logger.info(
f"load queue closed. Flushing load queue (final load for this stage)"
)
# if there's `batch_size` entities in the channel, load it
# or if the channel is closed, flush it
if load_q.qsize() >= batch_size or load_q.closed() or load_q.should_flush():
batch = {
"vertices": defaultdict(dict[str, any]),
"edges": dd(),
}
n_verts = 0
n_edges = 0
size = (
load_q.qsize()
if load_q.closed() or load_q.should_flush()
else batch_size
)
for _ in range(size):
t, elem = await load_q.get()
if t == "FLUSH":
logger.debug(f"flush recieved: {t}")
load_q._should_flush = False
break
match t:
case "vertices":
vt, v_id, attr = elem
batch[t][vt][v_id] = attr
n_verts += 1
case "edges":
src_v_type, src_v_id, edge_type, tgt_v_type, tgt_v_id, attrs = (
elem
)
batch[t][src_v_type][src_v_id][edge_type][tgt_v_type][
tgt_v_id
] = attrs
n_edges += 1

data = json.dumps(batch)
logger.info(
f"Upserting batch size of {size}. ({n_verts} verts | {n_edges} edges. {len(data.encode())/1000:,} kb)"
)
await upsert_batch(conn, data)
else:
await asyncio.sleep(1)

# TODO: flush q if it's not empty
if not load_q.empty():
raise Exception(f"load_q not empty: {load_q.qsize()}", flush=True)


async def embed(
Expand Down Expand Up @@ -287,7 +345,9 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel):
)
res.raise_for_status()
mod = res.json()["results"][0]["mod"]
logger.info(f"*** mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})")
logger.info(f"mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})")
if mod == 0 or mod - prev_mod <= -0.05:
break

# write iter to chan for layer to be processed
await stream_communities(conn, i + 1, comm_process_chan)
Expand All @@ -308,8 +368,6 @@ async def stream_communities(
logger.info("streaming communities")

headers = make_headers(conn)
# TODO:
# can only do one layer at a time to ensure that every child community has their descriptions

# async for i in community_chan:
# get the community from that layer
Expand All @@ -329,14 +387,16 @@ async def stream_communities(
# Wait for all communities for layer i to be processed before doing next layer
# all community descriptions must be populated before the next layer can be processed
if len(comms) > 0:
n_waits = 0
while not await check_vertex_has_desc(conn, i):
logger.info(f"Waiting for layer{i} to finish processing")
await asyncio.sleep(5)
n_waits += 1
if n_waits > 3:
logger.info("Flushing load_q")
await load_q.flush(("FLUSH", None))
await asyncio.sleep(3)

logger.info("stream_communities done")
logger.info("closing comm_process_chan")


async def summarize_communities(
conn: TigerGraphConnection,
Expand All @@ -353,7 +413,7 @@ async def summarize_communities(
embed_chan.close()


async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
async def run(graphname: str, conn: TigerGraphConnection):
"""
Set up GraphRAG:
- Install necessary queries.
Expand All @@ -370,8 +430,8 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
init_start = time.perf_counter()

doc_process_switch = True
entity_resolution_switch =True
community_detection_switch =True
entity_resolution_switch = True
community_detection_switch = True
if doc_process_switch:
logger.info("Doc Processing Start")
docs_chan = Channel(1)
Expand All @@ -386,7 +446,8 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan)
)
# upsert chunks
grp.create_task(upsert( upsert_chan))
grp.create_task(upsert(upsert_chan))
grp.create_task(load(conn))
# embed
grp.create_task(embed(embed_chan, index_stores, graphname))
# extract entities
Expand All @@ -403,6 +464,7 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
logger.info("Entity Processing Start")
entities_chan = Channel(100)
upsert_chan = Channel(100)
load_q.reopen()
async with asyncio.TaskGroup() as grp:
grp.create_task(stream_entities(conn, entities_chan, 50))
grp.create_task(
Expand All @@ -414,8 +476,12 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
)
)
grp.create_task(upsert(upsert_chan))
grp.create_task(load(conn))
entity_end = time.perf_counter()
logger.info("Entity Processing End")
while not await check_all_ents_resolved(conn):
logger.info(f"Waiting for resolved entites to finish loading")
await asyncio.sleep(1)

# Community Detection
community_start = time.perf_counter()
Expand All @@ -425,17 +491,17 @@ async def run(graphname: str, conn: TigerGraphConnection, upsert_limit=100):
comm_process_chan = Channel(100)
upsert_chan = Channel(100)
embed_chan = Channel(100)
load_q.reopen()
async with asyncio.TaskGroup() as grp:
# run louvain
# grp.create_task(communities(conn, communities_chan))
grp.create_task(communities(conn, comm_process_chan))
# get the communities
# grp.create_task( stream_communities(conn, communities_chan, comm_process_chan))
grp.create_task(communities(conn, comm_process_chan))
# summarize each community
grp.create_task(
summarize_communities(conn, comm_process_chan, upsert_chan, embed_chan)
)
grp.create_task(upsert(upsert_chan))
grp.create_task(load(conn))
grp.create_task(embed(embed_chan, index_stores, graphname))

community_end = time.perf_counter()
Expand Down
37 changes: 37 additions & 0 deletions eventual-consistency-service/app/graphrag/reusable_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from asyncio import Queue


class ReuseableChannel:
def __init__(self, maxsize=0) -> None:
self.maxsize = maxsize
self.q = Queue(maxsize)
self._closed = False
self._should_flush = False

async def put(self, item: any) -> None:
await self.q.put(item)

async def get(self) -> any:
return await self.q.get()

def closed(self):
return self._closed

def should_flush(self):
return self._should_flush

async def flush(self, flush_msg=None):
self._should_flush = True
await self.put(flush_msg)

def empty(self):
return self.q.empty()

def close(self):
self._closed = True

def qsize(self) -> int:
return self.q.qsize()

def reopen(self):
self._closed = False
Loading
Loading