Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gml 1820 graph rag init #259

Merged
merged 55 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
7a43943
init
RobRossmiller-TG Jul 22, 2024
8e0ed55
save: docs handled concurrently -- writing upsert_edge
RobRossmiller-TG Jul 23, 2024
ec299a2
save: docs handled concurrently -- writing upsert_edge
RobRossmiller-TG Jul 23, 2024
fce72c4
changing queues for channels
RobRossmiller-TG Jul 24, 2024
46d73dc
graphrag etl with channels
RobRossmiller-TG Jul 24, 2024
7501a37
pytg in 175 seconds
RobRossmiller-TG Jul 29, 2024
bb37198
docs processing done -- start community passes
RobRossmiller-TG Jul 30, 2024
e9f178e
save
RobRossmiller-TG Aug 1, 2024
8ab8774
starting to upsert community summaries
RobRossmiller-TG Aug 9, 2024
ef842ba
graphrag pipeline done
RobRossmiller-TG Aug 12, 2024
08aca04
cleanup
RobRossmiller-TG Aug 12, 2024
c50928c
Merge branch 'dev' into GML-1820-graph_rag_init
RobRossmiller-TG Aug 12, 2024
f282840
fmt after merge conflicts
RobRossmiller-TG Aug 12, 2024
50a4fd5
rm clang dotfiles
RobRossmiller-TG Aug 13, 2024
f007c8a
final cleanup
RobRossmiller-TG Aug 13, 2024
2d1e98b
reqs to fix unit tests
RobRossmiller-TG Aug 13, 2024
e0065ee
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
2a5434a
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
a43490a
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
4b76e73
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
115b1b3
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
58b5cbe
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
fa96039
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
905d5cf
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
5e8b0ae
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
be0177e
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
cb43815
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
ac6d3fe
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
60aa569
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
2d37756
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
1929aa2
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
f33ddef
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
1a97181
langchain-openai conflicts
RobRossmiller-TG Aug 13, 2024
e9f7468
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
c8248d7
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
210d0fc
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
9c8b183
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
e4d8168
reqs to fix unit test
RobRossmiller-TG Aug 13, 2024
538653f
reqs to fix unit tests
RobRossmiller-TG Aug 13, 2024
a63d376
reqs to fix unit tests
RobRossmiller-TG Aug 13, 2024
fe6643c
smoke test
RobRossmiller-TG Aug 13, 2024
64b3998
smoke test
RobRossmiller-TG Aug 14, 2024
e08d42a
smoke test
RobRossmiller-TG Aug 14, 2024
17b09df
smoke test
RobRossmiller-TG Aug 14, 2024
6ce885f
smoke test
RobRossmiller-TG Aug 14, 2024
442564b
smoke test
RobRossmiller-TG Aug 14, 2024
2d8675e
smoke test
RobRossmiller-TG Aug 14, 2024
e9f5e9d
smoke test
RobRossmiller-TG Aug 14, 2024
0ca73a3
smoke test
RobRossmiller-TG Aug 14, 2024
8252c1e
smoke test
RobRossmiller-TG Aug 14, 2024
8777b3c
smoke test
RobRossmiller-TG Aug 14, 2024
69a7db4
smoke test
RobRossmiller-TG Aug 14, 2024
4dfa51c
smoke test
RobRossmiller-TG Aug 14, 2024
56f8e16
working
RobRossmiller-TG Aug 14, 2024
11896e8
Merge pull request #262 from tigergraph/GML-1856-unit_test_deps
RobRossmiller-TG Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
AWSBedrock,
AzureOpenAI,
GoogleVertexAI,
OpenAI,
Groq,
Ollama,
HuggingFaceEndpoint,
LLM_Model,
Ollama,
OpenAI,
IBMWatsonX
)
from common.logs.logwriter import LogWriter
from common.session import SessionHandler
from common.status import StatusManager
from common.logs.logwriter import LogWriter

security = HTTPBasic()
session_handler = SessionHandler()
Expand Down Expand Up @@ -105,7 +106,7 @@
raise Exception("Embedding service not implemented")


def get_llm_service(llm_config):
def get_llm_service(llm_config) -> LLM_Model:
if llm_config["completion_service"]["llm_service"].lower() == "openai":
return OpenAI(llm_config["completion_service"])
elif llm_config["completion_service"]["llm_service"].lower() == "azure":
Expand Down Expand Up @@ -191,7 +192,7 @@ def get_llm_service(llm_config):
doc_processing_config = {
"chunker": "semantic",
"chunker_config": {"method": "percentile", "threshold": 0.95},
"extractor": "llm",
"extractor": "graphrag",
"extractor_config": {},
}
elif DOC_PROCESSING_CONFIG.endswith(".json"):
Expand Down
37 changes: 33 additions & 4 deletions common/embeddings/embedding_services.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import os
import time
from typing import List

