Skip to content

Commit

Permalink
Add post filter workaround logic for tidb vector search performance? (#…
Browse files Browse the repository at this point in the history
…46)

It's a workaround for TiDB Vector search execution plan limitation: the
where condition will cause the vector index disabled while searching.

No recommand to set post_filter_enabled and post_filter_multiplier, but
It can be used to workaround while encountering some performance issue.

---------

Co-authored-by: Ian Zhai <[email protected]>
Co-authored-by: WD <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2024
1 parent 1419d84 commit 9f824c2
Showing 1 changed file with 65 additions and 20 deletions.
85 changes: 65 additions & 20 deletions tidb_vector/integrations/vector_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,16 @@ def query(
k (int, optional): The number of results to return. Defaults to 5.
filter (dict, optional): A filter to apply to the search results.
Defaults to None.
post_filter_enabled (bool, optional): Whether to apply the post-filtering.
TiDB cannot utilize Vector Index when query contains a pre-filter.
post_filter_multiplier (int, optional): A multiplier to increase the initial
number of results fetched before applying the filter. Defaults to 1.
**kwargs: Additional keyword arguments.
Returns:
A list of tuples containing relevant documents and their similarity scores.
"""
relevant_docs = self._vector_search(query_vector, k, filter)
relevant_docs = self._vector_search(query_vector, k, filter, **kwargs)

return [
QueryResult(
Expand All @@ -296,26 +300,62 @@ def _vector_search(
query_embedding: List[float],
k: int = 5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Any]:
"""vector search from table."""

filter_by = self._build_filter_clause(filter)
post_filter_enabled = kwargs.get("post_filter_enabled", False)
post_filter_multiplier = kwargs.get("post_filter_multiplier", 1)
with Session(self._bind) as session:
results: List[Any] = (
session.query(
self._table_model.id,
self._table_model.meta,
self._table_model.document,
self.distance_strategy(query_embedding).label("distance"),
if post_filter_enabled is False or not filter:
filter_by = self._build_filter_clause(filter)
results: List[Any] = (
session.query(
self._table_model.id,
self._table_model.meta,
self._table_model.document,
self.distance_strategy(query_embedding).label("distance"),
)
.filter(filter_by)
.order_by(sqlalchemy.asc("distance"))
.limit(k)
.all()
)
else:
# Caused by the tidb vector search plan limited, this post_filter_multiplier is used to
# improved the search performance temporarily.
# Notice the return count may be less than k in this situation.
subquery = (
session.query(
self._table_model.id,
self._table_model.meta,
self._table_model.document,
self.distance_strategy(query_embedding).label("distance"),
)
.order_by(sqlalchemy.asc("distance"))
.limit(post_filter_multiplier * k * 10)
.subquery()
)
filter_by = self._build_filter_clause(filter, subquery.c)
results: List[Any] = (
session.query(
subquery.c.id,
subquery.c.meta,
subquery.c.document,
subquery.c.distance,
)
.filter(filter_by)
.order_by(sqlalchemy.asc(subquery.c.distance))
.limit(k)
.all()
)
.filter(filter_by)
.order_by(sqlalchemy.asc("distance"))
.limit(k)
.all()
)
return results

def _build_filter_clause(self, filters: Optional[Dict[str, Any]] = None) -> Any:
def _build_filter_clause(
self,
filters: Optional[Dict[str, Any]] = None,
table_model: Optional[Any] = None,
) -> Any:
"""
Builds the filter clause for querying based on the provided filters.
Expand All @@ -326,22 +366,25 @@ def _build_filter_clause(self, filters: Optional[Dict[str, Any]] = None) -> Any:
Any: The filter clause to be used in the query on TiDB.
"""

if table_model is None:
table_model = self._table_model

filter_by = sqlalchemy.true()
if filters is not None:
filter_clauses = []

for key, value in filters.items():
if key.lower() == "$and":
and_clauses = [
self._build_filter_clause(condition)
self._build_filter_clause(condition, table_model)
for condition in value
if isinstance(condition, dict) and condition is not None
]
filter_by_metadata = sqlalchemy.and_(*and_clauses)
filter_clauses.append(filter_by_metadata)
elif key.lower() == "$or":
or_clauses = [
self._build_filter_clause(condition)
self._build_filter_clause(condition, table_model)
for condition in value
if isinstance(condition, dict) and condition is not None
]
Expand All @@ -362,21 +405,23 @@ def _build_filter_clause(self, filters: Optional[Dict[str, Any]] = None) -> Any:
f"Operator {key} must be followed by a meta key. "
)
elif isinstance(value, dict):
filter_by_metadata = self._create_filter_clause(key, value)
filter_by_metadata = self._create_filter_clause(
table_model, key, value
)

if filter_by_metadata is not None:
filter_clauses.append(filter_by_metadata)
else:
filter_by_metadata = (
sqlalchemy.func.json_extract(self._table_model.meta, f"$.{key}")
sqlalchemy.func.json_extract(table_model.meta, f"$.{key}")
== value
)
filter_clauses.append(filter_by_metadata)

filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
return filter_by

def _create_filter_clause(self, key, value):
def _create_filter_clause(self, table_model, key, value):
"""
Create a filter clause based on the provided key-value pair.
Expand All @@ -403,7 +448,7 @@ def _create_filter_clause(self, key, value):
"$ne",
)

json_key = sqlalchemy.func.json_extract(self._table_model.meta, f"$.{key}")
json_key = sqlalchemy.func.json_extract(table_model.meta, f"$.{key}")
value_case_insensitive = {k.lower(): v for k, v in value.items()}

if IN in map(str.lower, value):
Expand Down

0 comments on commit 9f824c2

Please sign in to comment.