diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index 2a04a256..016e5937 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,6 +1,6 @@ use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; -use crate::models::{apply_rotary, cos_sin, inv_freqs, Model, PositionEmbeddingType}; -use candle::{Device, IndexOp, Result, Tensor, D}; +use crate::models::{apply_rotary, inv_freqs, Model, PositionEmbeddingType}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; @@ -108,12 +108,6 @@ impl GTEAttention { if let (Device::Cuda(_), Some(cublaslt)) = (device, get_cublas_lt_wrapper()) { #[cfg(feature = "cuda")] { - // cuBLASLt batch matmul implementation requires inputs to be dims3 - let (batch_size, _, seq_len, _) = k.shape().dims4()?; - let k = k.flatten(0, 1)?; - let q = q.flatten(0, 1)?; - let v = v.flatten(0, 1)?; - // Batch matrix multiplication // Fuse softmax scale and attention_bias add let attention_scores = cublaslt.batch_matmul( @@ -127,7 +121,7 @@ impl GTEAttention { )?; let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; - let context_layer = cublaslt.batch_matmul( + cublaslt.batch_matmul( &v.t()?.contiguous()?, &attention_probs, // We save one allocation @@ -136,15 +130,7 @@ impl GTEAttention { None, None, None, - )?; - - // Reshape to dims4 - context_layer.reshape(( - batch_size, - self.num_attention_heads, - seq_len, - self.attention_head_size, - )) + ) } #[cfg(not(feature = "cuda"))] { @@ -157,7 +143,7 @@ impl GTEAttention { attention_probs.matmul(&v.contiguous()?) }?; - let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; + let context_layer = context_layer.flatten_from(D::Minus2)?; let hidden_states = self.o_proj.forward(&context_layer)?; @@ -580,3 +566,14 @@ impl Model for GTEModel { } } } + +fn cos_sin(length: usize, inv_freqs: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> { + let t = Tensor::arange(0u32, length as u32, inv_freqs.device())? + .to_dtype(DType::F32)? + .reshape((length, 1))?; + + let freqs = t.matmul(inv_freqs)?; + let cos = freqs.cos()?.to_dtype(dtype)?; + let sin = freqs.sin()?.to_dtype(dtype)?; + Ok((cos, sin)) +} diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index 2c7b2322..ff10518f 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -44,7 +44,7 @@ pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, NTKScaling, RopeScalin pub use jina::JinaBertModel; pub use jina_code::JinaCodeBertModel; pub use mistral::MistralConfig; -pub use nomic::{apply_rotary, cos_sin, inv_freqs, NomicBertModel, NomicConfig}; +pub use nomic::{apply_rotary, inv_freqs, NomicBertModel, NomicConfig}; pub use qwen2::Qwen2Config; use text_embeddings_backend_core::Batch; diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index 59c3b881..3fd3f645 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -708,6 +708,7 @@ impl Model for NomicBertModel { fn is_padded(&self) -> bool { false } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) }