from langchain.schema.embeddings import Embeddings
import logging
import time

from common.logs.log import req_id_cv
from common.metrics.prometheus_metrics import metrics
from common.logs.logwriter import LogWriter
from common.metrics.prometheus_metrics import metrics

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,6 +89,33 @@ def embed_query(self, question: str) -> List[float]:
duration
)

async def aembed_query(self, question: str) -> List[float]:
"""Embed Query Async.
Embed a string.

Args:
question (str):
A string to embed.
"""
# start_time = time.time()
# metrics.llm_inprogress_requests.labels(self.model_name).inc()

# try:
logger.debug_pii(f"aembed_query() embedding question={question}")
query_embedding = await self.embeddings.aembed_query(question)
# metrics.llm_success_response_total.labels(self.model_name).inc()
return query_embedding
# except Exception as e:
# # metrics.llm_query_error_total.labels(self.model_name).inc()
# raise e
# finally:
# metrics.llm_request_total.labels(self.model_name).inc()
# metrics.llm_inprogress_requests.labels(self.model_name).dec()
# duration = time.time() - start_time
# metrics.llm_request_duration_seconds.labels(self.model_name).observe(
# duration
# )


class AzureOpenAI_Ada002(EmbeddingModel):
"""Azure OpenAI Ada-002 Embedding Model"""
Expand Down Expand Up @@ -124,8 +153,8 @@ class AWS_Bedrock_Embedding(EmbeddingModel):
"""AWS Bedrock Embedding Model"""

def __init__(self, config):
from langchain_community.embeddings import BedrockEmbeddings
import boto3
from langchain_community.embeddings import BedrockEmbeddings

super().__init__(config=config, model_name=config["embedding_model"])

Expand Down
152 changes: 142 additions & 10 deletions common/embeddings/milvus_embedding_store.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import logging
import traceback
from time import sleep, time
from typing import Iterable, List, Optional, Tuple

from langchain_milvus.vectorstores import Milvus
import Levenshtein as lev
from asyncer import asyncify
from langchain_community.vectorstores import Milvus
from langchain_core.documents.base import Document
from pymilvus import connections, utility
from langchain_milvus.vectorstores import Milvus
from pymilvus import MilvusException, connections, utility
from pymilvus.exceptions import MilvusException

from common.embeddings.base_embedding_store import EmbeddingStore
from common.embeddings.embedding_services import EmbeddingModel
from common.logs.log import req_id_cv
from common.metrics.prometheus_metrics import metrics
from common.logs.logwriter import LogWriter
from pymilvus import MilvusException
from common.metrics.prometheus_metrics import metrics

logger = logging.getLogger(__name__)

