Skip to content

Commit

Permalink
fix: GTE
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Nov 30, 2024
1 parent e0e3bc6 commit ebe6ac9
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 62 deletions.
6 changes: 4 additions & 2 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::layers::{LayerNorm, Linear};
use crate::models::{
GTEClassificationHead, GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling, GTEMLP,
};
Expand Down Expand Up @@ -132,7 +132,9 @@ impl GTELayer {
max_s: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();
let attn_output = self.attention.forward(&hidden_states, cos, sin)?;
let attn_output = self
.attention
.forward(&hidden_states, cu_seqlens, cos, sin, max_s)?;
let normed_attn_res_output = self
.attention_layer_norm
.forward(&attn_output, Some(hidden_states))?;
Expand Down
111 changes: 51 additions & 60 deletions backends/candle/src/models/gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,67 +104,58 @@ impl GTEAttention {
let k = apply_rotary(&k, cos, sin, self.attention_head_size)?;

#[allow(unused_variables)]
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) =
(device, get_cublas_lt_wrapper())
{
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
let (batch_size, _, seq_len, _) = k.shape().dims4()?;
let k = k.flatten(0, 1)?;
let q = q.flatten(0, 1)?;
let v = v.flatten(0, 1)?;
let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?;

// If attention_bias is set, we fuse the add by giving it as the output matrix
// and setting beta to 1.0
let beta = match attention_bias.is_some() {
true => Some(1.0),
false => None,
};

// Batch matrix multiplication
// Fuse softmax scale and attention_bias add
let attention_scores = cublaslt.batch_matmul(
&k,
&q,
attention_bias.as_ref(),
Some(self.softmax_scale as f32),
beta,
None,
None,
)?;
let context_layer =
if let (Device::Cuda(_), Some(cublaslt)) = (device, get_cublas_lt_wrapper()) {
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
let (batch_size, _, seq_len, _) = k.shape().dims4()?;
let k = k.flatten(0, 1)?;
let q = q.flatten(0, 1)?;
let v = v.flatten(0, 1)?;

// Batch matrix multiplication
// Fuse softmax scale and attention_bias add
let attention_scores = cublaslt.batch_matmul(
&k,
&q,
attention_bias.as_ref(),
Some(self.softmax_scale as f32),
None,
None,
None,
)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;

let context_layer = cublaslt.batch_matmul(
&v.t()?.contiguous()?,
&attention_probs,
// We save one allocation
Some(&q),
None,
None,
None,
None,
)?;

// Reshape to dims4
context_layer.reshape((
batch_size,
self.num_attention_heads,
seq_len,
self.attention_head_size,
))
}
#[cfg(not(feature = "cuda"))]
{
candle::bail!("`cuda` feature is not enabled")
}
} else {
let attention_scores = q.matmul(&k.t()?)?;
let attention_scores = (attention_scores * self.softmax_scale as f64)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;

let context_layer = cublaslt.batch_matmul(
&v.t()?.contiguous()?,
&attention_probs,
// We save one allocation
Some(&q),
None,
None,
None,
None,
)?;

// Reshape to dims4
context_layer.reshape((
batch_size,
self.num_attention_heads,
seq_len,
self.attention_head_size,
))
}
#[cfg(not(feature = "cuda"))]
{
candle::bail!("`cuda` feature is not enabled")
}
} else {
let attention_scores = q.matmul(&k.t()?)?;
let attention_scores = (attention_scores * self.softmax_scale as f64)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
attention_probs.matmul(&v.contiguous()?)
}?;
attention_probs.matmul(&v.contiguous()?)
}?;

let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;

Expand Down

0 comments on commit ebe6ac9

Please sign in to comment.