Skip to content

Commit

Permalink
Merge branch 'main' into feature/gte-classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr authored Nov 26, 2024
2 parents 63c964b + 7c4f67e commit efdffb0
Show file tree
Hide file tree
Showing 8 changed files with 1,081 additions and 734 deletions.
1,551 changes: 917 additions & 634 deletions Cargo.lock

Large diffs are not rendered by default.

37 changes: 33 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,22 @@ ARG ACTIONS_CACHE_URL
ARG ACTIONS_RUNTIME_TOKEN
ARG SCCACHE_GHA_ENABLED

RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \
tee /etc/apt/sources.list.d/oneAPI.list

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
intel-oneapi-mkl-devel=2024.0.0-49656 \
build-essential \
&& rm -rf /var/lib/apt/lists/*

RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \
gcc -shared -fPIC -o libfakeintel.so fakeintel.c

COPY --from=planner /usr/src/recipe.json recipe.json

RUN cargo chef cook --release --features ort --no-default-features --recipe-path recipe.json && sccache -s
RUN cargo chef cook --release --features ort --features candle --features mkl-dynamic --no-default-features --recipe-path recipe.json && sccache -s

COPY backends backends
COPY core core
Expand All @@ -40,7 +53,7 @@ COPY Cargo.lock ./

FROM builder AS http-builder

RUN cargo build --release --bin text-embeddings-router -F ort -F http --no-default-features && sccache -s
RUN cargo build --release --bin text-embeddings-router -F ort -F candle -F mkl-dynamic -F http --no-default-features && sccache -s

FROM builder AS grpc-builder

Expand All @@ -52,19 +65,35 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \

COPY proto proto

RUN cargo build --release --bin text-embeddings-router -F grpc -F ort --no-default-features && sccache -s
RUN cargo build --release --bin text-embeddings-router -F grpc -F ort -F candle -F mkl-dynamic --no-default-features && sccache -s

FROM debian:bookworm-slim AS base

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80
PORT=80 \
MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \
RAYON_NUM_THREADS=8 \
LD_PRELOAD=/usr/local/libfakeintel.so \
LD_LIBRARY_PATH=/usr/local/lib

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
libomp-dev \
ca-certificates \
libssl-dev \
curl \
&& rm -rf /var/lib/apt/lists/*

# Copy a lot of the Intel shared objects because of the mkl_serv_intel_cpu_true patch...
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_lp64.so.2 /usr/local/lib/libmkl_intel_lp64.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread.so.2 /usr/local/lib/libmkl_intel_thread.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2
COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so

FROM base AS grpc

Expand Down
4 changes: 2 additions & 2 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use candle_nn::VarBuilder;
use nohash_hasher::BuildNoHashHasher;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::PathBuf;
use std::path::Path;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
};
Expand Down Expand Up @@ -69,7 +69,7 @@ pub struct CandleBackend {

impl CandleBackend {
pub fn new(
model_path: PathBuf,
model_path: &Path,
dtype: String,
model_type: ModelType,
) -> Result<Self, BackendError> {
Expand Down
2 changes: 2 additions & 0 deletions backends/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,6 @@ pub enum BackendError {
Inference(String),
#[error("Backend is unhealthy")]
Unhealthy,
#[error("Weights not found: {0}")]
WeightsNotFound(String),
}
11 changes: 6 additions & 5 deletions backends/ort/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use ndarray::{s, Axis};
use nohash_hasher::BuildNoHashHasher;
use ort::{GraphOptimizationLevel, Session};
use ort::session::{builder::GraphOptimizationLevel, Session};
use std::collections::HashMap;
use std::ops::{Div, Mul};
use std::path::PathBuf;
use std::path::Path;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
};
Expand All @@ -16,12 +16,12 @@ pub struct OrtBackend {

impl OrtBackend {
pub fn new(
model_path: PathBuf,
model_path: &Path,
dtype: String,
model_type: ModelType,
) -> Result<Self, BackendError> {
// Check dtype
if &dtype == "float32" {
if dtype == "float32" {
} else {
return Err(BackendError::Start(format!(
"DType {dtype} is not supported"
Expand Down Expand Up @@ -246,7 +246,8 @@ impl Backend for OrtBackend {
if has_raw_requests {
// Reshape outputs
let s = outputs.shape().to_vec();
let outputs = outputs.into_shape_with_order((s[0] * s[1], s[2])).e()?;
#[allow(deprecated)]
let outputs = outputs.into_shape((s[0] * s[1], s[2])).e()?;

// We need to remove the padding tokens only if batch_size > 1 and there are some
// member of the batch that require pooling
Expand Down
145 changes: 89 additions & 56 deletions backends/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ pub struct Backend {
}

impl Backend {
pub fn new(
pub async fn new(
model_path: PathBuf,
api_repo: Option<ApiRepo>,
dtype: DType,
model_type: ModelType,
uds_path: String,
Expand All @@ -49,12 +50,14 @@ impl Backend {

let backend = init_backend(
model_path,
api_repo,
dtype,
model_type.clone(),
uds_path,
otlp_endpoint,
otlp_service_name,
)?;
)
.await?;
let padded_model = backend.is_padded();
let max_batch_size = backend.max_batch_size();

Expand Down Expand Up @@ -193,48 +196,102 @@ impl Backend {
}

#[allow(unused)]
fn init_backend(
async fn init_backend(
model_path: PathBuf,
api_repo: Option<ApiRepo>,
dtype: DType,
model_type: ModelType,
uds_path: String,
otlp_endpoint: Option<String>,
otlp_service_name: String,
) -> Result<Box<dyn CoreBackend + Send>, BackendError> {
let mut backend_start_failed = false;

if cfg!(feature = "ort") {
#[cfg(feature = "ort")]
{
if let Some(api_repo) = api_repo.as_ref() {
let start = std::time::Instant::now();
download_onnx(api_repo)
.await
.map_err(|err| BackendError::WeightsNotFound(err.to_string()));
tracing::info!("Model ONNX weights downloaded in {:?}", start.elapsed());
}

let backend = OrtBackend::new(&model_path, dtype.to_string(), model_type.clone());
match backend {
Ok(b) => return Ok(Box::new(b)),
Err(err) => {
tracing::error!("Could not start ORT backend: {err}");
backend_start_failed = true;
}
}
}
}

if let Some(api_repo) = api_repo.as_ref() {
if cfg!(feature = "python") || cfg!(feature = "candle") {
let start = std::time::Instant::now();
if download_safetensors(api_repo).await.is_err() {
tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower.");
tracing::info!("Downloading `pytorch_model.bin`");
api_repo
.get("pytorch_model.bin")
.await
.map_err(|err| BackendError::WeightsNotFound(err.to_string()))?;
}

tracing::info!("Model weights downloaded in {:?}", start.elapsed());
}
}

if cfg!(feature = "candle") {
#[cfg(feature = "candle")]
return Ok(Box::new(CandleBackend::new(
model_path,
dtype.to_string(),
model_type,
)?));
} else if cfg!(feature = "python") {
{
let backend = CandleBackend::new(&model_path, dtype.to_string(), model_type.clone());
match backend {
Ok(b) => return Ok(Box::new(b)),
Err(err) => {
tracing::error!("Could not start Candle backend: {err}");
backend_start_failed = true;
}
}
}
}

if cfg!(feature = "python") {
#[cfg(feature = "python")]
{
return Ok(Box::new(
std::thread::spawn(move || {
PythonBackend::new(
model_path.to_str().unwrap().to_string(),
dtype.to_string(),
model_type,
uds_path,
otlp_endpoint,
otlp_service_name,
)
})
.join()
.expect("Python Backend management thread failed")?,
));
let backend = std::thread::spawn(move || {
PythonBackend::new(
model_path.to_str().unwrap().to_string(),
dtype.to_string(),
model_type,
uds_path,
otlp_endpoint,
otlp_service_name,
)
})
.join()
.expect("Python Backend management thread failed");

match backend {
Ok(b) => return Ok(Box::new(b)),
Err(err) => {
tracing::error!("Could not start Python backend: {err}");
backend_start_failed = true;
}
}
}
} else if cfg!(feature = "ort") {
#[cfg(feature = "ort")]
return Ok(Box::new(OrtBackend::new(
model_path,
dtype.to_string(),
model_type,
)?));
}
Err(BackendError::NoBackend)

if backend_start_failed {
Err(BackendError::Start(
"Could not start a suitable backend".to_string(),
))
} else {
Err(BackendError::NoBackend)
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -298,31 +355,6 @@ enum BackendCommand {
),
}

pub async fn download_weights(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
let model_files = if cfg!(feature = "python") || cfg!(feature = "candle") {
match download_safetensors(api).await {
Ok(p) => p,
Err(_) => {
tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower.");
tracing::info!("Downloading `pytorch_model.bin`");
let p = api.get("pytorch_model.bin").await?;
vec![p]
}
}
} else if cfg!(feature = "ort") {
match download_onnx(api).await {
Ok(p) => p,
Err(err) => {
panic!("failed to download `model.onnx` or `model.onnx_data`. Check the onnx file exists in the repository. {err}");
}
}
} else {
unreachable!()
};

Ok(model_files)
}

async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
// Single file
tracing::info!("Downloading `model.safetensors`");
Expand Down Expand Up @@ -362,6 +394,7 @@ async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
Ok(safetensors_files)
}

#[cfg(feature = "ort")]
async fn download_onnx(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
let mut model_files: Vec<PathBuf> = Vec::new();

Expand Down
27 changes: 21 additions & 6 deletions core/src/download.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use hf_hub::api::tokio::{ApiError, ApiRepo};
use std::path::PathBuf;
use text_embeddings_backend::download_weights;
use tracing::instrument;

// Old classes used other config names than 'sentence_bert_config.json'
Expand All @@ -15,20 +14,36 @@ pub const ST_CONFIG_NAMES: [&str; 7] = [
];

#[instrument(skip_all)]
pub async fn download_artifacts(api: &ApiRepo) -> Result<PathBuf, ApiError> {
pub async fn download_artifacts(api: &ApiRepo, pool_config: bool) -> Result<PathBuf, ApiError> {
let start = std::time::Instant::now();

tracing::info!("Starting download");

// Optionally download the pooling config.
if pool_config {
// If a pooling config exist, download it
let _ = download_pool_config(api).await.map_err(|err| {
tracing::warn!("Download failed: {err}");
err
});
}

// Download legacy sentence transformers config
// We don't warn on failure as it is a legacy file
let _ = download_st_config(api).await;
// Download new sentence transformers config
let _ = download_new_st_config(api).await.map_err(|err| {
tracing::warn!("Download failed: {err}");
err
});

tracing::info!("Downloading `config.json`");
api.get("config.json").await?;

tracing::info!("Downloading `tokenizer.json`");
api.get("tokenizer.json").await?;

let model_files = download_weights(api).await?;
let model_root = model_files[0].parent().unwrap().to_path_buf();
let tokenizer_path = api.get("tokenizer.json").await?;

let model_root = tokenizer_path.parent().unwrap().to_path_buf();
tracing::info!("Model artifacts downloaded in {:?}", start.elapsed());
Ok(model_root)
}
Expand Down
Loading

0 comments on commit efdffb0

Please sign in to comment.