Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vector index for SQLAlchemy #65

Merged
merged 8 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ jobs:

- name: Run lint
run: |
tox -e lint
tox -e lint

tests:
strategy:
fail-fast: false
matrix:
python-version:
- '3.12'
- "3.12"
name: py${{ matrix.python-version }}_test
runs-on: ubuntu-latest
services:
tidb:
image: wangdi4zm/tind:v7.5.3-vector-index
image: wangdi4zm/tind:v8.4.0-vector-index
ports:
- 4000:4000
steps:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,5 @@ cython_debug/
django_tests_dir

*.swp

.vscode/
114 changes: 59 additions & 55 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# tidb-vector-python

This is a Python client for TiDB Vector.

> Now only TiDB Cloud Serverless cluster support vector data type, see this [docs](https://docs.pingcap.com/tidbcloud/vector-search-overview?utm_source=github&utm_medium=tidb-vector-python) for more information.
Use TiDB Vector Search with Python.

## Installation

Expand All @@ -12,74 +10,79 @@ pip install tidb-vector

## Usage

TiDB vector supports below distance functions:
TiDB is a SQL database so that this package introduces Vector Search capability for Python ORMs:

- `L1Distance`
- `L2Distance`
- `CosineDistance`
- `NegativeInnerProduct`
- [#SQLAlchemy](#sqlalchemy)
- [#Django](#django)
- [#Peewee](#peewee)

It also supports using hnsw index with l2 or cosine distance to speed up the search, for more details see [Vector Search Indexes in TiDB](https://docs.pingcap.com/tidbcloud/vector-search-index)
Pick one that you are familiar with to get started. If you are not using any of them, we recommend [#SQLAlchemy](#sqlalchemy).

Supports following orm or framework:
We also provide a Vector Search client for simple usage:

- [SQLAlchemy](#sqlalchemy)
- [Django](#django)
- [Peewee](#peewee)
- [TiDB Vector Client](#tidb-vector-client)
- [#TiDB Vector Client](#tidb-vector-client)

### SQLAlchemy

Learn how to connect to TiDB Serverless in the [TiDB Cloud documentation](https://docs.pingcap.com/tidbcloud/dev-guide-sample-application-python-sqlalchemy).

Define table with vector field
```bash
pip install tidb-vector sqlalchemy pymysql
```

```python
from sqlalchemy import Column, Integer, create_engine
from sqlalchemy.orm import declarative_base
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy import Integer, Text, Column
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session, declarative_base

engine = create_engine('mysql://****.root:******@gateway01.xxxxxx.shared.aws.tidbcloud.com:4000/test')
import tidb_vector
from tidb_vector.sqlalchemy import VectorType, VectorAdaptor

engine = create_engine("mysql+pymysql://[email protected]:4000/test")
Base = declarative_base()

class Test(Base):
__tablename__ = 'test'
id = Column(Integer, primary_key=True)
embedding = Column(VectorType(3))

# or add hnsw index when creating table
class TestWithIndex(Base):
__tablename__ = 'test_with_index'
# Define table schema
class Doc(Base):
__tablename__ = "doc"
id = Column(Integer, primary_key=True)
embedding = Column(VectorType(3), comment="hnsw(distance=l2)")

Base.metadata.create_all(engine)
```

Insert vector data

```python
test = Test(embedding=[1, 2, 3])
session.add(test)
session.commit()
```

Get the nearest neighbors
embedding = Column(VectorType(3)) # Vector with 3 dimensions
content = Column(Text)

```python
session.scalars(select(Test).order_by(Test.embedding.l2_distance([1, 2, 3.1])).limit(5))
```

Get the distance

```python
session.scalars(select(Test.embedding.l2_distance([1, 2, 3.1])))
```
# Create empty table
Base.metadata.drop_all(engine) # clean data from last run
Base.metadata.create_all(engine)

Get within a certain distance
# Create index using L2 distance
adaptor = VectorAdaptor(engine)
adaptor.create_vector_index(
Doc.embedding, tidb_vector.DistanceMetric.L2, skip_existing=True
)

```python
session.scalars(select(Test).filter(Test.embedding.l2_distance([1, 2, 3.1]) < 0.2))
# Insert content with vectors
with Session(engine) as session:
session.add(Doc(id=1, content="dog", embedding=[1, 2, 1]))
session.add(Doc(id=2, content="fish", embedding=[1, 2, 4]))
session.add(Doc(id=3, content="tree", embedding=[1, 0, 0]))
session.commit()

# Perform Vector Search for Top K=1
with Session(engine) as session:
results = session.execute(
select(Doc.id, Doc.content)
.order_by(Doc.embedding.cosine_distance([1, 2, 3]))
.limit(1)
).all()
print(results)

# Perform filtered Vector Search by adding a Where Clause:
with Session(engine) as session:
results = session.execute(
select(Doc.id, Doc.content)
.where(Doc.id > 2)
.order_by(Doc.embedding.cosine_distance([1, 2, 3]))
.limit(1)
).all()
print(results)
```

### Django
Expand Down Expand Up @@ -165,7 +168,7 @@ TestModel.select().where(TestModel.embedding.l2_distance([1, 2, 3.1]) < 0.5)

### TiDB Vector Client

Within the framework, you can directly utilize the built-in `TiDBVectorClient`, as demonstrated by integrations like [Langchain](https://python.langchain.com/docs/integrations/vectorstores/tidb_vector) and [Llama index](https://docs.llamaindex.ai/en/stable/community/integrations/vector_stores.html#using-a-vector-store-as-an-index), to seamlessly interact with TiDB Vector. This approach abstracts away the need to manage the underlying ORM, simplifying your interaction with the vector store.
Within the framework, you can directly utilize the built-in `TiDBVectorClient`, as demonstrated by integrations like [Langchain](https://python.langchain.com/docs/integrations/vectorstores/tidb_vector) and [Llama index](https://docs.llamaindex.ai/en/stable/community/integrations/vector_stores.html#using-a-vector-store-as-an-index), to seamlessly interact with TiDB Vector. This approach abstracts away the need to manage the underlying ORM, simplifying your interaction with the vector store.

We provide `TiDBVectorClient` which is based on sqlalchemy, you need to use `pip install tidb-vector[client]` to install it.

Expand Down Expand Up @@ -252,4 +255,5 @@ There are some examples to show how to use the tidb-vector-python to interact wi
for more examples, see the [examples](./examples) directory.

## Contributing
Please feel free to reach out to the maintainers if you have any questions or need help with the project. Before contributing, please read the [CONTRIBUTING.md](./CONTRIBUTING.md) file.

Please feel free to reach out to the maintainers if you have any questions or need help with the project. Before contributing, please read the [CONTRIBUTING.md](./CONTRIBUTING.md) file.
90 changes: 86 additions & 4 deletions tests/sqlalchemy/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from sqlalchemy import URL, create_engine, Column, Integer, select
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.exc import OperationalError
from tidb_vector.sqlalchemy import VectorType
from tidb_vector.sqlalchemy import VectorType, VectorAdaptor
import tidb_vector
from ..config import TestConfig


Expand All @@ -14,9 +15,11 @@
host=TestConfig.TIDB_HOST,
port=TestConfig.TIDB_PORT,
database="test",
query={"ssl_verify_cert": True, "ssl_verify_identity": True}
if TestConfig.TIDB_SSL
else {},
query=(
{"ssl_verify_cert": True, "ssl_verify_identity": True}
if TestConfig.TIDB_SSL
else {}
),
)

engine = create_engine(db_url)
Expand Down Expand Up @@ -58,6 +61,15 @@ def test_insert_get_record(self):
assert np.array_equal(item1.embedding, np.array([1, 2, 3]))
assert item1.embedding.dtype == np.float32

def test_insert_get_record_np(self):
with Session() as session:
item1 = Item1Model(embedding=np.array([1, 2, 3]))
session.add(item1)
session.commit()
item1 = session.query(Item1Model).first()
assert np.array_equal(item1.embedding, np.array([1, 2, 3]))
assert item1.embedding.dtype == np.float32

def test_empty_vector(self):
with Session() as session:
item1 = Item1Model(embedding=[])
Expand Down Expand Up @@ -303,3 +315,73 @@ def test_negative_inner_product(self):
)
assert len(items) == 2
assert items[1].distance == -14.0


class TestSQLAlchemyAdaptor:
def setup_method(self):
Item1Model.__table__.drop(bind=engine, checkfirst=True)
Item1Model.__table__.create(bind=engine)
Item2Model.__table__.drop(bind=engine, checkfirst=True)
Item2Model.__table__.create(bind=engine)

def teardown_method(self):
Item1Model.__table__.drop(bind=engine, checkfirst=True)
Item2Model.__table__.drop(bind=engine, checkfirst=True)

def test_create_index_on_dyn_vector(self):
adaptor = VectorAdaptor(engine)
with pytest.raises(ValueError):
adaptor.create_vector_index(
Item1Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
)
assert adaptor.has_vector_index(Item1Model.embedding) is False

def test_create_index_on_fixed_vector(self):
adaptor = VectorAdaptor(engine)
adaptor.create_vector_index(
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
)
assert adaptor.has_vector_index(Item2Model.embedding) is True

with pytest.raises(Exception):
adaptor.create_vector_index(
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
)

assert adaptor.has_vector_index(Item2Model.embedding) is True

adaptor.create_vector_index(
Item2Model.embedding,
distance_metric=tidb_vector.DistanceMetric.L2,
skip_existing=True,
)

adaptor.create_vector_index(
Item2Model.embedding,
distance_metric=tidb_vector.DistanceMetric.COSINE,
skip_existing=True,
)

def test_index_and_search(self):
adaptor = VectorAdaptor(engine)
adaptor.create_vector_index(
Item2Model.embedding, distance_metric=tidb_vector.DistanceMetric.L2
)
assert adaptor.has_vector_index(Item2Model.embedding) is True

with Session() as session:
session.add_all(
[Item2Model(embedding=[1, 2, 3]), Item2Model(embedding=[1, 2, 3.2])]
)
session.commit()

# l2 distance
distance = Item2Model.embedding.cosine_distance([1, 2, 3])
items = (
session.query(Item2Model.id, distance.label("distance"))
.order_by(distance)
.limit(5)
.all()
)
assert len(items) == 2
assert items[0].distance == 0.0
5 changes: 4 additions & 1 deletion tidb_vector/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
__version__ = "0.0.12"
from .constants import MAX_DIM, MIN_DIM, DistanceMetric

__version__ = "0.0.13"
__all__ = ["MAX_DIM", "MIN_DIM", "DistanceMetric"]
19 changes: 17 additions & 2 deletions tidb_vector/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
import enum

# TiDB Vector has a limitation on the dimension length
MAX_DIMENSION_LENGTH = 16000
MIN_DIMENSION_LENGTH = 1
MAX_DIM = 16000
MIN_DIM = 1


class DistanceMetric(enum.Enum):
L2 = "L2"
COSINE = "COSINE"

def to_sql_func(self):
if self == DistanceMetric.L2:
return "VEC_L2_DISTANCE"
elif self == DistanceMetric.COSINE:
return "VEC_COSINE_DISTANCE"
else:
raise ValueError("unsupported distance metric")
Loading
Loading