Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(graphrag): add descriptions to all upserts, cooccurence edges #270

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def get_llm_service(llm_config) -> LLM_Model:
doc_processing_config = {
"chunker": "semantic",
"chunker_config": {"method": "percentile", "threshold": 0.90},
"extractor": "graphrag",
"extractor": "llm",
"extractor_config": {},
}
elif DOC_PROCESSING_CONFIG.endswith(".json"):
Expand Down
198 changes: 165 additions & 33 deletions common/extractors/LLMEntityRelationshipExtractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from common.extractors.BaseExtractor import BaseExtractor
from common.llm_services import LLM_Model
from common.py_schemas import KnowledgeGraph

from langchain_community.graphs.graph_document import Node, Relationship, GraphDocument
from langchain_core.documents import Document

class LLMEntityRelationshipExtractor(BaseExtractor):
def __init__(
Expand All @@ -19,42 +20,116 @@ def __init__(
self.allowed_edge_types = allowed_relationship_types
self.strict_mode = strict_mode

def _extract_kg_from_doc(self, doc, chain, parser):
"""
returns:
{
"nodes": [
{
"id": "str",
"type": "string",
"definition": "string"
}
],
"rels": [
{
"source":{
"id": "str",
"type": "string",
"definition": "string"
}
"target":{
"id": "str",
"type": "string",
"definition": "string"
async def _aextract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]:
try:
out = await chain.ainvoke(
{"input": doc, "format_instructions": parser.get_format_instructions()}
)
except Exception as e:
return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))]
try:
if "```json" not in out.content:
json_out = json.loads(out.content.strip("content="))
else:
json_out = json.loads(
out.content.split("```")[1].strip("```").strip("json").strip()
)

formatted_rels = []
for rels in json_out["rels"]:
if isinstance(rels["source"], str) and isinstance(rels["target"], str):
formatted_rels.append(
{
"source": rels["source"],
"target": rels["target"],
"type": rels["relation_type"].replace(" ", "_").upper(),
"definition": rels["definition"],
}
)
elif isinstance(rels["source"], dict) and isinstance(
rels["target"], str
):
formatted_rels.append(
{
"source": rels["source"]["id"],
"target": rels["target"],
"type": rels["relation_type"].replace(" ", "_").upper(),
"definition": rels["definition"],
}
)
elif isinstance(rels["source"], str) and isinstance(
rels["target"], dict
):
formatted_rels.append(
{
"source": rels["source"],
"target": rels["target"]["id"],
"type": rels["relation_type"].replace(" ", "_").upper(),
"definition": rels["definition"],
}
)
elif isinstance(rels["source"], dict) and isinstance(
rels["target"], dict
):
formatted_rels.append(
{
"source": rels["source"]["id"],
"target": rels["target"]["id"],
"type": rels["relation_type"].replace(" ", "_").upper(),
"definition": rels["definition"],
}
)
else:
raise Exception("Relationship parsing error")
formatted_nodes = []
for node in json_out["nodes"]:
formatted_nodes.append(
{
"id": node["id"],
"type": node["node_type"].replace(" ", "_").capitalize(),
"definition": node["definition"],
}
"definition"
}
]
}
"""
)

# filter relationships and nodes based on allowed types
if self.strict_mode:
if self.allowed_vertex_types:
formatted_nodes = [
node
for node in formatted_nodes
if node["type"] in self.allowed_vertex_types
]
if self.allowed_edge_types:
formatted_rels = [
rel
for rel in formatted_rels
if rel["type"] in self.allowed_edge_types
]

nodes = []
for node in formatted_nodes:
nodes.append(Node(id=node["id"],
type=node["type"],
properties={"description": node["definition"]}))
relationships = []
for rel in formatted_rels:
relationships.append(Relationship(source=Node(id=rel["source"], type=rel["source"],
properties={"description": rel["definition"]}),
target=Node(id=rel["target"], type=rel["target"],
properties={"description": rel["definition"]}), type=rel["type"]))

return [GraphDocument(nodes=nodes, relationships=relationships, source=Document(page_content=doc))]

except:
return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))]

