diff --git a/backends/candle/src/layers/layer_norm.rs b/backends/candle/src/layers/layer_norm.rs index 0c360572..c67b12dd 100644 --- a/backends/candle/src/layers/layer_norm.rs +++ b/backends/candle/src/layers/layer_norm.rs @@ -4,7 +4,7 @@ use candle_nn::VarBuilder; #[derive(Debug)] pub struct LayerNorm { weight: Tensor, - bias: Tensor, + bias: Option, epsilon: f32, span: tracing::Span, } @@ -17,7 +17,8 @@ impl LayerNorm { .or_else(|_| vb.get(hidden_size, "gamma"))?, bias: vb .get(hidden_size, "bias") - .or_else(|_| vb.get(hidden_size, "beta"))?, + .or_else(|_| vb.get(hidden_size, "beta")) + .ok(), epsilon, span: tracing::span!(tracing::Level::TRACE, "layer-norm"), }) @@ -49,7 +50,12 @@ impl LayerNorm { let hidden_states = hidden_states_normed .to_dtype(hidden_states_dtype)? .broadcast_mul(&self.weight)?; - hidden_states.broadcast_add(&self.bias) + + if let Some(bias) = &self.bias { + hidden_states.broadcast_add(bias) + } else { + Ok(hidden_states) + } } Device::Cuda(_) => { #[cfg(feature = "cuda")] @@ -66,12 +72,12 @@ impl LayerNorm { &hidden_states, &residual, &self.weight, - Some(&self.bias), + &self.bias, self.epsilon, )?; Ok(result) } else { - layer_norm(&hidden_states, &self.weight, Some(&self.bias), self.epsilon) + layer_norm(&hidden_states, &self.weight, &self.bias, self.epsilon) }?; result.reshape(original_shape) } diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 8bf68f39..d3c8f831 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -12,8 +12,8 @@ use crate::compute_cap::{ }; use crate::models::{ BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel, - JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, NomicBertModel, NomicConfig, - Qwen2Config, + JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig, + ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, }; #[cfg(feature = "cuda")] use crate::models::{ @@ -63,6 +63,8 @@ enum Config { Qwen2(Qwen2Config), #[serde(rename = "mpnet")] MPNet(MPNetConfig), + #[serde(rename(deserialize = "modernbert"))] + ModernBert(ModernBertConfig), } pub struct CandleBackend { @@ -233,6 +235,12 @@ impl CandleBackend { tracing::info!("Starting MPNet model on {:?}", device); Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?)) } + (Config::ModernBert(config), _) => { + tracing::info!("Starting ModernBert model on {:?}", device); + Ok(Box::new( + ModernBertModel::load(vb, &config, model_type).s()?, + )) + } #[cfg(feature = "cuda")] (Config::Bert(config), Device::Cuda(_)) => { if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) @@ -375,6 +383,25 @@ impl CandleBackend { FlashQwen2Model::load(vb, &config, model_type).s()?, )) } + #[cfg(feature = "cuda")] + (Config::ModernBert(config), Device::Cuda(_)) => { + if cfg!(feature = "flash-attn") + && dtype == DType::F16 + && &std::env::var("USE_FLASH_ATTENTION") + .unwrap_or("True".to_string()) + .to_lowercase() + == "true" + { + return Err(BackendError::Start( + "ModernBert does not support flash attention".to_string(), + )); + } + + tracing::info!("Starting ModernBert model on {:?}", device); + Ok(Box::new( + ModernBERTModel::load(vb, &config, model_type).s()?, + )) + } }; Ok(Self { diff --git a/backends/candle/src/models/flash_modernbert.rs b/backends/candle/src/models/flash_modernbert.rs new file mode 100644 index 00000000..9f15a3ed --- /dev/null +++ b/backends/candle/src/models/flash_modernbert.rs @@ -0,0 +1,496 @@ +use crate::flash_attn::flash_attn_varlen; +use crate::layers::{LayerNorm, Linear}; +use crate::models::modernbert::{ + ClassificationHead, ModernBertClassificationHead, ModernBertConfig, ModernBertEmbeddings, +}; +use crate::models::Model; +use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle_nn::VarBuilder; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; + +struct ModernBertAttention { + wqkv: Linear, + wo: Linear, + + local_attention: (i64, i64), + cos: Tensor, + sin: Tensor, + + num_attention_heads: usize, + attention_head_size: usize, + softmax_scale: f64, + + span: tracing::Span, +} + +impl ModernBertAttention { + pub fn load(vb: VarBuilder, config: &BertConfig) -> Result { + let wi_weight = vb + .pp("Wi") + .get((config.hidden_size, config.intermediate_size * 2), "weight")?; + let wi_bias = vb + .pp("Wi") + .get((config.intermediate_size * 2,), "bias") + .ok(); + let wi = Linear::new(wi_weight, wi_bias, None); + + let wo_weight = vb + .pp("Wo") + .get((config.intermediate_size * 2, config.hidden_size), "weight")?; + let wo_bias = vb.pp("Wo").get((config.hidden_size,), "bias").ok(); + + let wo = Linear::new(wo_weight, wo_bias, None); + + let activation = Some(config.hidden_activation.clone()); + + Ok(Self { + wi, + wo, + activation, + intermediate_size: config.intermediate_size, + span: tracing::span!(tracing::Level::TRACE, "mlp"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + cu_seqlens: &Tensor, + max_s: usize, + ) -> Result { + let _enter = self.span.enter(); + + let qkv = self.wqkv.forward(hidden_states)?; + + let mut new_qkv_shape = qkv.dims().to_vec(); + new_qkv_shape.pop(); + new_qkv_shape.push(self.num_attention_heads * 3); + new_qkv_shape.push(self.attention_head_size); + let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + + let qkv = qkv.chunk(3, 1)?; + let query_layer = &qkv[0].contiguous()?; + let key_layer = &qkv[1].contiguous()?; + let value_layer = &qkv[2]; + + let query_layer = + apply_rotary(query_layer, &self.cos, &self.sin, self.attention_head_size)?; + let key_layer = apply_rotary(key_layer, &self.cos, &self.sin, self.attention_head_size)?; + + let attention = flash_attn_varlen( + &query_layer, + &key_layer, + &value_layer, + None, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + self.softmax_scale, + false, + self.local_attention[0], + self.local_attention[1], + )?; + let attention = attention.flatten_from(candle::D::Minus2)?; + + let hidden_states = self.wo.forward(&attention)?; + + Ok(hidden_states) + } +} + +struct ModernBertEncoderLayer { + attn_norm: Option, + attn: ModernBertAttention, + mlp_norm: LayerNorm, + mlp: ModernBertMLP, + + span: tracing::Span, +} + +impl ModernBertEncoderLayer { + pub fn load(vb: VarBuilder, index: usize, config: &ModernBertConfig) -> Result { + let attn_norm = if index > 0 { + Some(LayerNorm::load( + vb.pp("attn_norm"), + config.hidden_size, + config.norm_eps as f32, + )?) + } else { + None + }; + + let attn = ModernBertAttention::load(vb.pp("attn"), index, config)?; + + let mlp_norm = LayerNorm::load( + vb.pp("mlp_norm"), + config.hidden_size, + config.norm_eps as f32, + )?; + let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?; + + let span = tracing::span!(tracing::Level::TRACE, "layer"); + + Ok(ModernBertEncoderLayer { + attn_norm, + attn, + mlp_norm, + mlp, + span, + }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + silding_attention_mask: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.clone(); + + if let Some(attn_norm) = &self.attn_norm { + hidden_states = attn_norm.forward(&hidden_states, None)?; + } + + let hidden_states = + self.attn + .forward(&hidden_states, attention_mask, silding_attention_mask)?; + let mlp_output = self + .mlp + .forward(&self.mlp_norm.forward(&hidden_states, None)?)?; + + hidden_states.broadcast_add(&mlp_output) + } +} + +struct ModernBertEncoder { + layers: Vec, + span: tracing::Span, +} + +impl ModernBertEncoder { + pub fn load(vb: VarBuilder, config: &ModernBertConfig) -> Result { + let layers = (0..config.num_hidden_layers) + .map(|index| ModernBertEncoderLayer::load(vb.pp(format!("{index}")), index, config)) + .collect::>>()?; + + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + + Ok(ModernBertEncoder { layers, span }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + silding_attention_mask: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.clone(); + + for layer in self.layers.iter() { + hidden_states = + layer.forward(&hidden_states, attention_mask, silding_attention_mask)?; + } + + Ok(hidden_states) + } +} + +pub struct FlashModernBertModel { + embeddings: ModernBertEmbeddings, + encoder: ModernBertEncoder, + final_norm: LayerNorm, + pool: Pool, + classifier: Option>, + + local_attention: usize, + + device: Device, + dtype: DType, + + span: tracing::Span, +} + +impl FlashModernBertModel { + pub fn load(vb: VarBuilder, config: &BertConfig, model_type: ModelType) -> Result { + match vb.device() { + Device::Cuda(_) => {} + _ => candle::bail!("FlashModernBert requires Cuda"), + } + + if vb.dtype() != DType::F16 { + candle::bail!("FlashModernBert requires DType::F16") + } + + let (pool, classifier) = match model_type { + ModelType::Classifier => { + let pool = Pool::Cls; + + let classifier: Box = + Box::new(ModernBertClassificationHead::load(vb.clone(), config)?); + + (pool, Some(classifier)) + } + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for ModernBert") + } + + if pool == Pool::LastToken { + candle::bail!("`last_token` is not supported for ModernBert"); + } + + (pool, None) + } + }; + + let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config)?; + let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config)?; + let final_norm = LayerNorm::load( + vb.pp("final_norm"), + config.hidden_size, + config.norm_eps as f32, + )?; + + Ok(Self { + embeddings, + encoder, + final_norm, + pool, + classifier, + local_attention: config.local_attention, + device: vb.device().clone(), + dtype: vb.dtype(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + fn get_global_attention_mask( + &self, + attention_mask: Option<&Tensor>, + input_shape: &Shape, + ) -> Result { + let extended_attention_mask = if let Some(attention_mask) = attention_mask { + attention_mask.squeeze(2)? + } else { + Tensor::ones(input_shape, DType::F32, &self.device)? + } + .unsqueeze(1)? + .unsqueeze(1)? + .to_dtype(self.dtype)?; + + let min_value = match self.dtype { + DType::F32 => f32::MIN as f64, + _ => -65504.0_f64, // f16 minumum value + }; + + let extended_attention_mask = ((1.0 - extended_attention_mask)? * min_value)?; + + let (bs, seq_len) = input_shape.dims2()?; + let extended_attention_mask = + extended_attention_mask.broadcast_as((bs, 1, seq_len, seq_len))?; + + Ok(extended_attention_mask) + } + + fn get_silding_window_mask( + &self, + attention_mask: &Tensor, + local_attention: usize, + ) -> Result { + let mask_shape = attention_mask.shape(); + let (_, _, seq_len, _) = mask_shape.dims4()?; + + let rows = Tensor::arange(0, seq_len as i64, attention_mask.device())?.unsqueeze(0)?; + let distance = (&rows - &rows.t()?)?.abs()?; + + let window_size = local_attention / 2; + let window_mask = distance + .le(window_size as i64)? + .unsqueeze(0)? + .unsqueeze(0)?; + + let dtype = attention_mask.dtype(); + let min_value = match dtype { + DType::F32 => f32::MIN as f64, + _ => -65504.0, // f16 minimum value + }; + + let inverted_window_mask = window_mask.eq(0_i64)?; + let min_value_tensor = Tensor::full(min_value, mask_shape, attention_mask.device())?; + let sliding_window_mask = + attention_mask.where_cond(&inverted_window_mask, &min_value_tensor)?; + + Ok(sliding_window_mask) + } + + pub fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); + + let batch_size = batch.len(); + let shape = batch.input_ids.len(); + + let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; + let cu_seqlens = Tensor::from_vec( + batch.cumulative_seq_lengths.clone(), + batch_size + 1, + &self.device, + )?; + + let global_attention_mask = + self.get_global_attention_mask(attention_mask.as_ref(), input_ids.shape())?; + let silding_attention_mask = + self.get_silding_window_mask(&global_attention_mask, self.local_attention)?; + + let hidden_states = self.embeddings.forward(&input_ids)?; + let hidden_states = self.encoder.forward( + &hidden_states, + &global_attention_mask, + &silding_attention_mask, + )?; + let outputs = self.final_norm.forward(&hidden_states, None)?; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + match self.pool { + Pool::Cls | Pool::LastToken => { + if batch_size > 1 { + let mut indices = match self.pool { + Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, + Pool::LastToken => { + let end = cu_seqlens.narrow(0, 1, batch_size)?; + (&end - &end.ones_like()?)? + } + _ => unreachable!(), + }; + + if has_raw_requests { + let pooled_indices = Tensor::from_vec( + batch.pooled_indices.clone(), + batch.pooled_indices.len(), + &self.device, + )?; + + indices = indices.index_select(&pooled_indices, 0)? + } + + Some(outputs.index_select(&indices, 0)?) + } else { + Some( + match self.pool { + Pool::Cls => outputs.i(0)?, + Pool::LastToken => { + outputs.i(batch.cumulative_seq_lengths[1] as usize - 1)? + } + _ => unreachable!(), + } + .unsqueeze(0)?, + ) + } + } + Pool::Mean => { + if batch_size > 1 { + let results: Result> = batch + .pooled_indices + .into_iter() + .map(|i| { + let i = i as usize; + let start = batch.cumulative_seq_lengths[i]; + let len = batch.cumulative_seq_lengths[i + 1] - start; + + // Mean + let embeddings = outputs.narrow(0, start as usize, len as usize)?; + embeddings.sum_keepdim(0)? / (len as f64) + }) + .collect(); + + Some(Tensor::cat(&results?, 0)?) + } else { + Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?) + } + } + Pool::Splade => { + let splade_head = self.splade.as_ref().unwrap(); + let relu_log = splade_head.forward(&outputs)?; + + if batch_size > 1 { + let results: Result> = batch + .pooled_indices + .into_iter() + .map(|i| { + let i = i as usize; + let start = batch.cumulative_seq_lengths[i]; + let len = batch.cumulative_seq_lengths[i + 1] - start; + + relu_log + .narrow(0, start as usize, len as usize)? + .max_keepdim(0) + }) + .collect(); + + Some(Tensor::cat(&results?, 0)?) + } else { + Some(relu_log.max_keepdim(0)?) + } + } + } + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + if batch_size > 1 && has_pooling_requests { + let mut final_indices: Vec = Vec::with_capacity(shape); + for i in batch.raw_indices.into_iter() { + let i = i as usize; + let start = batch.cumulative_seq_lengths[i]; + let end = batch.cumulative_seq_lengths[i + 1]; + + for j in start..end { + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; + + Ok((pooled_embeddings, raw_embeddings)) + } +} + +impl Model for FlashModernBertModel { + fn is_padded(&self) -> bool { + false + } + + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) + } + } + } +} diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index 9c67ae8b..b4e3b090 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -9,6 +9,7 @@ mod distilbert; mod jina; mod jina_code; mod mistral; +mod modernbert; mod nomic; #[cfg(feature = "cuda")] @@ -28,11 +29,16 @@ mod flash_distilbert; #[cfg(feature = "cuda")] mod flash_gte; + #[cfg(feature = "cuda")] mod flash_mistral; #[cfg(feature = "cuda")] mod flash_qwen2; + +#[cfg(feature = "cuda")] +mod flash_modernbert; + mod gte; mod mpnet; mod qwen2; @@ -45,6 +51,7 @@ pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, GTEMLP}; pub use jina::JinaBertModel; pub use jina_code::JinaCodeBertModel; pub use mistral::MistralConfig; +pub use modernbert::{ModernBertConfig, ModernBertModel}; pub use mpnet::{MPNetConfig, MPNetModel}; pub use nomic::{NomicBertModel, NomicConfig}; pub use qwen2::Qwen2Config; @@ -74,6 +81,9 @@ pub use flash_gte::FlashGTEModel; #[cfg(feature = "cuda")] pub use flash_qwen2::FlashQwen2Model; +#[cfg(feature = "cuda")] +pub use flash_modernbert::FlashModernBertModel; + pub(crate) trait Model { fn is_padded(&self) -> bool; diff --git a/backends/candle/src/models/modernbert.rs b/backends/candle/src/models/modernbert.rs new file mode 100644 index 00000000..9e1a1f74 --- /dev/null +++ b/backends/candle/src/models/modernbert.rs @@ -0,0 +1,804 @@ +use std::collections::HashMap; + +use crate::layers::{ + apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, LayerNorm, Linear, +}; +use crate::models::Model; +use candle::{DType, Device, IndexOp, Module, Result, Shape, Tensor, D}; +use candle_nn::{Embedding, VarBuilder}; +use serde::Deserialize; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; + +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/configuration_modernbert.py +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct ModernBertConfig { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub hidden_activation: HiddenAct, + pub max_position_embeddings: usize, + pub initializer_range: f64, + pub initializer_cutoff_factor: f64, + pub norm_eps: f64, + pub norm_bias: bool, + pub pad_token_id: usize, + pub eos_token_id: usize, + pub bos_token_id: usize, + pub cls_token_id: usize, + pub sep_token_id: usize, + pub global_rope_theta: f64, + pub attention_bias: bool, + pub attention_dropout: f64, + pub global_attn_every_n_layers: usize, + pub local_attention: usize, + pub local_rope_theta: f64, + pub embedding_dropout: Option, + pub mlp_bias: Option, + pub mlp_dropout: Option, + pub decoder_bias: Option, + pub classifier_pooling: Option, + pub classifier_dropout: Option, + pub classifier_bias: Option, + pub classifier_activation: HiddenAct, + pub deterministic_flash_attn: Option, + pub sparse_prediction: Option, + pub sparse_pred_ignore_index: Option, + pub reference_compile: Option, + pub num_labels: Option, +} + +#[derive(Debug)] +pub struct ModernBertEmbeddings { + tok_embeddings: Embedding, + norm: LayerNorm, + span: tracing::Span, +} + +impl ModernBertEmbeddings { + pub fn load(vb: VarBuilder, config: &ModernBertConfig) -> Result { + Ok(Self { + tok_embeddings: Embedding::new( + vb.pp("tok_embeddings") + .get((config.vocab_size, config.hidden_size), "weight")?, + config.hidden_size, + ), + norm: LayerNorm::load(vb.pp("norm"), config.hidden_size, config.norm_eps as f32)?, + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } + + pub fn forward(&self, input_ids: &Tensor) -> Result { + let _enter = self.span.enter(); + + self.norm + .forward(&self.tok_embeddings.forward(input_ids)?, None) + } +} + +struct ModernBertMLP { + wi: Linear, + wo: Linear, + activation: Option, + intermediate_size: usize, + span: tracing::Span, +} + +impl ModernBertMLP { + pub fn load(vb: VarBuilder, config: &ModernBertConfig) -> Result { + let wi_weight = vb + .pp("Wi") + .get((config.intermediate_size * 2, config.hidden_size), "weight")?; + let wi_bias = vb.pp("Wi").get(config.intermediate_size * 2, "bias").ok(); + let wi = Linear::new(wi_weight, wi_bias, None); + + let wo_weight = vb + .pp("Wo") + .get((config.hidden_size, config.intermediate_size), "weight")?; + let wo_bias = vb.pp("Wo").get(config.hidden_size, "bias").ok(); + + let wo = Linear::new(wo_weight, wo_bias, None); + + let activation = Some(config.hidden_activation.clone()); + + Ok(Self { + wi, + wo, + activation, + intermediate_size: config.intermediate_size, + span: tracing::span!(tracing::Level::TRACE, "mlp"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let hidden_states = self.wi.forward(hidden_states)?; + + let input = hidden_states.narrow(D::Minus1, 0, self.intermediate_size)?; + let gate = + hidden_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; + + let input = if let Some(activation) = &self.activation { + match activation { + HiddenAct::Gelu => input.gelu(), + HiddenAct::Relu => input.relu(), + HiddenAct::Swiglu => input.silu(), + } + } else { + Ok(input) + }; + + let hidden_states = self.wo.forward(&(input * gate)?)?; + + Ok(hidden_states) + } +} + +struct ModernBertAttention { + wqkv: Linear, + wo: Linear, + + num_attention_heads: usize, + attention_head_size: usize, + softmax_scale: f64, + + span: tracing::Span, +} + +impl ModernBertAttention { + pub fn load(vb: VarBuilder, config: &ModernBertConfig) -> Result { + let attention_head_size = config.hidden_size / config.num_attention_heads; + let hidden_size = config.hidden_size; + + let wqkv_weight = vb + .pp("Wqkv") + .get((hidden_size * 3, hidden_size), "weight")?; + let wqkv_bias = if config.attention_bias { + vb.pp("Wqkv").get(hidden_size * 3, "bias").ok() + } else { + None + }; + let wqkv: Linear = Linear::new(wqkv_weight, wqkv_bias, None); + + let wo_weight = vb.pp("Wo").get((hidden_size, hidden_size), "weight")?; + let wo_bias = if config.attention_bias { + vb.pp("Wo").get(hidden_size, "bias").ok() + } else { + None + }; + let wo = Linear::new(wo_weight, wo_bias, None); + + let softmax_scale = 1. / (attention_head_size as f64).sqrt(); + + Ok(Self { + wqkv, + wo, + num_attention_heads: config.num_attention_heads, + attention_head_size, + softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + cos: &Tensor, + sin: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + let device = hidden_states.device(); + + let qkv = self.wqkv.forward(hidden_states)?; + + let mut new_qkv_shape = qkv.dims().to_vec(); + new_qkv_shape.pop(); + new_qkv_shape.push(self.num_attention_heads * 3); + new_qkv_shape.push(self.attention_head_size); + let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + + let qkv = qkv.chunk(3, 1)?; + let query_layer = &qkv[0].contiguous()?; + let key_layer = &qkv[1].contiguous()?; + let value_layer = &qkv[2]; + + let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?; + let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?; + + #[allow(unused_variables)] + let context_layer = + if let (Device::Cuda(_), Some(cublaslt)) = (device, get_cublas_lt_wrapper()) { + #[cfg(feature = "cuda")] + { + let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?; + let key_layer = key_layer.flatten(0, 1)?; + let query_layer = query_layer.flatten(0, 1)?; + let value_layer = value_layer.flatten(0, 1)?; + let attention_mask = attention_mask.flatten(0, 1)?; + + let attention_scores = cublaslt.batch_matmul( + &key_layer, + &query_layer, + Some(attention_mask.as_ref()), + Some(self.softmax_scale as f32), + None, + None, + None, + )?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &value_layer.t()?.contiguous()?, + &attention_probs, + Some(&query_layer), + None, + None, + None, + None, + )?; + + context_layer.reshape(( + batch_size, + self.num_attention_heads, + seq_len, + self.attention_head_size, + )) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } else { + let attn_weights = query_layer.matmul(&key_layer.t()?)?; + let attn_weights = (attn_weights * self.softmax_scale)?; + let attn_weights = attn_weights.add(attention_mask)?; + + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_layer.contiguous()?) + }?; + + let hidden_states = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; + + self.wo.forward(&hidden_states) + } +} + +struct ModernBertEncoderLayer { + attn_norm: Option, + attn: ModernBertAttention, + mlp_norm: LayerNorm, + mlp: ModernBertMLP, + + span: tracing::Span, +} + +impl ModernBertEncoderLayer { + pub fn load(vb: VarBuilder, index: usize, config: &ModernBertConfig) -> Result { + let attn_norm = if index > 0 { + Some(LayerNorm::load( + vb.pp("attn_norm"), + config.hidden_size, + config.norm_eps as f32, + )?) + } else { + None + }; + + let attn = ModernBertAttention::load(vb.pp("attn"), config)?; + + let mlp_norm = LayerNorm::load( + vb.pp("mlp_norm"), + config.hidden_size, + config.norm_eps as f32, + )?; + let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?; + + let span = tracing::span!(tracing::Level::TRACE, "layer"); + + Ok(ModernBertEncoderLayer { + attn_norm, + attn, + mlp_norm, + mlp, + span, + }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + cos: &Tensor, + sin: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.clone(); + + if let Some(attn_norm) = &self.attn_norm { + hidden_states = attn_norm.forward(&hidden_states, None)?; + } + + let hidden_states = self + .attn + .forward(&hidden_states, attention_mask, cos, sin)?; + let mlp_output = self + .mlp + .forward(&self.mlp_norm.forward(&hidden_states, None)?)?; + + hidden_states.broadcast_add(&mlp_output) + } +} + +struct ModernBertEncoder { + layers: Vec, + + global_attn_every_n_layers: usize, + + span: tracing::Span, +} + +impl ModernBertEncoder { + pub fn load(vb: VarBuilder, config: &ModernBertConfig) -> Result { + let layers = (0..config.num_hidden_layers) + .map(|index| ModernBertEncoderLayer::load(vb.pp(format!("{index}")), index, config)) + .collect::>>()?; + + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + + Ok(ModernBertEncoder { + layers, + global_attn_every_n_layers: config.global_attn_every_n_layers, + span, + }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + silding_attention_mask: &Tensor, + rotary_cache: &HashMap, + ) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.clone(); + + for (index, layer) in self.layers.iter().enumerate() { + let use_local_attention = index % self.global_attn_every_n_layers != 0; + let (cos, sin) = &rotary_cache[&use_local_attention]; + let attention_mask = if use_local_attention { + silding_attention_mask + } else { + attention_mask + }; + + hidden_states = layer.forward(&hidden_states, attention_mask, cos, sin)?; + } + + Ok(hidden_states) + } +} + +pub trait ClassificationHead { + fn forward(&self, hidden_states: &Tensor) -> Result; +} + +pub struct ModernBertClassificationHead { + dense: Linear, + norm: LayerNorm, + classifier: Linear, + span: tracing::Span, +} + +impl ModernBertClassificationHead { + pub(crate) fn load(vb: VarBuilder, config: &ModernBertConfig) -> Result { + let dense_weight = vb + .pp("dense") + .get((config.hidden_size, config.hidden_size), "weight")?; + let dense_bias = vb.pp("dense").get(config.hidden_size, "bias").ok(); + let dense = Linear::new( + dense_weight, + dense_bias, + Some(config.classifier_activation.clone()), + ); + + let norm = LayerNorm::load(vb.pp("norm"), config.hidden_size, config.norm_eps as f32)?; + + let classifier_weight = vb.pp("dense").get( + (config.num_labels.unwrap_or(1), config.hidden_size), + "weight", + )?; + let classifier_bias = vb + .pp("dense") + .get(config.num_labels.unwrap_or(1), "bias") + .ok(); + let classifier = Linear::new(classifier_weight, classifier_bias, None); + + Ok(Self { + dense, + norm, + classifier, + span: tracing::span!(tracing::Level::TRACE, "classifier"), + }) + } +} + +impl ClassificationHead for ModernBertClassificationHead { + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let hidden_states = hidden_states.unsqueeze(1)?; + + let hidden_states = self.dense.forward(&hidden_states)?; + let hidden_states = self.norm.forward(&hidden_states, None)?; + let hidden_states = self.classifier.forward(&hidden_states)?; + + let hidden_states = hidden_states.squeeze(1)?; + + Ok(hidden_states) + } +} + +pub struct ModernBertModel { + embeddings: ModernBertEmbeddings, + encoder: ModernBertEncoder, + final_norm: LayerNorm, + pool: Pool, + classifier: Option>, + + local_attention: usize, + rotary_dim: usize, + rotary_cache: HashMap, + pad_token_id: u32, + num_attention_heads: usize, + + device: Device, + dtype: DType, + + span: tracing::Span, +} + +impl ModernBertModel { + pub fn load(vb: VarBuilder, config: &ModernBertConfig, model_type: ModelType) -> Result { + let (pool, classifier) = match model_type { + ModelType::Classifier => { + let pool = Pool::Cls; + + let classifier: Box = + Box::new(ModernBertClassificationHead::load(vb.clone(), config)?); + + (pool, Some(classifier)) + } + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for ModernBert") + } + + if pool == Pool::LastToken { + candle::bail!("`last_token` is not supported for ModernBert"); + } + + (pool, None) + } + }; + + let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config)?; + let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config)?; + let final_norm = LayerNorm::load( + vb.pp("model.final_norm"), + config.hidden_size, + config.norm_eps as f32, + )?; + + let rotary_dim = config.hidden_size / config.num_attention_heads; + let mut rotary_cache: HashMap = HashMap::new(); + + for use_local_attention in [true, false] { + let rope_theta = if use_local_attention { + config.local_rope_theta + } else { + config.global_rope_theta + }; + + let max_position_embeddings = if use_local_attention { + config.max_position_embeddings + } else { + config.local_attention + }; + + let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?; + + let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?; + + rotary_cache.insert(use_local_attention, (cos, sin)); + } + + Ok(Self { + embeddings, + encoder, + final_norm, + pool, + classifier, + local_attention: config.local_attention, + rotary_dim, + rotary_cache, + pad_token_id: config.pad_token_id as u32, + num_attention_heads: config.num_attention_heads, + device: vb.device().clone(), + dtype: vb.dtype(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + fn get_global_attention_mask( + &self, + attention_mask: Option<&Tensor>, + input_shape: &Shape, + num_attention_heads: usize, + ) -> Result { + let extended_attention_mask = if let Some(attention_mask) = attention_mask { + attention_mask.squeeze(2)? + } else { + Tensor::ones(input_shape, DType::F32, &self.device)? + } + .unsqueeze(1)? + .unsqueeze(1)? + .to_dtype(self.dtype)?; + + let (bs, seq_len) = input_shape.dims2()?; + let extended_attention_mask = + extended_attention_mask.broadcast_as((bs, num_attention_heads, seq_len, seq_len))?; + + Ok(extended_attention_mask) + } + + fn get_silding_window_mask( + &self, + attention_mask: &Tensor, + local_attention: usize, + ) -> Result { + let attention_mask = attention_mask.to_dtype(DType::U8)?; + let mask_shape = attention_mask.shape(); + let (_, _, seq_len, _) = mask_shape.dims4()?; + + let rows = Tensor::arange(0, seq_len as i64, attention_mask.device())?.unsqueeze(0)?; + let rows = rows.broadcast_as((seq_len, seq_len))?; + + let distance = (&rows - &rows.t()?)?.abs()?; + + let window_size = local_attention / 2; + let window_mask = distance + .le(window_size as i64)? + .unsqueeze(0)? + .unsqueeze(0)? + .broadcast_as(mask_shape)?; + + let zero_tensor = Tensor::zeros_like(&attention_mask)?; + let sliding_window_mask = attention_mask.where_cond(&window_mask, &zero_tensor)?; + + Ok(sliding_window_mask) + } + + fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); + + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + let shape = (batch_size, max_length); + + let (input_ids, input_lengths, position_ids, attention_mask) = if batch_size > 1 { + let elems = batch_size * max_length; + + let mut input_ids = Vec::with_capacity(elems); + let mut position_ids = Vec::with_capacity(elems); + let mut attention_mask = Vec::with_capacity(elems); + let mut input_lengths = Vec::with_capacity(batch_size); + + let mut masking = false; + + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + input_lengths.push(seq_length as f32); + + for j in start..end { + input_ids.push(batch.input_ids[j]); + position_ids.push(batch.position_ids[j]); + attention_mask.push(1.0_f32); + } + + let padding = batch.max_length - seq_length; + if padding > 0 { + masking = true; + for _ in 0..padding { + input_ids.push(self.pad_token_id); + position_ids.push(0); + attention_mask.push(0.0_f32); + } + } + } + + let attention_mask = match masking { + true => { + let attention_mask = Tensor::from_vec( + attention_mask, + (batch_size, max_length, 1), + &self.device, + )? + .to_dtype(self.dtype)?; + + Some(attention_mask) + } + false => None, + }; + + (input_ids, input_lengths, position_ids, attention_mask) + } else { + ( + batch.input_ids, + vec![batch.max_length as f32], + batch.position_ids, + None, + ) + }; + + let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(position_ids, batch_size * max_length, &self.device)?; + let mut input_lengths = + Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?; + + let global_attention_mask = self + .get_global_attention_mask( + attention_mask.as_ref(), + input_ids.shape(), + self.num_attention_heads, + )? + .to_dtype(self.dtype)?; + let silding_attention_mask = self + .get_silding_window_mask(&global_attention_mask, self.local_attention)? + .to_dtype(self.dtype)?; + + let min_value = match self.dtype { + DType::F32 => f32::MIN as f64, + _ => -65504.0, // f16 minimum value + }; + + let global_attention_mask = ((1.0 - global_attention_mask)? * min_value)?; + let silding_attention_mask = ((1.0 - silding_attention_mask)? * min_value)?; + + let mut rotary_cache: HashMap = HashMap::new(); + for use_local_attention in [true, false] { + let (cos, sin) = &self.rotary_cache[&use_local_attention]; + + let cos = cos.index_select(&position_ids, 0)?; + let sin = sin.index_select(&position_ids, 0)?; + + let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?; + let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?; + + rotary_cache.insert(use_local_attention, (cos, sin)); + } + + let hidden_states = self.embeddings.forward(&input_ids)?; + let hidden_states = self.encoder.forward( + &hidden_states, + &global_attention_mask, + &silding_attention_mask, + &rotary_cache, + )?; + let outputs = self.final_norm.forward(&hidden_states, None)?; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + let pooled_indices_length = batch.pooled_indices.len(); + let mut outputs = outputs.clone(); + + // Only use pooled_indices if at least one member of the batch ask for raw embeddings + let pooled_indices = if has_raw_requests { + let pooled_indices = + Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?; + + // Select values in the batch + outputs = outputs.index_select(&pooled_indices, 0)?; + Some(pooled_indices) + } else { + None + }; + + let pooled_embeddings = match self.pool { + // CLS pooling + Pool::Cls => outputs.i((.., 0))?, + // Last token pooling is not supported for this model + Pool::LastToken | Pool::Splade => unreachable!(), + // Mean pooling + Pool::Mean => { + if let Some(ref attention_mask) = attention_mask { + let mut attention_mask = attention_mask.clone(); + + if let Some(pooled_indices) = pooled_indices { + // Select values in the batch + attention_mask = attention_mask.index_select(&pooled_indices, 0)?; + input_lengths = input_lengths.index_select(&pooled_indices, 0)?; + }; + + // Mask padded values + outputs = outputs.broadcast_mul(&attention_mask)?; + } + + (outputs.sum(1)?.broadcast_div(&input_lengths))? + } + }; + Some(pooled_embeddings) + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + // Reshape outputs + let (b, l, h) = outputs.shape().dims3()?; + let outputs = outputs.reshape((b * l, h))?; + + // We need to remove the padding tokens only if batch_size > 1 and there are some + // member of the batch that require pooling + // or if batch_size > 1 and the members of the batch have different lengths + if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 { + let mut final_indices: Vec = Vec::with_capacity(batch_size * max_length); + + for i in batch.raw_indices.into_iter() { + let start = i * batch.max_length; + let i = i as usize; + let length = + batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]; + + for j in start..start + length { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + // Select the tokens with final indices + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; + + Ok((pooled_embeddings, raw_embeddings)) + } +} + +impl Model for ModernBertModel { + fn is_padded(&self) -> bool { + true + } + + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) + } + } + } +}