Skip to content

Commit

Permalink
feature: flashmodernbert
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Dec 25, 2024
1 parent eb5932c commit 3253fc7
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 37 deletions.
11 changes: 9 additions & 2 deletions backends/candle/src/flash_attn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ pub(crate) fn flash_attn_varlen(
softmax_scale: f32,
causal: bool,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor, candle::Error> {
let runtime_compute_cap = get_runtime_compute_cap();

if runtime_compute_cap == 75 {
if alibi_slopes.is_some() {
candle::bail!("Flash attention v1 does not support alibi");
}
if window_size_left.is_some() {
if window_size_left.is_some() | window_size_right.is_some() {
candle::bail!("Flash attention v1 does not support attention windowing");
}

Expand All @@ -65,7 +66,13 @@ pub(crate) fn flash_attn_varlen(
{
use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed};

let window_size_right = if causal { Some(0) } else { None };
let window_size_right = if causal {
Some(0)
} else if window_size_right.is_some() {
window_size_right
} else {
None
};

let attention = if let Some(alibi_slopes) = alibi_slopes {
flash_attn_varlen_alibi_windowed(
Expand Down
9 changes: 7 additions & 2 deletions backends/candle/src/layers/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,17 @@ impl LayerNorm {
&hidden_states,
&residual,
&self.weight,
&self.bias,
self.bias.as_ref(),
self.epsilon,
)?;
Ok(result)
} else {
layer_norm(&hidden_states, &self.weight, &self.bias, self.epsilon)
layer_norm(
&hidden_states,
&self.weight,
self.bias.as_ref(),
self.epsilon,
)
}?;
result.reshape(original_shape)
}
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl DistilBertAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ impl GTEAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ impl JinaAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ impl JinaCodeAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ impl MistralAttention {
self.softmax_scale,
true,
self.window_size_left,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
58 changes: 26 additions & 32 deletions backends/candle/src/models/flash_modernbert.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::collections::HashMap;

use crate::flash_attn::flash_attn_varlen;
use crate::layers::{LayerNorm, Linear};
use crate::layers::{apply_rotary, get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::models::modernbert::{
ClassificationHead, ModernBertClassificationHead, ModernBertConfig, ModernBertEmbeddings,
ModernBertMLP,
};
use crate::models::Model;
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
use candle_nn::VarBuilder;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

Expand All @@ -15,13 +17,15 @@ struct ModernBertAttention {

num_attention_heads: usize,
attention_head_size: usize,
softmax_scale: f64,
softmax_scale: f32,
local_attention: usize,
use_local_attention: bool,

span: tracing::Span,
}

impl ModernBertAttention {
pub fn load(vb: VarBuilder, config: &BertConfig) -> Result<Self> {
pub fn load(vb: VarBuilder, index: usize, config: &ModernBertConfig) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let hidden_size = config.hidden_size;

Expand All @@ -43,14 +47,18 @@ impl ModernBertAttention {
};
let wo = Linear::new(wo_weight, wo_bias, None);

let softmax_scale = 1. / (attention_head_size as f64).sqrt();
let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32;

let use_local_attention = index % config.global_attn_every_n_layers != 0;

Ok(Self {
wqkv,
wo,
num_attention_heads: config.num_attention_heads,
attention_head_size,
softmax_scale,
local_attention: config.local_attention / 2 as usize,
use_local_attention,
span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}
Expand Down Expand Up @@ -81,6 +89,12 @@ impl ModernBertAttention {
let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?;
let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?;

let attention_size = if self.use_local_attention {
Some(self.local_attention)
} else {
None
};

let attention = flash_attn_varlen(
&query_layer,
&key_layer,
Expand All @@ -92,7 +106,8 @@ impl ModernBertAttention {
max_s,
self.softmax_scale,
false,
self.local_attention,
attention_size,
attention_size,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down Expand Up @@ -123,7 +138,7 @@ impl ModernBertEncoderLayer {
None
};

let attn = ModernBertAttention::load(vb.pp("attn"), config)?;
let attn = ModernBertAttention::load(vb.pp("attn"), index, config)?;

let mlp_norm = LayerNorm::load(
vb.pp("mlp_norm"),
Expand Down Expand Up @@ -238,7 +253,7 @@ pub struct FlashModernBertModel {
}

impl FlashModernBertModel {
pub fn load(vb: VarBuilder, config: &BertConfig, model_type: ModelType) -> Result<Self> {
pub fn load(vb: VarBuilder, config: &ModernBertConfig, model_type: ModelType) -> Result<Self> {
match vb.device() {
Device::Cuda(_) => {}
_ => candle::bail!("FlashModernBert requires Cuda"),
Expand Down Expand Up @@ -388,8 +403,8 @@ impl FlashModernBertModel {
let cos = cos.index_select(&position_ids, 0)?;
let sin = sin.index_select(&position_ids, 0)?;

let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?;
let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?;
let cos = cos.reshape((batch_size, 1, batch.max_length, self.rotary_dim))?;
let sin = sin.reshape((batch_size, 1, batch.max_length, self.rotary_dim))?;

rotary_cache.insert(use_local_attention, (cos, sin));
}
Expand Down Expand Up @@ -465,28 +480,7 @@ impl FlashModernBertModel {
}
}
Pool::Splade => {
let splade_head = self.splade.as_ref().unwrap();
let relu_log = splade_head.forward(&outputs)?;

if batch_size > 1 {
let results: Result<Vec<Tensor>> = batch
.pooled_indices
.into_iter()
.map(|i| {
let i = i as usize;
let start = batch.cumulative_seq_lengths[i];
let len = batch.cumulative_seq_lengths[i + 1] - start;

relu_log
.narrow(0, start as usize, len as usize)?
.max_keepdim(0)
})
.collect();

Some(Tensor::cat(&results?, 0)?)
} else {
Some(relu_log.max_keepdim(0)?)
}
unreachable!();
}
}
} else {
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl NomicAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(D::Minus2)?;

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ impl Qwen2Attention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
2 changes: 1 addition & 1 deletion backends/candle/src/models/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl ModernBertEmbeddings {
}
}

struct ModernBertMLP {
pub struct ModernBertMLP {
wi: Linear,
wo: Linear,
activation: Option<HiddenAct>,
Expand Down

0 comments on commit 3253fc7

Please sign in to comment.