def _extract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]:
try:
out = chain.invoke(
{"input": doc, "format_instructions": parser.get_format_instructions()}
)
except Exception as e:
print("Error: ", e)
return {"nodes": [], "rels": []}
return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))]
try:
if "```json" not in out.content:
json_out = json.loads(out.content.strip("content="))
Expand Down Expand Up @@ -133,15 +208,67 @@ def _extract_kg_from_doc(self, doc, chain, parser):
for rel in formatted_rels
if rel["type"] in self.allowed_edge_types
]
return {"nodes": formatted_nodes, "rels": formatted_rels}

nodes = []
for node in formatted_nodes:
nodes.append(Node(id=node["id"],
type=node["type"],
properties={"description": node["definition"]}))
relationships = []
for rel in formatted_rels:
relationships.append(Relationship(source=Node(id=rel["source"], type=rel["source"],
properties={"description": rel["definition"]}),
target=Node(id=rel["target"], type=rel["target"],
properties={"description": rel["definition"]}), type=rel["type"]))

return [GraphDocument(nodes=nodes, relationships=relationships, source=Document(page_content=doc))]

except:
print("Error Processing: ", out)
return {"nodes": [], "rels": []}
return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))]

async def adocument_er_extraction(self, document):
from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser


parser = PydanticOutputParser(pydantic_object=KnowledgeGraph)
prompt = [
("system", self.llm_service.entity_relationship_extraction_prompt),
(
"human",
"Tip: Make sure to answer in the correct format and do "
"not include any explanations. "
"Use the given format to extract information from the "
"following input: {input}",
),
(
"human",
"Mandatory: Make sure to answer in the correct format, specified here: {format_instructions}",
),
]
if self.allowed_vertex_types or self.allowed_edge_types:
prompt.append(
(
"human",
"Tip: Make sure to use the following types if they are applicable. "
"If the input does not contain any of the types, you may create your own.",
)
)
if self.allowed_vertex_types:
prompt.append(("human", f"Allowed Node Types: {self.allowed_vertex_types}"))
if self.allowed_edge_types:
prompt.append(("human", f"Allowed Edge Types: {self.allowed_edge_types}"))
prompt = ChatPromptTemplate.from_messages(prompt)
chain = prompt | self.llm_service.model # | parser
er = await self._aextract_kg_from_doc(document, chain, parser)
return er


def document_er_extraction(self, document):
from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser


parser = PydanticOutputParser(pydantic_object=KnowledgeGraph)
prompt = [
("system", self.llm_service.entity_relationship_extraction_prompt),
Expand Down Expand Up @@ -176,3 +303,8 @@ def document_er_extraction(self, document):

def extract(self, text):
return self.document_er_extraction(text)

async def aextract(self, text) -> list[GraphDocument]:
return await self.adocument_er_extraction(text)


2 changes: 1 addition & 1 deletion eventual-consistency-service/app/graphrag/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def map_attrs(attributes: dict):


def process_id(v_id: str):
v_id = v_id.replace(" ", "_").replace("/", "")
v_id = v_id.replace(" ", "_").replace("/", "").replace("%", "percent")

has_func = re.compile(r"(.*)\(").findall(v_id)
if len(has_func) > 0:
Expand Down
18 changes: 17 additions & 1 deletion eventual-consistency-service/app/graphrag/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def extract(

# upsert nodes and edges to the graph
for doc in extracted:
for node in doc.nodes:
for i, node in enumerate(doc.nodes):
logger.info(f"extract writes entity vert to upsert\nNode: {node.id}")
v_id = util.process_id(str(node.id))
if len(v_id) == 0:
Expand Down Expand Up @@ -259,6 +259,22 @@ async def extract(
),
)
)
for node2 in doc.nodes[i + 1:]:
v_id2 = util.process_id(str(node2.id))
await upsert_chan.put(
(
util.upsert_edge,
(
conn,
"Entity", # src_type
v_id, # src_id
"RELATIONSHIP", # edgeType
"Entity", # tgt_type
v_id2, # tgt_id
{"relation_type": "DOC_CHUNK_COOCCURRENCE"}, # attributes
parkererickson-tg marked this conversation as resolved.
Show resolved Hide resolved
),
)
)

for edge in doc.relationships:
logger.info(
Expand Down
Loading