-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
example: add semantic cache demo built with jinaai and tidb vector (#39)
- Loading branch information
1 parent
1b8c290
commit 4a2c0fb
Showing
5 changed files
with
224 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Semantic Cache with Jina AI and TiDB Vector | ||
Semantic cache is a cache that stores the semantic information of the data. It can be used to speed up the search process by storing the embeddings of the data and searching for similar embeddings. This example demonstrates how to use Jina AI to generate embeddings for text data and store the embeddings in TiDB Vector Storage. It also shows how to search for similar embeddings in TiDB Vector Storage. | ||
|
||
## Prerequisites | ||
|
||
- A running TiDB Serverless cluster with vector search enabled | ||
- Python 3.8 or later | ||
- Jina AI API key | ||
|
||
## Run the example | ||
|
||
### Clone this repo | ||
|
||
```bash | ||
git clone https://github.com/pingcap/tidb-vector-python.git | ||
``` | ||
|
||
### Create a virtual environment | ||
|
||
```bash | ||
cd tidb-vector-python/examples/semantic-cache | ||
python3 -m venv .venv | ||
source .venv/bin/activate | ||
``` | ||
|
||
### Install dependencies | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Set the environment variables | ||
|
||
Get the `HOST`, `PORT`, `USERNAME`, `PASSWORD`, and `DATABASE` from the TiDB Cloud console, as described in the [Prerequisites](../README.md#prerequisites) section. Then set the following environment variables: | ||
|
||
```bash | ||
export DATABASE_URI="mysql+pymysql://34u7xMnnDLSkjV1.root:<PASSWORD>@gateway01.eu-central-1.prod.aws.tidbcloud.com:4000/test?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true" | ||
``` | ||
or create a `.env` file with the above environment variables. | ||
|
||
|
||
### Run this example | ||
|
||
|
||
#### Start the semantic cache server | ||
|
||
```bash | ||
fastapi dev cache.py | ||
``` | ||
|
||
#### Test the API | ||
|
||
Get the Jina AI API key from the [Jina AI Embedding API](https://jina.ai/embeddings/) page, and save it somewhere safe for later use. | ||
|
||
`POST /set` | ||
|
||
```bash | ||
curl --location ':8000/set' \ | ||
--header 'Content-Type: application/json' \ | ||
--header 'Authorization: Bearer <your jina token>' \ | ||
--data '{ | ||
"key": "what is tidb", | ||
"value": "tidb is a mysql-compatible and htap database" | ||
}' | ||
``` | ||
|
||
`GET /get/<key>` | ||
|
||
```bash | ||
curl --location ':8000/get/what%27s%20tidb%20and%20tikv?max_distance=0.5' \ | ||
--header 'Content-Type: application/json' \ | ||
--header 'Authorization: Bearer <your jina token>' | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import os | ||
from datetime import datetime | ||
from typing import Optional, Annotated | ||
|
||
import requests | ||
import dotenv | ||
from fastapi import Depends, FastAPI | ||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | ||
from sqlmodel import ( | ||
SQLModel, | ||
Session, | ||
create_engine, | ||
select, | ||
Field, | ||
Column, | ||
String, | ||
Text, | ||
DateTime, | ||
) | ||
from sqlalchemy import func | ||
from tidb_vector.sqlalchemy import VectorType | ||
dotenv.load_dotenv() | ||
|
||
|
||
# Configuration from .env | ||
# Example: "mysql+pymysql://<username>:<password>@<host>:<port>/<database>?ssl_mode=VERIFY_IDENTITY&ssl_ca=/etc/ssl/cert.pem" | ||
DATABASE_URI = os.getenv('DATABASE_URI') | ||
# Ref: https://docs.pingcap.com/tidb/stable/time-to-live | ||
# Default: 604800 SECOND (1 week) | ||
TIME_TO_LIVE = os.getenv('TIME_TO_LIVE') | ||
|
||
|
||
# Get Embeddings from Jina AI | ||
def generate_embeddings(jinaai_api_key: str, text: str): | ||
JINAAI_API_URL = 'https://api.jina.ai/v1/embeddings' | ||
JINAAI_HEADERS = { | ||
'Content-Type': 'application/json', | ||
'Authorization': f'Bearer {jinaai_api_key}' | ||
} | ||
JINAAI_REQUEST_DATA = { | ||
'input': [text], | ||
'model': 'jina-embeddings-v2-base-en' # with dimisions 768 | ||
} | ||
response = requests.post(JINAAI_API_URL, headers=JINAAI_HEADERS, json=JINAAI_REQUEST_DATA) | ||
return response.json()['data'][0]['embedding'] | ||
|
||
|
||
class Cache(SQLModel, table=True): | ||
__table_args__ = { | ||
# Ref: https://docs.pingcap.com/tidb/stable/time-to-live | ||
'mysql_TTL': f'created_at + INTERVAL {TIME_TO_LIVE} SECOND', | ||
} | ||
|
||
id: Optional[int] = Field(default=None, primary_key=True) | ||
key: str = Field(sa_column=Column(String(255), unique=True, nullable=False)) | ||
key_vec: Optional[list[float]]= Field( | ||
sa_column=Column( | ||
VectorType(768), | ||
default=None, | ||
comment="hnsw(distance=l2)", | ||
nullable=False, | ||
) | ||
) | ||
value: Optional[str] = Field(sa_column=Column(Text)) | ||
created_at: datetime = Field( | ||
sa_column=Column(DateTime, server_default=func.now(), nullable=False) | ||
) | ||
updated_at: datetime = Field( | ||
sa_column=Column( | ||
DateTime, server_default=func.now(), onupdate=func.now(), nullable=False | ||
) | ||
) | ||
|
||
engine = create_engine(DATABASE_URI) | ||
SQLModel.metadata.create_all(engine) | ||
|
||
app = FastAPI() | ||
security = HTTPBearer() | ||
|
||
@app.get("/") | ||
def index(): | ||
return { | ||
"message": "Welcome to Semantic Cache API, it is built using Jina AI Embeddings API and TiDB Vector", | ||
"docs": "/docs", | ||
"redoc": "/redoc", | ||
"about": "https://github.com/pingcap/tidb-vector-python/blob/main/examples/semantic-cache/README.md", | ||
"config": { | ||
"TIME_TO_LIVE": int(TIME_TO_LIVE), | ||
"EMBEDDING_DIMENSIONS": 768, | ||
"EMBEDDING_PROVIDER": "Jina AI", | ||
"EMBEDDING_MODEL": "jina-embeddings-v2-base-en", | ||
} | ||
} | ||
|
||
|
||
# /set method of Semantic Cache | ||
@app.post("/set") | ||
def set( | ||
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], | ||
cache: Cache, | ||
): | ||
cache.key_vec = generate_embeddings(credentials.credentials, cache.key) | ||
|
||
with Session(engine) as session: | ||
session.add(cache) | ||
session.commit() | ||
|
||
return {'message': 'Cache has been set'} | ||
|
||
|
||
@app.get("/get/{key}") | ||
def get( | ||
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], | ||
key: str, | ||
max_distance: Optional[float] = 0.1, | ||
): | ||
key_vec = generate_embeddings(credentials.credentials, key) | ||
# The max value of distance is 0.3 | ||
max_distance = min(max_distance, 0.3) | ||
|
||
with Session(engine) as session: | ||
result = session.exec( | ||
select( | ||
Cache, | ||
Cache.key_vec.cosine_distance(key_vec).label('distance') | ||
).order_by( | ||
'distance' | ||
).limit(1) | ||
).first() | ||
|
||
if result is None: | ||
return {"message": "Cache not found"}, 404 | ||
|
||
cache, distance = result | ||
if distance > max_distance: | ||
return {"message": "Cache not found"}, 404 | ||
|
||
return { | ||
"key": cache.key, | ||
"value": cache.value, | ||
"distance": distance | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
requests | ||
PyMySQL | ||
sqlmodel==0.0.19 | ||
tidb-vector>=0.0.9 | ||
python-dotenv | ||
fastapi |