Skip to content

Commit

Permalink
Merge pull request #280 from tigergraph/GML-1896-native-vector-store
Browse files Browse the repository at this point in the history
Gml 1896 native vector store
  • Loading branch information
parkererickson-tg authored Dec 25, 2024
2 parents ce56f2f + 8db5eb1 commit d0ba503
Show file tree
Hide file tree
Showing 12 changed files with 399 additions and 348 deletions.
86 changes: 60 additions & 26 deletions common/db/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBasicCredentials, HTTPAuthorizationCredentials
from pyTigerGraph import TigerGraphConnection
from pyTigerGraph import TigerGraphConnection, AsyncTigerGraphConnection
from pyTigerGraph.common.exception import TigerGraphException
from requests import HTTPError

Expand All @@ -21,14 +21,24 @@
def get_db_connection_id_token(
graphname: str,
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)],
async_conn: bool = False
) -> TigerGraphConnectionProxy:
conn = TigerGraphConnection(
host=db_config["hostname"],
graphname=graphname,
apiToken=credentials,
tgCloud=True,
sslPort=14240,
)
if async_conn:
conn = AsyncTigerGraphConnection(
host=db_config["hostname"],
graphname=graphname,
apiToken=credentials,
tgCloud=True,
sslPort=14240,
)
else:
conn = TigerGraphConnection(
host=db_config["hostname"],
graphname=graphname,
apiToken=credentials,
tgCloud=True,
sslPort=14240,
)
conn.customizeHeader(
timeout=db_config["default_timeout"] * 1000, responseSize=5000000
)
Expand All @@ -55,9 +65,10 @@ def get_db_connection_id_token(


def get_db_connection_pwd(
graphname, credentials: Annotated[HTTPBasicCredentials, Depends(security)]
graphname, credentials: Annotated[HTTPBasicCredentials, Depends(security)],
async_conn: bool = False
) -> TigerGraphConnectionProxy:
conn = elevate_db_connection_to_token(db_config["hostname"], credentials.username, credentials.password, graphname)
conn = elevate_db_connection_to_token(db_config["hostname"], credentials.username, credentials.password, graphname, async_conn)

conn.customizeHeader(
timeout=db_config["default_timeout"] * 1000, responseSize=5000000
Expand All @@ -70,12 +81,13 @@ def get_db_connection_pwd(

def get_db_connection_pwd_manual(
graphname, username: str, password: str,
async_conn: bool = False
) -> TigerGraphConnectionProxy:
"""
Manual auth - pass in user/pass not from basic auth
"""
conn = elevate_db_connection_to_token(
db_config["hostname"], username, password, graphname
db_config["hostname"], username, password, graphname, async_conn
)

conn.customizeHeader(
Expand All @@ -85,22 +97,19 @@ def get_db_connection_pwd_manual(
LogWriter.info("Connected to TigerGraph with password")
return conn

def elevate_db_connection_to_token(host, username, password, graphname) -> TigerGraphConnectionProxy:
def elevate_db_connection_to_token(host, username, password, graphname, async_conn: bool = False) -> TigerGraphConnectionProxy:
conn = TigerGraphConnection(
host=host,
username=username,
password=password,
graphname=graphname
graphname=graphname,
restppPort=db_config.get("restppPort", "9000"),
gsPort=db_config.get("gsPort", "14240")
)

if db_config["getToken"]:
try:
apiToken = conn._post(
conn.restppUrl + "/requesttoken",
authMode="pwd",
data=str({"graph": conn.graphname}),
resKey="results",
)["token"]
apiToken = conn.getToken()[0]
except HTTPError:
LogWriter.error("Failed to get token")
raise HTTPException(
Expand All @@ -115,13 +124,38 @@ def elevate_db_connection_to_token(host, username, password, graphname) -> Tiger
detail="Failed to get token - is the database running?"
)

if async_conn:
conn = AsyncTigerGraphConnection(
host=host,
username=username,
password=password,
graphname=graphname,
apiToken=apiToken,
restppPort=db_config.get("restppPort", "9000"),
gsPort=db_config.get("gsPort", "14240")
)
else:
conn = TigerGraphConnection(
host=db_config["hostname"],
username=username,
password=password,
graphname=graphname,
apiToken=apiToken,
restppPort=db_config.get("restppPort", "9000"),
gsPort=db_config.get("gsPort", "14240")
)
else:
if async_conn:
conn = AsyncTigerGraphConnection(
host=host,
username=username,
password=password,
graphname=graphname,
restppPort=db_config.get("restppPort", "9000"),
gsPort=db_config.get("gsPort", "14240")
)

conn = TigerGraphConnection(
host=db_config["hostname"],
username=username,
password=password,
graphname=graphname,
apiToken=apiToken
)
# temp fix for path
conn.restppUrl = conn.restppUrl+"/restpp"

return conn
36 changes: 36 additions & 0 deletions common/gsql/supportai/SupportAI_Schema_Native_Vector.gsql
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
CREATE SCHEMA_CHANGE JOB add_supportai_schema {
ADD VERTEX DocumentChunk(id STRING PRIMARY KEY, idx INT, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH EMBEDDING ATTRIBUTE embedding(dimension=1536, metric=cosine) STATS="OUTDEGREE_BY_EDGETYPE";
ADD VERTEX Document(id STRING PRIMARY KEY, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH EMBEDDING ATTRIBUTE embedding(dimension=1536, metric=cosine) STATS="OUTDEGREE_BY_EDGETYPE";
ADD VERTEX Concept(id STRING PRIMARY KEY, description STRING, concept_type STRING, human_curated BOOL, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH EMBEDDING ATTRIBUTE embedding(dimension=1536, metric=cosine) STATS="OUTDEGREE_BY_EDGETYPE";
ADD VERTEX Entity(id STRING PRIMARY KEY, definition STRING, description SET<STRING>, entity_type STRING, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH EMBEDDING ATTRIBUTE embedding(dimension=1536, metric=cosine) STATS="OUTDEGREE_BY_EDGETYPE";
ADD VERTEX Relationship(id STRING PRIMARY KEY, definition STRING, short_name STRING, epoch_added UINT, epoch_processing UINT, epoch_processed UINT) WITH EMBEDDING ATTRIBUTE embedding(dimension=1536, metric=cosine) STATS="OUTDEGREE_BY_EDGETYPE";
ADD VERTEX DocumentCollection(id STRING PRIMARY KEY, epoch_added UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE";
ADD VERTEX Content(id STRING PRIMARY KEY, text STRING, epoch_added UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE";
ADD VERTEX EntityType(id STRING PRIMARY KEY, description STRING, epoch_added UINT) WITH STATS="OUTDEGREE_BY_EDGETYPE";
ADD DIRECTED EDGE HAS_CONTENT(FROM Document, TO Content|FROM DocumentChunk, TO Content) WITH REVERSE_EDGE="reverse_HAS_CONTENT";
ADD DIRECTED EDGE IS_CHILD_OF(FROM Concept, TO Concept) WITH REVERSE_EDGE="reverse_IS_CHILD_OF";
ADD DIRECTED EDGE IS_HEAD_OF(FROM Entity, TO Relationship) WITH REVERSE_EDGE="reverse_IS_HEAD_OF";
ADD DIRECTED EDGE HAS_TAIL(FROM Relationship, TO Entity) WITH REVERSE_EDGE="reverse_HAS_TAIL";
ADD DIRECTED EDGE DESCRIBES_RELATIONSHIP(FROM Concept, TO Relationship) WITH REVERSE_EDGE="reverse_DESCRIBES_RELATIONSHIP";
ADD DIRECTED EDGE DESCRIBES_ENTITY(FROM Concept, TO Entity) WITH REVERSE_EDGE="reverse_DESCRIBES_ENTITY";
ADD DIRECTED EDGE CONTAINS_ENTITY(FROM DocumentChunk, TO Entity|FROM Document, TO Entity) WITH REVERSE_EDGE="reverse_CONTAINS_ENTITY";
ADD DIRECTED EDGE MENTIONS_RELATIONSHIP(FROM DocumentChunk, TO Relationship|FROM Document, TO Relationship) WITH REVERSE_EDGE="reverse_MENTIONS_RELATIONSHIP";
ADD DIRECTED EDGE IS_AFTER(FROM DocumentChunk, TO DocumentChunk) WITH REVERSE_EDGE="reverse_IS_AFTER";
ADD DIRECTED EDGE HAS_CHILD(FROM Document, TO DocumentChunk) WITH REVERSE_EDGE="reverse_HAS_CHILD";
ADD DIRECTED EDGE HAS_RELATIONSHIP(FROM Concept, TO Concept, relation_type STRING) WITH REVERSE_EDGE="reverse_HAS_RELATIONSHIP";
ADD DIRECTED EDGE CONTAINS_DOCUMENT(FROM DocumentCollection, TO Document) WITH REVERSE_EDGE="reverse_CONTAINS_DOCUMENT";
ADD DIRECTED EDGE ENTITY_HAS_TYPE(FROM Entity, TO EntityType) WITH REVERSE_EDGE="reverse_ENTITY_HAS_TYPE";
ADD DIRECTED EDGE RELATIONSHIP_TYPE(FROM EntityType, TO EntityType, DISCRIMINATOR(relation_type STRING), frequency INT) WITH REVERSE_EDGE="reverse_RELATIONSHIP_TYPE";

// GraphRAG
ADD VERTEX Community (id STRING PRIMARY KEY, iteration UINT, description STRING) WITH EMBEDDING ATTRIBUTE embedding(dimension=1536, metric=cosine) STATS="OUTDEGREE_BY_EDGETYPE";
ADD VERTEX ResolvedEntity(id STRING PRIMARY KEY, entity_type STRING )WITH EMBEDDING ATTRIBUTE embedding(dimension=1536, metric=cosine) STATS="OUTDEGREE_BY_EDGETYPE";

ADD DIRECTED EDGE RELATIONSHIP(FROM Entity, TO Entity, relation_type STRING) WITH REVERSE_EDGE="reverse_RELATIONSHIP";
ADD DIRECTED EDGE RESOLVES_TO(FROM Entity, TO ResolvedEntity, relation_type STRING) WITH REVERSE_EDGE="reverse_RESOLVES_TO"; // Connect ResolvedEntities with their children entities
ADD DIRECTED EDGE RESOLVED_RELATIONSHIP(FROM ResolvedEntity, TO ResolvedEntity, relation_type STRING) WITH REVERSE_EDGE="reverse_RESOLVED_RELATIONSHIP"; // store edges between entities after they're resolved

ADD DIRECTED EDGE IN_COMMUNITY(FROM ResolvedEntity, TO Community) WITH REVERSE_EDGE="reverse_IN_COMMUNITY";
ADD DIRECTED EDGE LINKS_TO (from Community, to Community, weight DOUBLE) WITH REVERSE_EDGE="reverse_LINKS_TO";
ADD DIRECTED EDGE HAS_PARENT (from Community, to Community) WITH REVERSE_EDGE="reverse_HAS_PARENT";
}
2 changes: 1 addition & 1 deletion common/metrics/tg_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def hooked(*args, **kwargs):
else:
return original_attr

def _req(self, method: str, url: str, authMode: str, *args, **kwargs):
def _req(self, method: str, url: str, authMode: str = "token", *args, **kwargs):
# we always use token auth
# always use proxy endpoint in GUI for restpp and gsql
if self.auth_mode == "pwd":
Expand Down
6 changes: 5 additions & 1 deletion copilot/app/supportai/supportai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

def init_supportai(conn: TigerGraphConnection, graphname: str) -> tuple[dict, dict]:
# need to open the file using the absolute path
file_path = "common/gsql/supportai/SupportAI_Schema.gsql"
ver = conn.getVer().split(".")
if int(ver[0]) >= 4 and int(ver[1]) >= 2:
file_path = "common/gsql/supportai/SupportAI_Schema_Native_Vector.gsql"
else:
file_path = "common/gsql/supportai/SupportAI_Schema.gsql"
with open(file_path, "r") as f:
schema = f.read()
schema_res = conn.gsql(
Expand Down
Loading

0 comments on commit d0ba503

Please sign in to comment.