Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stability improvements #269

Merged
merged 5 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def get_llm_service(llm_config) -> LLM_Model:
):
doc_processing_config = {
"chunker": "semantic",
"chunker_config": {"method": "percentile", "threshold": 0.95},
"chunker_config": {"method": "percentile", "threshold": 0.90},
"extractor": "graphrag",
"extractor_config": {},
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CREATE DISTRIBUTED QUERY GraphRAG_CommunityRetriever(INT community_level=2) FOR GRAPH pyTigerGraphRAG {
CREATE DISTRIBUTED QUERY GraphRAG_Community_Retriever(INT community_level=2) {
comms = {Community.*};

selected_comms = SELECT c FROM comms:c WHERE c.iteration == community_level;
Expand Down
4 changes: 2 additions & 2 deletions copilot/app/supportai/retrievers/GraphRAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def __init__(
connection: TigerGraphConnectionProxy,
):
super().__init__(embedding_service, embedding_store, llm_service, connection)
self._check_query_install("GraphRAG_CommunityRetriever")
self._check_query_install("GraphRAG_Community_Retriever")

def search(self, question, community_level: int):
res = self.conn.runInstalledQuery("GraphRAG_CommunityRetriever", {"community_level": community_level}, usePost=True)
res = self.conn.runInstalledQuery("GraphRAG_Community_Retriever", {"community_level": community_level}, usePost=True)
return res

async def _generate_candidate(self, question, context):
Expand Down
27 changes: 16 additions & 11 deletions eventual-consistency-service/app/graphrag/graph_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
http_timeout,
init,
load_q,
loading_event,
make_headers,
stream_ids,
tg_sem,
Expand Down Expand Up @@ -124,7 +125,7 @@ async def upsert(upsert_chan: Channel):
async def load(conn: TigerGraphConnection):
logger.info("Reading from load_q")
dd = lambda: defaultdict(dd) # infinite default dict
batch_size = 1000
batch_size = 500
# while the load q is still open or has contents
while not load_q.closed() or not load_q.empty():
if load_q.closed():
Expand Down Expand Up @@ -169,7 +170,11 @@ async def load(conn: TigerGraphConnection):
logger.info(
f"Upserting batch size of {size}. ({n_verts} verts | {n_edges} edges. {len(data.encode())/1000:,} kb)"
)

loading_event.clear()
await upsert_batch(conn, data)
await asyncio.sleep(5)
loading_event.set()
else:
await asyncio.sleep(1)

Expand Down Expand Up @@ -435,12 +440,12 @@ async def run(graphname: str, conn: TigerGraphConnection):
if doc_process_switch:
logger.info("Doc Processing Start")
docs_chan = Channel(1)
embed_chan = Channel(100)
upsert_chan = Channel(100)
extract_chan = Channel(100)
embed_chan = Channel()
upsert_chan = Channel()
extract_chan = Channel()
async with asyncio.TaskGroup() as grp:
# get docs
grp.create_task(stream_docs(conn, docs_chan, 10))
grp.create_task(stream_docs(conn, docs_chan, 100))
# process docs
grp.create_task(
chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan)
Expand All @@ -462,8 +467,8 @@ async def run(graphname: str, conn: TigerGraphConnection):

if entity_resolution_switch:
logger.info("Entity Processing Start")
entities_chan = Channel(100)
upsert_chan = Channel(100)
entities_chan = Channel()
upsert_chan = Channel()
load_q.reopen()
async with asyncio.TaskGroup() as grp:
grp.create_task(stream_entities(conn, entities_chan, 50))
Expand All @@ -487,10 +492,10 @@ async def run(graphname: str, conn: TigerGraphConnection):
community_start = time.perf_counter()
if community_detection_switch:
logger.info("Community Processing Start")
upsert_chan = Channel(10)
comm_process_chan = Channel(100)
upsert_chan = Channel(100)
embed_chan = Channel(100)
upsert_chan = Channel()
comm_process_chan = Channel()
upsert_chan = Channel()
embed_chan = Channel()
load_q.reopen()
async with asyncio.TaskGroup() as grp:
# run louvain
Expand Down
37 changes: 25 additions & 12 deletions eventual-consistency-service/app/graphrag/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from glob import glob

import httpx
from graphrag import reusable_channel, workers
from pyTigerGraph import TigerGraphConnection

from common.config import (
doc_processing_config,
embedding_service,
Expand All @@ -17,15 +20,17 @@
from common.extractors import GraphExtractor, LLMEntityRelationshipExtractor
from common.extractors.BaseExtractor import BaseExtractor
from common.logs.logwriter import LogWriter
from graphrag import reusable_channel, workers
from pyTigerGraph import TigerGraphConnection

logger = logging.getLogger(__name__)
http_timeout = httpx.Timeout(15.0)

tg_sem = asyncio.Semaphore(20)
tg_sem = asyncio.Semaphore(2)
load_q = reusable_channel.ReuseableChannel()

# will pause workers until the event is false
loading_event = asyncio.Event()
loading_event.set() # set the event to true to allow the workers to run

async def install_queries(
requried_queries: list[str],
conn: TigerGraphConnection,
Expand Down Expand Up @@ -107,7 +112,7 @@ async def init(
vector_field=milvus_config.get("vector_field", "document_vector"),
text_field=milvus_config.get("text_field", "document_content"),
vertex_field=vertex_field,
drop_old=False,
drop_old=True,
parkererickson-tg marked this conversation as resolved.
Show resolved Hide resolved
)

LogWriter.info(f"Initializing {name}")
Expand Down Expand Up @@ -207,7 +212,6 @@ async def upsert_batch(conn: TigerGraphConnection, data: str):
res.raise_for_status()



async def check_vertex_exists(conn, v_id: str):
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=http_timeout) as client:
Expand All @@ -219,7 +223,8 @@ async def check_vertex_exists(conn, v_id: str):
)

except Exception as e:
logger.error(f"Check err:\n{e}")
err = traceback.format_exc()
logger.error(f"Check err:\n{err}")
return {"error": True}

try:
Expand Down Expand Up @@ -264,17 +269,25 @@ async def get_commuinty_children(conn, i: int, c: str):
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=None) as client:
async with tg_sem:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/get_community_children",
params={"comm": c, "iter": i},
headers=headers,
)
try:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/get_community_children",
params={"comm": c, "iter": i},
headers=headers,
)
except:
logger.error(f"Get Children err:\n{traceback.format_exc()}")
try:
resp.raise_for_status()
except Exception as e:
logger.error(f"Get Children err:\n{e}")
descrs = []
for d in resp.json()["results"][0]["children"]:
try:
res = resp.json()["results"][0]["children"]
except Exception as e:
logger.error(f"Get Children err:\n{e}")
res = []
for d in res:
desc = d["attributes"]["description"]
# if it's the entity iteration
if i == 1:
Expand Down
Loading
Loading