Skip to content

Commit

Permalink
graph rag v0
Browse files Browse the repository at this point in the history
  • Loading branch information
RobRossmiller-TG committed Aug 20, 2024
1 parent 2b7099b commit a03188d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
5 changes: 2 additions & 3 deletions eventual-consistency-service/app/graphrag/graph_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ async def communities(conn: TigerGraphConnection, comm_process_chan: Channel):
res.raise_for_status()
mod = res.json()["results"][0]["mod"]
logger.info(f"mod pass {i+1}: {mod} (diff= {abs(prev_mod - mod)})")
if mod == 0 or mod - prev_mod < -0.05:
if mod == 0 or mod - prev_mod <= -0.05:
break

# write iter to chan for layer to be processed
Expand Down Expand Up @@ -500,9 +500,8 @@ async def run(graphname: str, conn: TigerGraphConnection):
load_q.reopen()
async with asyncio.TaskGroup() as grp:
# run louvain
grp.create_task(communities(conn, comm_process_chan))
# get the communities
# grp.create_task( stream_communities(conn, communities_chan, comm_process_chan))
grp.create_task(communities(conn, comm_process_chan))
# summarize each community
grp.create_task(
summarize_communities(conn, comm_process_chan, upsert_chan, embed_chan)
Expand Down
7 changes: 3 additions & 4 deletions eventual-consistency-service/app/graphrag/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from glob import glob

import httpx
from graphrag import reusable_channel, workers
from pyTigerGraph import TigerGraphConnection

from common.config import (
doc_processing_config,
embedding_service,
Expand All @@ -20,6 +17,8 @@
from common.extractors import GraphExtractor, LLMEntityRelationshipExtractor
from common.extractors.BaseExtractor import BaseExtractor
from common.logs.logwriter import LogWriter
from graphrag import reusable_channel, workers
from pyTigerGraph import TigerGraphConnection

logger = logging.getLogger(__name__)
http_timeout = httpx.Timeout(15.0)
Expand Down Expand Up @@ -182,6 +181,7 @@ def process_id(v_id: str):
v_id = has_func[0]
if v_id == "''" or v_id == '""':
return ""
v_id = v_id.replace("(", "").replace(")", "")

return v_id

Expand Down Expand Up @@ -287,7 +287,6 @@ async def get_commuinty_children(conn, i: int, c: str):
else:
descrs.append(desc)

print(f"Comm: {c} --> {descrs}", flush=True)
return descrs


Expand Down
9 changes: 4 additions & 5 deletions eventual-consistency-service/app/graphrag/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
import ecc_util
import httpx
from aiochannel import Channel
from graphrag import community_summarizer, util
from langchain_community.graphs.graph_document import GraphDocument, Node
from pyTigerGraph import TigerGraphConnection

from common.config import milvus_config
from common.embeddings.embedding_services import EmbeddingModel
from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore
from common.extractors.BaseExtractor import BaseExtractor
from common.logs.logwriter import LogWriter
from graphrag import community_summarizer, util
from langchain_community.graphs.graph_document import GraphDocument, Node
from pyTigerGraph import TigerGraphConnection

vertex_field = milvus_config.get("vertex_field", "vertex_id")

Expand Down Expand Up @@ -378,7 +377,7 @@ async def process_community(
summarizer = community_summarizer.CommunitySummarizer(llm)
summary = await summarizer.summarize(comm_id, children)

print(f"*******>{comm_id}: {children}, {summary}", flush=True)
logger.debug(f"*******>{comm_id}: {children}, {summary}")
await upsert_chan.put(
(
util.upsert_vertex, # func to call
Expand Down

0 comments on commit a03188d

Please sign in to comment.