diff --git a/proto/tei.proto b/proto/tei.proto index aac6c2ba..170fbaf4 100644 --- a/proto/tei.proto +++ b/proto/tei.proto @@ -29,6 +29,11 @@ service Rerank { rpc RerankStream (stream RerankStreamRequest) returns (RerankResponse); } +service Similarity { + rpc Similarity (SimilarityRequest) returns (SimilarityResponse); + rpc SimilarityStream (stream SimilarityStreamRequest) returns (SimilarityResponse); +} + service Tokenize { rpc Tokenize (EncodeRequest) returns (EncodeResponse); rpc TokenizeStream (stream EncodeRequest) returns (stream EncodeResponse); @@ -175,6 +180,27 @@ message RerankResponse { Metadata metadata = 2; } +message SimilarityRequest { + string source_sentence = 1; + repeated string sentences = 2; + bool truncate = 3; + TruncationDirection truncation_direction = 4; + optional string prompt_name = 5; +} + +message SimilarityStreamRequest { + string source_sentence = 1; + string sentence = 2; + bool truncate = 3; + TruncationDirection truncation_direction = 4; + optional string prompt_name = 5; +} + +message SimilarityResponse { + repeated float distances = 1; + Metadata metadata = 2; +} + message EncodeRequest { string inputs = 1; bool add_special_tokens = 2; diff --git a/router/src/grpc/mod.rs b/router/src/grpc/mod.rs index d71b0c07..c02c74fa 100644 --- a/router/src/grpc/mod.rs +++ b/router/src/grpc/mod.rs @@ -3,5 +3,6 @@ pub(crate) mod server; use pb::tei::v1::{ embed_server::EmbedServer, info_server::InfoServer, predict_server::PredictServer, - rerank_server::RerankServer, tokenize_server::TokenizeServer, *, + rerank_server::RerankServer, similarity_server::SimilarityServer, + tokenize_server::TokenizeServer, *, }; diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index d3666214..b47467ec 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -1,16 +1,18 @@ use crate::grpc::pb::tei::v1::{ EmbedAllRequest, EmbedAllResponse, EmbedSparseRequest, EmbedSparseResponse, EncodeRequest, - EncodeResponse, PredictPairRequest, RerankStreamRequest, SimpleToken, SparseValue, - TokenEmbedding, TruncationDirection, + EncodeResponse, PredictPairRequest, RerankStreamRequest, SimilarityStreamRequest, SimpleToken, + SparseValue, TokenEmbedding, TruncationDirection, }; use crate::grpc::{ DecodeRequest, DecodeResponse, EmbedRequest, EmbedResponse, InfoRequest, InfoResponse, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, + SimilarityRequest, SimilarityResponse, }; use crate::ResponseMetadata; use crate::{grpc, shutdown, ErrorResponse, ErrorType, Info, ModelType}; use futures::future::join_all; use metrics_exporter_prometheus::PrometheusBuilder; +use simsimd::SpatialSimilarity; use std::future::Future; use std::net::SocketAddr; use std::time::{Duration, Instant}; @@ -1331,6 +1333,425 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { } } +#[tonic::async_trait] +impl grpc::similarity_server::Similarity for TextEmbeddingsService { + #[instrument( + skip_all, + fields( + compute_chars, + compute_tokens, + total_time, + tokenization_time, + queue_time, + inference_time, + ) + )] + async fn similarity( + &self, + request: Request, + ) -> Result, Status> { + let span = Span::current(); + let start_time = Instant::now(); + + let request = request.into_inner(); + + if request.sentences.is_empty() { + let message = "`sentences` cannot be empty".to_string(); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + let counter = metrics::counter!("te_request_failure", "err" => "validation"); + counter.increment(1); + Err(err)?; + } + + match &self.info.model_type { + ModelType::Embedding(_) => Ok(()), + _ => { + let counter = metrics::counter!("te_request_failure", "err" => "model_type"); + counter.increment(1); + let message = "model is not a embedding model".to_string(); + tracing::error!("{message}"); + Err(Status::new(Code::FailedPrecondition, message)) + } + }?; + + // Closure for similarity + let similarity_inner = move |text: String, + truncate: bool, + truncation_direction: tokenizers::TruncationDirection, + prompt_name: Option, + infer: Infer| async move { + let permit = infer.acquire_permit().await; + + let response = infer + .embed_pooled( + text, + truncate, + truncation_direction, + prompt_name, + false, + permit, + ) + .await + .map_err(ErrorResponse::from)?; + + let embedding = response.results; + + Ok::<(usize, Duration, Duration, Duration, Vec), ErrorResponse>(( + response.metadata.prompt_tokens, + response.metadata.tokenization, + response.metadata.queue, + response.metadata.inference, + embedding, + )) + }; + + let counter = metrics::counter!("te_request_count", "method" => "batch"); + counter.increment(1); + + let batch_size = request.sentences.len() + 1; + if batch_size > self.info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + self.info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + let counter = metrics::counter!("te_request_failure", "err" => "batch_size"); + counter.increment(1); + Err(err)?; + } + + let mut futures = Vec::with_capacity(batch_size); + + let mut total_compute_chars = 0; + let truncation_direction = convert_truncation_direction(request.truncation_direction); + + { + total_compute_chars += request.source_sentence.chars().count(); + let local_infer = self.infer.clone(); + futures.push(similarity_inner( + request.source_sentence.clone(), + request.truncate, + truncation_direction, + request.prompt_name.clone(), + local_infer, + )); + + for text in &request.sentences { + total_compute_chars += text.chars().count(); + let local_infer = self.infer.clone(); + futures.push(similarity_inner( + text.clone(), + request.truncate, + truncation_direction, + request.prompt_name.clone(), + local_infer, + )); + } + } + + let results = join_all(futures) + .await + .into_iter() + .collect::)>, ErrorResponse>>( + )?; + + let mut embeddings = Vec::with_capacity(results.len()); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for r in results.into_iter() { + total_compute_tokens += r.0; + total_tokenization_time += r.1.as_nanos() as u64; + total_queue_time += r.2.as_nanos() as u64; + total_inference_time += r.3.as_nanos() as u64; + embeddings.push(r.4); + } + + let distances = (1..batch_size) + .map(|i| 1.0 - f32::cosine(&embeddings[0], &embeddings[i]).unwrap() as f32) + .collect(); + + let batch_size = batch_size as u64; + + let counter = metrics::counter!("te_request_success", "method" => "batch"); + counter.increment(1); + + let response_metadata = ResponseMetadata::new( + total_compute_chars, + total_compute_tokens, + start_time, + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ); + response_metadata.record_span(&span); + response_metadata.record_metrics(); + + let message = SimilarityResponse { + distances, + metadata: Some(grpc::Metadata::from(&response_metadata)), + }; + + let headers = HeaderMap::from(response_metadata); + + tracing::info!("Success"); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + message, + Extensions::default(), + )) + } + + #[instrument( + skip_all, + fields( + compute_chars, + compute_tokens, + total_time, + tokenization_time, + queue_time, + inference_time, + ) + )] + async fn similarity_stream( + &self, + request: Request>, + ) -> Result, Status> { + let span = Span::current(); + let start_time = Instant::now(); + + // Check model type + match &self.info.model_type { + ModelType::Embedding(_) => Ok(()), + _ => { + let counter = metrics::counter!("te_request_failure", "err" => "model_type"); + counter.increment(1); + let message = "model is not a embedding model".to_string(); + tracing::error!("{message}"); + Err(Status::new(Code::FailedPrecondition, message)) + } + }?; + + // Closure for similarity + let similarity_inner = move |source_sentence: String, + sentence: String, + truncate: bool, + truncation_direction: tokenizers::TruncationDirection, + prompt_name: Option, + infer: Infer| async move { + let source_response; + + { + let permit = infer.acquire_permit().await; + + source_response = infer + .embed_pooled( + source_sentence, + truncate, + truncation_direction, + prompt_name.clone(), + false, + permit, + ) + .await + .map_err(ErrorResponse::from)?; + } + + let response; + + { + let permit = infer.acquire_permit().await; + + response = infer + .embed_pooled( + sentence, + truncate, + truncation_direction, + prompt_name, + false, + permit, + ) + .await + .map_err(ErrorResponse::from)?; + } + + let distance = + 1.0 - f32::cosine(&source_response.results, &response.results).unwrap() as f32; + + Ok::<(usize, Duration, Duration, Duration, f32), ErrorResponse>(( + source_response.metadata.prompt_tokens + response.metadata.prompt_tokens, + source_response.metadata.tokenization + response.metadata.tokenization, + source_response.metadata.queue + response.metadata.queue, + source_response.metadata.inference + response.metadata.inference, + distance, + )) + }; + + let counter = metrics::counter!("te_request_count", "method" => "batch"); + counter.increment(1); + + let mut request_stream = request.into_inner(); + + // Create bounded channel to have an upper bound of spawned tasks + // We will have at most `max_parallel_stream_requests` messages from this stream in the queue + let (similarity_sender, mut similarity_receiver) = + mpsc::channel::<( + ( + String, + String, + bool, + tokenizers::TruncationDirection, + Option, + ), + oneshot::Sender>, + )>(self.max_parallel_stream_requests); + + // Required for the async move below + let local_infer = self.infer.clone(); + + // Background task that uses the bounded channel + tokio::spawn(async move { + while let Some(( + (source_sentence, sentence, truncate, truncation_direction, prompt_name), + mut sender, + )) = similarity_receiver.recv().await + { + // Required for the async move below + let task_infer = local_infer.clone(); + + // Create async task for this specific input + tokio::spawn(async move { + // Select on closed to cancel work if the stream was closed + tokio::select! { + result = similarity_inner(source_sentence, sentence, truncate, truncation_direction, prompt_name, task_infer) => { + let _ = sender.send(result); + } + _ = sender.closed() => {} + } + }); + } + }); + + let mut index = 0; + let mut total_compute_chars = 0; + + // Intermediate channels + // Required to keep the order of the requests + let (intermediate_sender, mut intermediate_receiver) = mpsc::unbounded_channel(); + + while let Some(request) = request_stream.next().await { + let request = request?; + + // Create return channel + let (result_sender, result_receiver) = oneshot::channel(); + // Push to intermediate channel and preserve ordering + intermediate_sender + .send(result_receiver) + .expect("`intermediate_receiver` was dropped. This is a bug."); + + total_compute_chars += request.source_sentence.chars().count(); + total_compute_chars += request.sentence.chars().count(); + + let truncation_direction = convert_truncation_direction(request.truncation_direction); + + similarity_sender + .send(( + ( + request.source_sentence, + request.sentence, + request.truncate, + truncation_direction, + request.prompt_name, + ), + result_sender, + )) + .await + .expect("`similarity_receiver` was dropped. This is a bug."); + + index += 1; + } + + // Drop the sender to signal to the underlying task that we are done + drop(similarity_sender); + + let batch_size = index; + + let mut distances: Vec = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + // Iterate on result stream + while let Some(result_receiver) = intermediate_receiver.recv().await { + let r = result_receiver + .await + .expect("`result_sender` was dropped. This is a bug.")?; + + total_compute_tokens += r.0; + total_tokenization_time += r.1.as_nanos() as u64; + total_queue_time += r.2.as_nanos() as u64; + total_inference_time += r.3.as_nanos() as u64; + + distances.push(r.4) + } + + if distances.len() < batch_size { + let message = "similarity results is missing values".to_string(); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Backend, + }; + let counter = metrics::counter!("te_request_failure", "err" => "missing_values"); + counter.increment(1); + Err(err)?; + } + + let batch_size = batch_size as u64; + + let counter = metrics::counter!("te_request_success", "method" => "batch"); + counter.increment(1); + + let response_metadata = ResponseMetadata::new( + total_compute_chars, + total_compute_tokens, + start_time, + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ); + response_metadata.record_span(&span); + response_metadata.record_metrics(); + + let message = SimilarityResponse { + distances, + metadata: Some(grpc::Metadata::from(&response_metadata)), + }; + + let headers = HeaderMap::from(response_metadata); + + tracing::info!("Success"); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + message, + Extensions::default(), + )) + } +} + #[tonic::async_trait] impl grpc::tokenize_server::Tokenize for TextEmbeddingsService { async fn tokenize( @@ -1406,6 +1827,9 @@ pub async fn run( health_reporter .set_not_serving::>() .await; + health_reporter + .set_not_serving::>() + .await; health_reporter .set_not_serving::>() .await; @@ -1450,7 +1874,13 @@ pub async fn run( >::NAME, status, ) - .await + .await; + health_reporter + .set_service_status( + >::NAME, + status, + ) + .await; } ModelType::Reranker(_) => { // Reranker has both a predict and rerank service @@ -1505,7 +1935,8 @@ pub async fn run( )) .add_service(grpc::EmbedServer::with_interceptor(service.clone(), auth)) .add_service(grpc::PredictServer::with_interceptor(service.clone(), auth)) - .add_service(grpc::RerankServer::with_interceptor(service, auth)) + .add_service(grpc::RerankServer::with_interceptor(service.clone(), auth)) + .add_service(grpc::SimilarityServer::with_interceptor(service, auth)) .serve_with_shutdown(addr, shutdown::shutdown_signal()) } else { Server::builder() @@ -1515,7 +1946,8 @@ pub async fn run( .add_service(grpc::TokenizeServer::new(service.clone())) .add_service(grpc::EmbedServer::new(service.clone())) .add_service(grpc::PredictServer::new(service.clone())) - .add_service(grpc::RerankServer::new(service)) + .add_service(grpc::RerankServer::new(service.clone())) + .add_service(grpc::SimilarityServer::new(service)) .serve_with_shutdown(addr, shutdown::shutdown_signal()) };