From 6f2f8d6eb7b74dd28463f669e9460dbb2b10af17 Mon Sep 17 00:00:00 2001 From: Parker Erickson Date: Thu, 29 Aug 2024 15:06:49 -0500 Subject: [PATCH 1/2] feat(graphrag): add descriptions to all upserts, cooccurence edges --- common/config.py | 2 +- .../LLMEntityRelationshipExtractor.py | 35 +++++++++++++++---- .../app/graphrag/util.py | 2 +- .../app/graphrag/workers.py | 18 +++++++++- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/common/config.py b/common/config.py index 4548aa3a..ee7a850b 100644 --- a/common/config.py +++ b/common/config.py @@ -188,7 +188,7 @@ def get_llm_service(llm_config) -> LLM_Model: doc_processing_config = { "chunker": "semantic", "chunker_config": {"method": "percentile", "threshold": 0.90}, - "extractor": "graphrag", + "extractor": "llm", "extractor_config": {}, } elif DOC_PROCESSING_CONFIG.endswith(".json"): diff --git a/common/extractors/LLMEntityRelationshipExtractor.py b/common/extractors/LLMEntityRelationshipExtractor.py index 415c3235..959ce644 100644 --- a/common/extractors/LLMEntityRelationshipExtractor.py +++ b/common/extractors/LLMEntityRelationshipExtractor.py @@ -4,7 +4,8 @@ from common.extractors.BaseExtractor import BaseExtractor from common.llm_services import LLM_Model from common.py_schemas import KnowledgeGraph - +from langchain_community.graphs.graph_document import Node, Relationship, GraphDocument +from langchain_core.documents import Document class LLMEntityRelationshipExtractor(BaseExtractor): def __init__( @@ -19,7 +20,7 @@ def __init__( self.allowed_edge_types = allowed_relationship_types self.strict_mode = strict_mode - def _extract_kg_from_doc(self, doc, chain, parser): + async def _extract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]: """ returns: { @@ -49,7 +50,7 @@ def _extract_kg_from_doc(self, doc, chain, parser): """ try: - out = chain.invoke( + out = await chain.ainvoke( {"input": doc, "format_instructions": parser.get_format_instructions()} ) except Exception as e: @@ -133,15 +134,30 @@ def _extract_kg_from_doc(self, doc, chain, parser): for rel in formatted_rels if rel["type"] in self.allowed_edge_types ] - return {"nodes": formatted_nodes, "rels": formatted_rels} + + nodes = [] + for node in formatted_nodes: + nodes.append(Node(id=node["id"], + type=node["type"], + properties={"description": node["definition"]})) + relationships = [] + for rel in formatted_rels: + relationships.append(Relationship(source=Node(id=rel["source"], type=rel["source"], + properties={"description": rel["definition"]}), + target=Node(id=rel["target"], type=rel["target"], + properties={"description": rel["definition"]}), type=rel["type"])) + + return [GraphDocument(nodes=nodes, relationships=relationships, source=Document(page_content=doc))] + except: print("Error Processing: ", out) - return {"nodes": [], "rels": []} + return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))] - def document_er_extraction(self, document): + async def document_er_extraction(self, document): from langchain.prompts import ChatPromptTemplate from langchain.output_parsers import PydanticOutputParser + parser = PydanticOutputParser(pydantic_object=KnowledgeGraph) prompt = [ ("system", self.llm_service.entity_relationship_extraction_prompt), @@ -171,8 +187,13 @@ def document_er_extraction(self, document): prompt.append(("human", f"Allowed Edge Types: {self.allowed_edge_types}")) prompt = ChatPromptTemplate.from_messages(prompt) chain = prompt | self.llm_service.model # | parser - er = self._extract_kg_from_doc(document, chain, parser) + er = await self._extract_kg_from_doc(document, chain, parser) return er def extract(self, text): return self.document_er_extraction(text) + + async def aextract(self, text) -> list[GraphDocument]: + return await self.document_er_extraction(text) + + diff --git a/eventual-consistency-service/app/graphrag/util.py b/eventual-consistency-service/app/graphrag/util.py index 35f5bcdf..3911fd56 100644 --- a/eventual-consistency-service/app/graphrag/util.py +++ b/eventual-consistency-service/app/graphrag/util.py @@ -178,7 +178,7 @@ def map_attrs(attributes: dict): def process_id(v_id: str): - v_id = v_id.replace(" ", "_").replace("/", "") + v_id = v_id.replace(" ", "_").replace("/", "").replace("%", "percent") has_func = re.compile(r"(.*)\(").findall(v_id) if len(has_func) > 0: diff --git a/eventual-consistency-service/app/graphrag/workers.py b/eventual-consistency-service/app/graphrag/workers.py index c00b9187..8967d120 100644 --- a/eventual-consistency-service/app/graphrag/workers.py +++ b/eventual-consistency-service/app/graphrag/workers.py @@ -213,7 +213,7 @@ async def extract( # upsert nodes and edges to the graph for doc in extracted: - for node in doc.nodes: + for i, node in enumerate(doc.nodes): logger.info(f"extract writes entity vert to upsert\nNode: {node.id}") v_id = util.process_id(str(node.id)) if len(v_id) == 0: @@ -259,6 +259,22 @@ async def extract( ), ) ) + for node2 in doc.nodes[i + 1:]: + v_id2 = util.process_id(str(node2.id)) + await upsert_chan.put( + ( + util.upsert_edge, + ( + conn, + "Entity", # src_type + v_id, # src_id + "RELATIONSHIP", # edgeType + "Entity", # tgt_type + v_id2, # tgt_id + {"relation_type": "DOC_CHUNK_COOCCURRENCE"}, # attributes + ), + ) + ) for edge in doc.relationships: logger.info( From e76fbd3f873c54d0ff74bb9e7b98c907f4e7b6c4 Mon Sep 17 00:00:00 2001 From: Parker Erickson Date: Thu, 29 Aug 2024 15:44:19 -0500 Subject: [PATCH 2/2] update async/sync llm relationship extraction --- .../LLMEntityRelationshipExtractor.py | 181 ++++++++++++++---- 1 file changed, 146 insertions(+), 35 deletions(-) diff --git a/common/extractors/LLMEntityRelationshipExtractor.py b/common/extractors/LLMEntityRelationshipExtractor.py index 959ce644..6ee999e9 100644 --- a/common/extractors/LLMEntityRelationshipExtractor.py +++ b/common/extractors/LLMEntityRelationshipExtractor.py @@ -20,42 +20,116 @@ def __init__( self.allowed_edge_types = allowed_relationship_types self.strict_mode = strict_mode - async def _extract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]: - """ - returns: - { - "nodes": [ - { - "id": "str", - "type": "string", - "definition": "string" - } - ], - "rels": [ - { - "source":{ - "id": "str", - "type": "string", - "definition": "string" - } - "target":{ - "id": "str", - "type": "string", - "definition": "string" + async def _aextract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]: + try: + out = await chain.ainvoke( + {"input": doc, "format_instructions": parser.get_format_instructions()} + ) + except Exception as e: + return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))] + try: + if "```json" not in out.content: + json_out = json.loads(out.content.strip("content=")) + else: + json_out = json.loads( + out.content.split("```")[1].strip("```").strip("json").strip() + ) + + formatted_rels = [] + for rels in json_out["rels"]: + if isinstance(rels["source"], str) and isinstance(rels["target"], str): + formatted_rels.append( + { + "source": rels["source"], + "target": rels["target"], + "type": rels["relation_type"].replace(" ", "_").upper(), + "definition": rels["definition"], + } + ) + elif isinstance(rels["source"], dict) and isinstance( + rels["target"], str + ): + formatted_rels.append( + { + "source": rels["source"]["id"], + "target": rels["target"], + "type": rels["relation_type"].replace(" ", "_").upper(), + "definition": rels["definition"], + } + ) + elif isinstance(rels["source"], str) and isinstance( + rels["target"], dict + ): + formatted_rels.append( + { + "source": rels["source"], + "target": rels["target"]["id"], + "type": rels["relation_type"].replace(" ", "_").upper(), + "definition": rels["definition"], + } + ) + elif isinstance(rels["source"], dict) and isinstance( + rels["target"], dict + ): + formatted_rels.append( + { + "source": rels["source"]["id"], + "target": rels["target"]["id"], + "type": rels["relation_type"].replace(" ", "_").upper(), + "definition": rels["definition"], + } + ) + else: + raise Exception("Relationship parsing error") + formatted_nodes = [] + for node in json_out["nodes"]: + formatted_nodes.append( + { + "id": node["id"], + "type": node["node_type"].replace(" ", "_").capitalize(), + "definition": node["definition"], } - "definition" - } - ] - } - """ + ) + # filter relationships and nodes based on allowed types + if self.strict_mode: + if self.allowed_vertex_types: + formatted_nodes = [ + node + for node in formatted_nodes + if node["type"] in self.allowed_vertex_types + ] + if self.allowed_edge_types: + formatted_rels = [ + rel + for rel in formatted_rels + if rel["type"] in self.allowed_edge_types + ] + + nodes = [] + for node in formatted_nodes: + nodes.append(Node(id=node["id"], + type=node["type"], + properties={"description": node["definition"]})) + relationships = [] + for rel in formatted_rels: + relationships.append(Relationship(source=Node(id=rel["source"], type=rel["source"], + properties={"description": rel["definition"]}), + target=Node(id=rel["target"], type=rel["target"], + properties={"description": rel["definition"]}), type=rel["type"])) + + return [GraphDocument(nodes=nodes, relationships=relationships, source=Document(page_content=doc))] + + except: + return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))] + + def _extract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]: try: - out = await chain.ainvoke( + out = chain.invoke( {"input": doc, "format_instructions": parser.get_format_instructions()} ) except Exception as e: - print("Error: ", e) - return {"nodes": [], "rels": []} + return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))] try: if "```json" not in out.content: json_out = json.loads(out.content.strip("content=")) @@ -150,10 +224,47 @@ async def _extract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]: return [GraphDocument(nodes=nodes, relationships=relationships, source=Document(page_content=doc))] except: - print("Error Processing: ", out) - return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))] + return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))] + + async def adocument_er_extraction(self, document): + from langchain.prompts import ChatPromptTemplate + from langchain.output_parsers import PydanticOutputParser + + + parser = PydanticOutputParser(pydantic_object=KnowledgeGraph) + prompt = [ + ("system", self.llm_service.entity_relationship_extraction_prompt), + ( + "human", + "Tip: Make sure to answer in the correct format and do " + "not include any explanations. " + "Use the given format to extract information from the " + "following input: {input}", + ), + ( + "human", + "Mandatory: Make sure to answer in the correct format, specified here: {format_instructions}", + ), + ] + if self.allowed_vertex_types or self.allowed_edge_types: + prompt.append( + ( + "human", + "Tip: Make sure to use the following types if they are applicable. " + "If the input does not contain any of the types, you may create your own.", + ) + ) + if self.allowed_vertex_types: + prompt.append(("human", f"Allowed Node Types: {self.allowed_vertex_types}")) + if self.allowed_edge_types: + prompt.append(("human", f"Allowed Edge Types: {self.allowed_edge_types}")) + prompt = ChatPromptTemplate.from_messages(prompt) + chain = prompt | self.llm_service.model # | parser + er = await self._aextract_kg_from_doc(document, chain, parser) + return er + - async def document_er_extraction(self, document): + def document_er_extraction(self, document): from langchain.prompts import ChatPromptTemplate from langchain.output_parsers import PydanticOutputParser @@ -187,13 +298,13 @@ async def document_er_extraction(self, document): prompt.append(("human", f"Allowed Edge Types: {self.allowed_edge_types}")) prompt = ChatPromptTemplate.from_messages(prompt) chain = prompt | self.llm_service.model # | parser - er = await self._extract_kg_from_doc(document, chain, parser) + er = self._extract_kg_from_doc(document, chain, parser) return er def extract(self, text): return self.document_er_extraction(text) async def aextract(self, text) -> list[GraphDocument]: - return await self.document_er_extraction(text) + return await self.adocument_er_extraction(text)