diff --git a/backends/candle/src/flash_attn.rs b/backends/candle/src/flash_attn.rs index f2016928..8dbe58cf 100644 --- a/backends/candle/src/flash_attn.rs +++ b/backends/candle/src/flash_attn.rs @@ -32,6 +32,7 @@ pub(crate) fn flash_attn_varlen( softmax_scale: f32, causal: bool, window_size_left: Option, + window_size_right: Option, ) -> Result { let runtime_compute_cap = get_runtime_compute_cap(); @@ -39,7 +40,7 @@ pub(crate) fn flash_attn_varlen( if alibi_slopes.is_some() { candle::bail!("Flash attention v1 does not support alibi"); } - if window_size_left.is_some() { + if window_size_left.is_some() | window_size_right.is_some() { candle::bail!("Flash attention v1 does not support attention windowing"); } @@ -65,7 +66,13 @@ pub(crate) fn flash_attn_varlen( { use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed}; - let window_size_right = if causal { Some(0) } else { None }; + let window_size_right = if causal { + Some(0) + } else if window_size_right.is_some() { + window_size_right + } else { + None + }; let attention = if let Some(alibi_slopes) = alibi_slopes { flash_attn_varlen_alibi_windowed( diff --git a/backends/candle/src/layers/layer_norm.rs b/backends/candle/src/layers/layer_norm.rs index c67b12dd..932e1b95 100644 --- a/backends/candle/src/layers/layer_norm.rs +++ b/backends/candle/src/layers/layer_norm.rs @@ -72,12 +72,17 @@ impl LayerNorm { &hidden_states, &residual, &self.weight, - &self.bias, + self.bias.as_ref(), self.epsilon, )?; Ok(result) } else { - layer_norm(&hidden_states, &self.weight, &self.bias, self.epsilon) + layer_norm( + &hidden_states, + &self.weight, + self.bias.as_ref(), + self.epsilon, + ) }?; result.reshape(original_shape) } diff --git a/backends/candle/src/models/flash_distilbert.rs b/backends/candle/src/models/flash_distilbert.rs index 7f060601..b107e1e3 100644 --- a/backends/candle/src/models/flash_distilbert.rs +++ b/backends/candle/src/models/flash_distilbert.rs @@ -85,6 +85,7 @@ impl DistilBertAttention { self.softmax_scale, false, None, + None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index b9bb4cdf..f4aac07e 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -87,6 +87,7 @@ impl GTEAttention { self.softmax_scale, false, None, + None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index 83c5b0a4..947fac6c 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -106,6 +106,7 @@ impl JinaAttention { self.softmax_scale, false, None, + None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; diff --git a/backends/candle/src/models/flash_jina_code.rs b/backends/candle/src/models/flash_jina_code.rs index 56ab1976..745786dc 100644 --- a/backends/candle/src/models/flash_jina_code.rs +++ b/backends/candle/src/models/flash_jina_code.rs @@ -142,6 +142,7 @@ impl JinaCodeAttention { self.softmax_scale, false, None, + None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; diff --git a/backends/candle/src/models/flash_mistral.rs b/backends/candle/src/models/flash_mistral.rs index 70538269..19955259 100644 --- a/backends/candle/src/models/flash_mistral.rs +++ b/backends/candle/src/models/flash_mistral.rs @@ -105,6 +105,7 @@ impl MistralAttention { self.softmax_scale, true, self.window_size_left, + None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; diff --git a/backends/candle/src/models/flash_modernbert.rs b/backends/candle/src/models/flash_modernbert.rs index 07fdc477..73aeeeb4 100644 --- a/backends/candle/src/models/flash_modernbert.rs +++ b/backends/candle/src/models/flash_modernbert.rs @@ -1,11 +1,13 @@ +use std::collections::HashMap; + use crate::flash_attn::flash_attn_varlen; -use crate::layers::{LayerNorm, Linear}; +use crate::layers::{apply_rotary, get_cos_sin, get_inv_freqs, LayerNorm, Linear}; use crate::models::modernbert::{ ClassificationHead, ModernBertClassificationHead, ModernBertConfig, ModernBertEmbeddings, ModernBertMLP, }; use crate::models::Model; -use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle::{DType, Device, IndexOp, Result, Shape, Tensor}; use candle_nn::VarBuilder; use text_embeddings_backend_core::{Batch, ModelType, Pool}; @@ -15,13 +17,15 @@ struct ModernBertAttention { num_attention_heads: usize, attention_head_size: usize, - softmax_scale: f64, + softmax_scale: f32, + local_attention: usize, + use_local_attention: bool, span: tracing::Span, } impl ModernBertAttention { - pub fn load(vb: VarBuilder, config: &BertConfig) -> Result { + pub fn load(vb: VarBuilder, index: usize, config: &ModernBertConfig) -> Result { let attention_head_size = config.hidden_size / config.num_attention_heads; let hidden_size = config.hidden_size; @@ -43,7 +47,9 @@ impl ModernBertAttention { }; let wo = Linear::new(wo_weight, wo_bias, None); - let softmax_scale = 1. / (attention_head_size as f64).sqrt(); + let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32; + + let use_local_attention = index % config.global_attn_every_n_layers != 0; Ok(Self { wqkv, @@ -51,6 +57,8 @@ impl ModernBertAttention { num_attention_heads: config.num_attention_heads, attention_head_size, softmax_scale, + local_attention: config.local_attention / 2 as usize, + use_local_attention, span: tracing::span!(tracing::Level::TRACE, "attention"), }) } @@ -81,6 +89,12 @@ impl ModernBertAttention { let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?; let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?; + let attention_size = if self.use_local_attention { + Some(self.local_attention) + } else { + None + }; + let attention = flash_attn_varlen( &query_layer, &key_layer, @@ -92,7 +106,8 @@ impl ModernBertAttention { max_s, self.softmax_scale, false, - self.local_attention, + attention_size, + attention_size, )?; let attention = attention.flatten_from(candle::D::Minus2)?; @@ -123,7 +138,7 @@ impl ModernBertEncoderLayer { None }; - let attn = ModernBertAttention::load(vb.pp("attn"), config)?; + let attn = ModernBertAttention::load(vb.pp("attn"), index, config)?; let mlp_norm = LayerNorm::load( vb.pp("mlp_norm"), @@ -238,7 +253,7 @@ pub struct FlashModernBertModel { } impl FlashModernBertModel { - pub fn load(vb: VarBuilder, config: &BertConfig, model_type: ModelType) -> Result { + pub fn load(vb: VarBuilder, config: &ModernBertConfig, model_type: ModelType) -> Result { match vb.device() { Device::Cuda(_) => {} _ => candle::bail!("FlashModernBert requires Cuda"), @@ -388,8 +403,8 @@ impl FlashModernBertModel { let cos = cos.index_select(&position_ids, 0)?; let sin = sin.index_select(&position_ids, 0)?; - let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?; - let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?; + let cos = cos.reshape((batch_size, 1, batch.max_length, self.rotary_dim))?; + let sin = sin.reshape((batch_size, 1, batch.max_length, self.rotary_dim))?; rotary_cache.insert(use_local_attention, (cos, sin)); } @@ -465,28 +480,7 @@ impl FlashModernBertModel { } } Pool::Splade => { - let splade_head = self.splade.as_ref().unwrap(); - let relu_log = splade_head.forward(&outputs)?; - - if batch_size > 1 { - let results: Result> = batch - .pooled_indices - .into_iter() - .map(|i| { - let i = i as usize; - let start = batch.cumulative_seq_lengths[i]; - let len = batch.cumulative_seq_lengths[i + 1] - start; - - relu_log - .narrow(0, start as usize, len as usize)? - .max_keepdim(0) - }) - .collect(); - - Some(Tensor::cat(&results?, 0)?) - } else { - Some(relu_log.max_keepdim(0)?) - } + unreachable!(); } } } else { diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index 057db768..5683a86c 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -83,6 +83,7 @@ impl NomicAttention { self.softmax_scale, false, None, + None, )?; let attention = attention.flatten_from(D::Minus2)?; diff --git a/backends/candle/src/models/flash_qwen2.rs b/backends/candle/src/models/flash_qwen2.rs index c6662047..7053c4c8 100644 --- a/backends/candle/src/models/flash_qwen2.rs +++ b/backends/candle/src/models/flash_qwen2.rs @@ -113,6 +113,7 @@ impl Qwen2Attention { self.softmax_scale, false, None, + None, )?; let attention = attention.flatten_from(candle::D::Minus2)?; diff --git a/backends/candle/src/models/modernbert.rs b/backends/candle/src/models/modernbert.rs index afdef81f..98fb185c 100644 --- a/backends/candle/src/models/modernbert.rs +++ b/backends/candle/src/models/modernbert.rs @@ -77,7 +77,7 @@ impl ModernBertEmbeddings { } } -struct ModernBertMLP { +pub struct ModernBertMLP { wi: Linear, wo: Linear, activation: Option,