Skip to content

Commit

Permalink
fix: gte
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Nov 30, 2024
1 parent d5430f5 commit 005a693
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 20 deletions.
35 changes: 16 additions & 19 deletions backends/candle/src/models/gte.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear};
use crate::models::{apply_rotary, cos_sin, inv_freqs, Model, PositionEmbeddingType};
use candle::{Device, IndexOp, Result, Tensor, D};
use crate::models::{apply_rotary, inv_freqs, Model, PositionEmbeddingType};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
Expand Down Expand Up @@ -108,12 +108,6 @@ impl GTEAttention {
if let (Device::Cuda(_), Some(cublaslt)) = (device, get_cublas_lt_wrapper()) {
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
let (batch_size, _, seq_len, _) = k.shape().dims4()?;
let k = k.flatten(0, 1)?;
let q = q.flatten(0, 1)?;
let v = v.flatten(0, 1)?;

// Batch matrix multiplication
// Fuse softmax scale and attention_bias add
let attention_scores = cublaslt.batch_matmul(
Expand All @@ -127,7 +121,7 @@ impl GTEAttention {
)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;

let context_layer = cublaslt.batch_matmul(
cublaslt.batch_matmul(
&v.t()?.contiguous()?,
&attention_probs,
// We save one allocation
Expand All @@ -136,15 +130,7 @@ impl GTEAttention {
None,
None,
None,
)?;

// Reshape to dims4
context_layer.reshape((
batch_size,
self.num_attention_heads,
seq_len,
self.attention_head_size,
))
)
}
#[cfg(not(feature = "cuda"))]
{
Expand All @@ -157,7 +143,7 @@ impl GTEAttention {
attention_probs.matmul(&v.contiguous()?)
}?;

let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;
let context_layer = context_layer.flatten_from(D::Minus2)?;

let hidden_states = self.o_proj.forward(&context_layer)?;

Expand Down Expand Up @@ -580,3 +566,14 @@ impl Model for GTEModel {
}
}
}

fn cos_sin(length: usize, inv_freqs: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
let t = Tensor::arange(0u32, length as u32, inv_freqs.device())?
.to_dtype(DType::F32)?
.reshape((length, 1))?;

let freqs = t.matmul(inv_freqs)?;
let cos = freqs.cos()?.to_dtype(dtype)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
Ok((cos, sin))
}
2 changes: 1 addition & 1 deletion backends/candle/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, NTKScaling, RopeScalin
pub use jina::JinaBertModel;
pub use jina_code::JinaCodeBertModel;
pub use mistral::MistralConfig;
pub use nomic::{apply_rotary, cos_sin, inv_freqs, NomicBertModel, NomicConfig};
pub use nomic::{apply_rotary, inv_freqs, NomicBertModel, NomicConfig};
pub use qwen2::Qwen2Config;
use text_embeddings_backend_core::Batch;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,7 @@ impl Model for NomicBertModel {
fn is_padded(&self) -> bool {
false
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}
Expand Down

0 comments on commit 005a693

Please sign in to comment.