-
Notifications
You must be signed in to change notification settings - Fork 161
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature: implemented rerank genai comp using predictionguard api
- Loading branch information
1 parent
7c0bf3d
commit a47976e
Showing
1 changed file
with
67 additions
and
0 deletions.
There are no files selected for viewing
67 changes: 67 additions & 0 deletions
67
comps/reranks/predictionguard/src/reranks_predictionguard.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (C) 2024 Prediction Guard, Inc. | ||
# SPDX-License-Identified: Apache-2.0 | ||
import logging | ||
import time | ||
|
||
from fastapi import FastAPI, HTTPException | ||
from fastapi.responses import StreamingResponse | ||
from predictionguard import PredictionGuard | ||
|
||
from comps import ( | ||
GeneratedDoc, | ||
LLMParamsDoc, | ||
ServiceType, | ||
opea_microservices, | ||
register_microservice, | ||
register_statistics, | ||
statistics_dict, SearchedDoc, RerankedDoc, TextDoc, | ||
) | ||
from comps.reranks.predictionguard.src.helpers import process_doc_list | ||
|
||
client = PredictionGuard() | ||
app = FastAPI() | ||
|
||
|
||
@register_microservice( | ||
name="opea_service@reranks_predictionguard", | ||
service_type=ServiceType.LLM, | ||
endpoint="/v1/reranking", | ||
host="0.0.0.0", | ||
port=9000, | ||
input_datatype=SearchedDoc, | ||
output_datatype=RerankedDoc, | ||
) | ||
@register_statistics(names=["opea_service@reranks_predictionguard"]) | ||
def reranks_generate(input: SearchedDoc) -> RerankedDoc: | ||
start = time.time() | ||
reranked_docs = [] | ||
|
||
if input.retrieved_docs: | ||
docs = process_doc_list(input.retrieved_docs) | ||
|
||
try: | ||
rerank_result = client.rerank.create( | ||
model="bge-reranker-v2-m3", | ||
query=input.initial_query, | ||
documents=docs, | ||
return_documents=True | ||
) | ||
|
||
# based on rerank_result, reorder the retrieved_docs to match the order of the retrieved_docs in the input | ||
reranked_docs = [TextDoc(id=input.retrieved_docs[doc["index"]].id, text=doc["text"]) for doc in rerank_result["results"]] | ||
|
||
|
||
|
||
except ValueError as e: | ||
logging.error(f"rerank failed with error: {e}. Inputs: query={input.initial_query}, documents={docs}") | ||
raise HTTPException(status_code=500, detail="An unexpected error occurred.") | ||
else: | ||
logging.info("reranking request input did not contain any documents") | ||
|
||
|
||
statistics_dict["opea_service@reranks_predictionguard"].append_latency(time.time() - start, None) | ||
return RerankedDoc(initial_query=input.initial_query, reranked_docs=reranked_docs) | ||
|
||
if __name__ == "__main__": | ||
opea_microservices["opea_service@reranks_predictionguard"].start() |