diff --git a/tidb_vector/integrations/vector_client.py b/tidb_vector/integrations/vector_client.py index a648fad..dd44f17 100644 --- a/tidb_vector/integrations/vector_client.py +++ b/tidb_vector/integrations/vector_client.py @@ -274,12 +274,16 @@ def query( k (int, optional): The number of results to return. Defaults to 5. filter (dict, optional): A filter to apply to the search results. Defaults to None. + post_filter_enabled (bool, optional): Whether to apply the post-filtering. + TiDB cannot utilize Vector Index when query contains a pre-filter. + post_filter_multiplier (int, optional): A multiplier to increase the initial + number of results fetched before applying the filter. Defaults to 1. **kwargs: Additional keyword arguments. Returns: A list of tuples containing relevant documents and their similarity scores. """ - relevant_docs = self._vector_search(query_vector, k, filter) + relevant_docs = self._vector_search(query_vector, k, filter, **kwargs) return [ QueryResult( @@ -296,26 +300,62 @@ def _vector_search( query_embedding: List[float], k: int = 5, filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Any]: """vector search from table.""" - filter_by = self._build_filter_clause(filter) + post_filter_enabled = kwargs.get("post_filter_enabled", False) + post_filter_multiplier = kwargs.get("post_filter_multiplier", 1) with Session(self._bind) as session: - results: List[Any] = ( - session.query( - self._table_model.id, - self._table_model.meta, - self._table_model.document, - self.distance_strategy(query_embedding).label("distance"), + if post_filter_enabled is False or not filter: + filter_by = self._build_filter_clause(filter) + results: List[Any] = ( + session.query( + self._table_model.id, + self._table_model.meta, + self._table_model.document, + self.distance_strategy(query_embedding).label("distance"), + ) + .filter(filter_by) + .order_by(sqlalchemy.asc("distance")) + .limit(k) + .all() + ) + else: + # Caused by the tidb vector search plan limited, this post_filter_multiplier is used to + # improved the search performance temporarily. + # Notice the return count may be less than k in this situation. + subquery = ( + session.query( + self._table_model.id, + self._table_model.meta, + self._table_model.document, + self.distance_strategy(query_embedding).label("distance"), + ) + .order_by(sqlalchemy.asc("distance")) + .limit(post_filter_multiplier * k * 10) + .subquery() + ) + filter_by = self._build_filter_clause(filter, subquery.c) + results: List[Any] = ( + session.query( + subquery.c.id, + subquery.c.meta, + subquery.c.document, + subquery.c.distance, + ) + .filter(filter_by) + .order_by(sqlalchemy.asc(subquery.c.distance)) + .limit(k) + .all() ) - .filter(filter_by) - .order_by(sqlalchemy.asc("distance")) - .limit(k) - .all() - ) return results - def _build_filter_clause(self, filters: Optional[Dict[str, Any]] = None) -> Any: + def _build_filter_clause( + self, + filters: Optional[Dict[str, Any]] = None, + table_model: Optional[Any] = None, + ) -> Any: """ Builds the filter clause for querying based on the provided filters. @@ -326,6 +366,9 @@ def _build_filter_clause(self, filters: Optional[Dict[str, Any]] = None) -> Any: Any: The filter clause to be used in the query on TiDB. """ + if table_model is None: + table_model = self._table_model + filter_by = sqlalchemy.true() if filters is not None: filter_clauses = [] @@ -333,7 +376,7 @@ def _build_filter_clause(self, filters: Optional[Dict[str, Any]] = None) -> Any: for key, value in filters.items(): if key.lower() == "$and": and_clauses = [ - self._build_filter_clause(condition) + self._build_filter_clause(condition, table_model) for condition in value if isinstance(condition, dict) and condition is not None ] @@ -341,7 +384,7 @@ def _build_filter_clause(self, filters: Optional[Dict[str, Any]] = None) -> Any: filter_clauses.append(filter_by_metadata) elif key.lower() == "$or": or_clauses = [ - self._build_filter_clause(condition) + self._build_filter_clause(condition, table_model) for condition in value if isinstance(condition, dict) and condition is not None ] @@ -362,13 +405,15 @@ def _build_filter_clause(self, filters: Optional[Dict[str, Any]] = None) -> Any: f"Operator {key} must be followed by a meta key. " ) elif isinstance(value, dict): - filter_by_metadata = self._create_filter_clause(key, value) + filter_by_metadata = self._create_filter_clause( + table_model, key, value + ) if filter_by_metadata is not None: filter_clauses.append(filter_by_metadata) else: filter_by_metadata = ( - sqlalchemy.func.json_extract(self._table_model.meta, f"$.{key}") + sqlalchemy.func.json_extract(table_model.meta, f"$.{key}") == value ) filter_clauses.append(filter_by_metadata) @@ -376,7 +421,7 @@ def _build_filter_clause(self, filters: Optional[Dict[str, Any]] = None) -> Any: filter_by = sqlalchemy.and_(filter_by, *filter_clauses) return filter_by - def _create_filter_clause(self, key, value): + def _create_filter_clause(self, table_model, key, value): """ Create a filter clause based on the provided key-value pair. @@ -403,7 +448,7 @@ def _create_filter_clause(self, key, value): "$ne", ) - json_key = sqlalchemy.func.json_extract(self._table_model.meta, f"$.{key}") + json_key = sqlalchemy.func.json_extract(table_model.meta, f"$.{key}") value_case_insensitive = {k.lower(): v for k, v in value.items()} if IN in map(str.lower, value):