Skip to content

Commit

Permalink
Stability improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
parkererickson-tg committed Aug 29, 2024
1 parent 8c96f26 commit 0604ecf
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 18 deletions.
2 changes: 1 addition & 1 deletion common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def get_llm_service(llm_config) -> LLM_Model:
):
doc_processing_config = {
"chunker": "semantic",
"chunker_config": {"method": "percentile", "threshold": 0.95},
"chunker_config": {"method": "percentile", "threshold": 0.90},
"extractor": "graphrag",
"extractor_config": {},
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CREATE DISTRIBUTED QUERY GraphRAG_CommunityRetriever(INT community_level=2) FOR GRAPH pyTigerGraphRAG {
CREATE DISTRIBUTED QUERY GraphRAG_Community_Retriever(INT community_level=2) {
comms = {Community.*};

selected_comms = SELECT c FROM comms:c WHERE c.iteration == community_level;
Expand Down
4 changes: 2 additions & 2 deletions copilot/app/supportai/retrievers/GraphRAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def __init__(
connection: TigerGraphConnectionProxy,
):
super().__init__(embedding_service, embedding_store, llm_service, connection)
self._check_query_install("GraphRAG_CommunityRetriever")
self._check_query_install("GraphRAG_Community_Retriever")

def search(self, question, community_level: int):
res = self.conn.runInstalledQuery("GraphRAG_CommunityRetriever", {"community_level": community_level}, usePost=True)
res = self.conn.runInstalledQuery("GraphRAG_Community_Retriever", {"community_level": community_level}, usePost=True)
return res

async def _generate_candidate(self, question, context):
Expand Down
22 changes: 15 additions & 7 deletions eventual-consistency-service/app/graphrag/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
logger = logging.getLogger(__name__)
http_timeout = httpx.Timeout(15.0)

tg_sem = asyncio.Semaphore(20)
tg_sem = asyncio.Semaphore(2)
load_q = reusable_channel.ReuseableChannel()

# will pause workers until the event is false
Expand Down Expand Up @@ -270,17 +270,25 @@ async def get_commuinty_children(conn, i: int, c: str):
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=None) as client:
async with tg_sem:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/get_community_children",
params={"comm": c, "iter": i},
headers=headers,
)
try:
resp = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/get_community_children",
params={"comm": c, "iter": i},
headers=headers,
)
except:
logger.error(f"Get Children err:\n{traceback.format_exc()}")
try:
resp.raise_for_status()
except Exception as e:
logger.error(f"Get Children err:\n{e}")
descrs = []
for d in resp.json()["results"][0]["children"]:
try:
res = resp.json()["results"][0]["children"]
except Exception as e:
logger.error(f"Get Children err:\n{e}")
res = []
for d in res:
desc = d["attributes"]["description"]
# if it's the entity iteration
if i == 1:
Expand Down
24 changes: 17 additions & 7 deletions eventual-consistency-service/app/graphrag/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,15 @@ async def embed(
if not util.loading_event.is_set():
logger.info("Embed worker waiting for loading event to finish")
await util.loading_event.wait()

vec = await embed_svc.aembed_query(content)
await embed_store.aadd_embeddings([(content, vec)], [{vertex_field: v_id}])
try:
vec = await embed_svc.aembed_query(content)
except Exception as e:
logger.error(f"Failed to embed {v_id}: {e}")
return
try:
await embed_store.aadd_embeddings([(content, vec)], [{vertex_field: v_id}])
except Exception as e:
logger.error(f"Failed to add embeddings for {v_id}: {e}")


async def get_vert_desc(conn, v_id, node: Node):
Expand Down Expand Up @@ -196,10 +202,14 @@ async def extract(
await util.loading_event.wait()

async with extract_sem:
extracted: list[GraphDocument] = await extractor.aextract(chunk)
logger.info(
f"Extracting chunk: {chunk_id} ({len(extracted)} graph docs extracted)"
)
try:
extracted: list[GraphDocument] = await extractor.aextract(chunk)
logger.info(
f"Extracting chunk: {chunk_id} ({len(extracted)} graph docs extracted)"
)
except Exception as e:
logger.error(f"Failed to extract chunk {chunk_id}: {e}")
extracted = []

# upsert nodes and edges to the graph
for doc in extracted:
Expand Down

0 comments on commit 0604ecf

Please sign in to comment.