Expand All @@ -33,6 +36,7 @@ def __init__(
alias: str = "alias",
retry_interval: int = 2,
max_retry_attempts: int = 10,
drop_old=False,
):
self.embedding_service = embedding_service
self.vector_field = vector_field
Expand All @@ -43,6 +47,7 @@ def __init__(
self.milvus_alias = alias
self.retry_interval = retry_interval
self.max_retry_attempts = max_retry_attempts
self.drop_old = drop_old

if host.startswith("http"):
if host.endswith(str(port)):
Expand Down Expand Up @@ -77,7 +82,7 @@ def connect_to_milvus(self):
while retry_attempt < self.max_retry_attempts:
try:
connections.connect(**self.milvus_connection)
metrics.milvus_active_connections.labels(self.collection_name).inc
# metrics.milvus_active_connections.labels(self.collection_name).inc
LogWriter.info(
f"""Initializing Milvus with host={self.milvus_connection.get("host", self.milvus_connection.get("uri", "unknown host"))},
port={self.milvus_connection.get('port', 'unknown')}, username={self.milvus_connection.get('user', 'unknown')}, collection={self.collection_name}"""
Expand All @@ -88,7 +93,7 @@ def connect_to_milvus(self):
collection_name=self.collection_name,
connection_args=self.milvus_connection,
auto_id=True,
drop_old=False,
drop_old=self.drop_old,
text_field=self.text_field,
vector_field=self.vector_field,
)
Expand Down Expand Up @@ -120,6 +125,9 @@ def metadata_func(record: dict, metadata: dict) -> dict:
return metadata

LogWriter.info("Milvus add initial load documents init()")
import os

logger.info(f"*******{os.path.exists('tg_documents')}")
loader = DirectoryLoader(
"./common/tg_documents/",
glob="*.json",
Expand Down Expand Up @@ -216,6 +224,76 @@ def add_embeddings(
error_message = f"An error occurred while registering document: {str(e)}"
LogWriter.error(error_message)

async def aadd_embeddings(
self,
embeddings: Iterable[Tuple[str, List[float]]],
metadatas: List[dict] = None,
):
"""Async Add Embeddings.
Add embeddings to the Embedding store.
Args:
embeddings (Iterable[Tuple[str, List[float]]]):
Iterable of content and embedding of the document.
metadatas (List[Dict]):
List of dictionaries containing the metadata for each document.
The embeddings and metadatas list need to have identical indexing.
"""
try:
if metadatas is None:
metadatas = []

# add fields required by Milvus if they do not exist
if self.support_ai_instance:
for metadata in metadatas:
if self.vertex_field not in metadata:
metadata[self.vertex_field] = ""
else:
for metadata in metadatas:
if "seq_num" not in metadata:
metadata["seq_num"] = 1
if "source" not in metadata:
metadata["source"] = ""

LogWriter.info(
f"request_id={req_id_cv.get()} Milvus ENTRY aadd_embeddings()"
)
texts = [text for text, _ in embeddings]

# operation_type = "add_texts"
# metrics.milvus_query_total.labels(
# self.collection_name, operation_type
# ).inc()
# start_time = time()

added = await self.milvus.aadd_texts(texts=texts, metadatas=metadatas)

# duration = time() - start_time
# metrics.milvus_query_duration_seconds.labels(
# self.collection_name, operation_type
# ).observe(duration)

LogWriter.info(
f"request_id={req_id_cv.get()} Milvus EXIT aadd_embeddings()"
)

# Check if registration was successful
if added:
success_message = f"Document registered with id: {added[0]}"
LogWriter.info(success_message)
return success_message
else:
error_message = f"Failed to register document {added}"
LogWriter.error(error_message)
raise Exception(error_message)

except Exception as e:
error_message = f"An error occurred while registering document:{metadatas} ({len(texts)},{len(metadatas)})\nErr: {str(e)}"
LogWriter.error(error_message)
exc = traceback.format_exc()
LogWriter.error(exc)
LogWriter.error(f"{texts}")
raise e

def get_pks(
self,
expr: str,
Expand Down Expand Up @@ -509,14 +587,68 @@ def query(self, expr: str, output_fields: List[str]):
return None

try:
query_result = self.milvus.col.query(
expr=expr, output_fields=output_fields
)
query_result = self.milvus.col.query(expr=expr, output_fields=output_fields)
except MilvusException as exc:
LogWriter.error(f"Failed to get outputs: {self.milvus.collection_name} error: {exc}")
LogWriter.error(
f"Failed to get outputs: {self.milvus.collection_name} error: {exc}"
)
raise exc

return query_result

def edit_dist_check(self, a: str, b: str, edit_dist_threshold: float, p=False):
a = a.lower()
b = b.lower()
# if the words are short, they should be the same
if len(a) < 5 and len(b) < 5:
return a == b

# edit_dist_threshold (as a percent) of word must match
threshold = int(min(len(a), len(b)) * (1 - edit_dist_threshold))
if p:
print(a, b, threshold, lev.distance(a, b))
return lev.distance(a, b) < threshold

async def aget_k_closest(
self, v_id: str, k=15, threshold_similarity=0.90, edit_dist_threshold_pct=0.75
) -> list[Document]:
threshold_dist = 1 - threshold_similarity

# asyncify necessary funcs
query = asyncify(self.milvus.col.query)
search = asyncify(self.milvus.similarity_search_with_score_by_vector)

# Get all vectors with this ID
verts = await query(
f'{self.vertex_field} == "{v_id}"',
output_fields=[self.vertex_field, self.vector_field],
)
result = []
for v in verts:
# get the k closest verts
sim = await search(
v["document_vector"],
k=k,
)
# filter verts using similiarity threshold and leven_dist
similar_verts = [
doc.metadata["vertex_id"]
for doc, dist in sim
# check semantic similarity
if dist < threshold_dist
# check name similarity (won't merge Apple and Google if they're semantically similar)
and self.edit_dist_check(
doc.metadata["vertex_id"],
v_id,
edit_dist_threshold_pct,
# v_id == "Dataframe",
)
# don't have to merge verts with the same id (they're the same)
and doc.metadata["vertex_id"] != v_id
]
result.extend(similar_verts)
result.append(v_id)
return set(result)

def __del__(self):
metrics.milvus_active_connections.labels(self.collection_name).dec
13 changes: 10 additions & 3 deletions common/extractors/BaseExtractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
class BaseExtractor:
def __init__():
from abc import ABC, abstractmethod

from langchain_community.graphs.graph_document import GraphDocument


class BaseExtractor(ABC):
@abstractmethod
def extract(self, text:str):
pass

def extract(self, text):
@abstractmethod
async def aextract(self, text:str) -> list[GraphDocument]:
pass
Loading
Loading