From 59b6909a0beffd7e94975fa3cbd584616412ba29 Mon Sep 17 00:00:00 2001 From: Timothy Olaleke Date: Sun, 19 Jan 2025 15:01:40 +0000 Subject: [PATCH] feat: weaviate multitenancy support --- .../retrieval_models_clients/WeaviateRM.md | 14 +++++++++++ dspy/retrieve/weaviate_rm.py | 25 +++++++++---------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/docs/docs/deep-dive/retrieval_models_clients/WeaviateRM.md b/docs/docs/deep-dive/retrieval_models_clients/WeaviateRM.md index ef9049cee..5b7c4c958 100644 --- a/docs/docs/deep-dive/retrieval_models_clients/WeaviateRM.md +++ b/docs/docs/deep-dive/retrieval_models_clients/WeaviateRM.md @@ -19,6 +19,20 @@ WeaviateRM( ) ``` +## Using Multitenancy +If your Weaviate instance is tenant-aware, you can provide a tenant_id in the WeaviateRM constructor or as a keyword argument: + +```python +retriever_model = WeaviateRM( + weaviate_collection_name="", + weaviate_client=weaviate_client, + tenant_id="tenant123" +) + +results = retriever_model("Your query here", tenant_id="tenantXYZ") +``` +When tenant_id is specified, this will scope all retrieval requests to the tenant ID provided. + ## Under the Hood `forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None, **kwargs) -> dspy.Prediction` diff --git a/dspy/retrieve/weaviate_rm.py b/dspy/retrieve/weaviate_rm.py index c9e1f6533..a471b6831 100644 --- a/dspy/retrieve/weaviate_rm.py +++ b/dspy/retrieve/weaviate_rm.py @@ -51,11 +51,13 @@ def __init__( weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], weaviate_collection_text_key: Optional[str] = "content", k: int = 3, + tenant_id: Optional[str] = None, ): self._weaviate_collection_name = weaviate_collection_name self._weaviate_client = weaviate_client self._weaviate_collection = self._weaviate_client.collections.get(self._weaviate_collection_name) self._weaviate_collection_text_key = weaviate_collection_text_key + self._tenant_id = tenant_id # Check the type of weaviate_client (this is added to support v3 and v4) if hasattr(weaviate_client, "collections"): @@ -82,26 +84,23 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries queries = [q for q in queries if q] passages, parsed_results = [], [] + tenant = kwargs.pop("tenant_id", self._tenant_id) for query in queries: if self._client_type == "WeaviateClient": - results = self._weaviate_collection.query.hybrid( - query=query, - limit=k, - **kwargs, - ) + if tenant: + results = self._weaviate_collection.query.tenant(tenant).hybrid(query=query, limit=k, **kwargs) + else: + results = self._weaviate_collection.query.hybrid(query=query, limit=k, **kwargs) parsed_results = [result.properties[self._weaviate_collection_text_key] for result in results.objects] elif self._client_type == "Client": - results = ( - self._weaviate_client.query.get( - self._weaviate_collection_name, - [self._weaviate_collection_text_key], + q = self._weaviate_client.query.get( + self._weaviate_collection_name, [self._weaviate_collection_text_key] ) - .with_hybrid(query=query) - .with_limit(k) - .do() - ) + if tenant: + q = q.with_tenant(tenant) + results = q.with_hybrid(query=query).with_limit(k).do() results = results["data"]["Get"][self._weaviate_collection_name] parsed_results = [result[self._weaviate_collection_text_key] for result in results]