Skip to content

Commit

Permalink
feature: flashmodernbert
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Dec 25, 2024
1 parent ddac414 commit 8cc7120
Showing 1 changed file with 120 additions and 71 deletions.
191 changes: 120 additions & 71 deletions backends/candle/src/models/flash_modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::flash_attn::flash_attn_varlen;
use crate::layers::{LayerNorm, Linear};
use crate::models::modernbert::{
ClassificationHead, ModernBertClassificationHead, ModernBertConfig, ModernBertEmbeddings,
ModernBertMLP,
};
use crate::models::Model;
use candle::{DType, Device, IndexOp, Result, Tensor};
Expand All @@ -12,10 +13,6 @@ struct ModernBertAttention {
wqkv: Linear,
wo: Linear,

local_attention: (i64, i64),
cos: Tensor,
sin: Tensor,

num_attention_heads: usize,
attention_head_size: usize,
softmax_scale: f64,
Expand All @@ -25,37 +22,45 @@ struct ModernBertAttention {

impl ModernBertAttention {
pub fn load(vb: VarBuilder, config: &BertConfig) -> Result<Self> {
let wi_weight = vb
.pp("Wi")
.get((config.hidden_size, config.intermediate_size * 2), "weight")?;
let wi_bias = vb
.pp("Wi")
.get((config.intermediate_size * 2,), "bias")
.ok();
let wi = Linear::new(wi_weight, wi_bias, None);

let wo_weight = vb
.pp("Wo")
.get((config.intermediate_size * 2, config.hidden_size), "weight")?;
let wo_bias = vb.pp("Wo").get((config.hidden_size,), "bias").ok();
let attention_head_size = config.hidden_size / config.num_attention_heads;
let hidden_size = config.hidden_size;

let wqkv_weight = vb
.pp("Wqkv")
.get((hidden_size * 3, hidden_size), "weight")?;
let wqkv_bias = if config.attention_bias {
vb.pp("Wqkv").get(hidden_size * 3, "bias").ok()
} else {
None
};
let wqkv: Linear = Linear::new(wqkv_weight, wqkv_bias, None);

let wo_weight = vb.pp("Wo").get((hidden_size, hidden_size), "weight")?;
let wo_bias = if config.attention_bias {
vb.pp("Wo").get(hidden_size, "bias").ok()
} else {
None
};
let wo = Linear::new(wo_weight, wo_bias, None);

let activation = Some(config.hidden_activation.clone());
let softmax_scale = 1. / (attention_head_size as f64).sqrt();

Ok(Self {
wi,
wqkv,
wo,
activation,
intermediate_size: config.intermediate_size,
span: tracing::span!(tracing::Level::TRACE, "mlp"),
num_attention_heads: config.num_attention_heads,
attention_head_size,
softmax_scale,
span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}

pub fn forward(
&self,
hidden_states: &Tensor,
cu_seqlens: &Tensor,
cos: &Tensor,
sin: &Tensor,
max_s: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();
Expand All @@ -73,9 +78,8 @@ impl ModernBertAttention {
let key_layer = &qkv[1].contiguous()?;
let value_layer = &qkv[2];

let query_layer =
apply_rotary(query_layer, &self.cos, &self.sin, self.attention_head_size)?;
let key_layer = apply_rotary(key_layer, &self.cos, &self.sin, self.attention_head_size)?;
let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?;
let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?;

let attention = flash_attn_varlen(
&query_layer,
Expand All @@ -88,8 +92,7 @@ impl ModernBertAttention {
max_s,
self.softmax_scale,
false,
self.local_attention[0],
self.local_attention[1],
self.local_attention,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand All @@ -110,7 +113,7 @@ struct ModernBertEncoderLayer {

impl ModernBertEncoderLayer {
pub fn load(vb: VarBuilder, index: usize, config: &ModernBertConfig) -> Result<Self> {
let attn_norm = if index > 0 {
let attn_norm = if index != 0 {
Some(LayerNorm::load(
vb.pp("attn_norm"),
config.hidden_size,
Expand All @@ -120,7 +123,7 @@ impl ModernBertEncoderLayer {
None
};

let attn = ModernBertAttention::load(vb.pp("attn"), index, config)?;
let attn = ModernBertAttention::load(vb.pp("attn"), config)?;

let mlp_norm = LayerNorm::load(
vb.pp("mlp_norm"),
Expand All @@ -143,30 +146,38 @@ impl ModernBertEncoderLayer {
fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
silding_attention_mask: &Tensor,
cu_seqlens: &Tensor,
cos: &Tensor,
sin: &Tensor,
max_s: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.clone();
let residual = hidden_states.clone();

if let Some(attn_norm) = &self.attn_norm {
hidden_states = attn_norm.forward(&hidden_states, None)?;
}
let attn_norm = if let Some(attn_norm) = &self.attn_norm {
attn_norm.forward(hidden_states, None)?
} else {
hidden_states.clone()
};

let attn_outputs = self.attn.forward(&attn_norm, cu_seqlens, cos, sin, max_s)?;

let hidden_states = residual.add(&attn_outputs)?;

let hidden_states =
self.attn
.forward(&hidden_states, attention_mask, silding_attention_mask)?;
let mlp_output = self
.mlp
.forward(&self.mlp_norm.forward(&hidden_states, None)?)?;

hidden_states.broadcast_add(&mlp_output)
hidden_states.add(&mlp_output)
}
}

struct ModernBertEncoder {
layers: Vec<ModernBertEncoderLayer>,

global_attn_every_n_layers: usize,

span: tracing::Span,
}

Expand All @@ -178,22 +189,29 @@ impl ModernBertEncoder {

let span = tracing::span!(tracing::Level::TRACE, "encoder");

Ok(ModernBertEncoder { layers, span })
Ok(ModernBertEncoder {
layers,
global_attn_every_n_layers: config.global_attn_every_n_layers,
span,
})
}

fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
silding_attention_mask: &Tensor,
cu_seqlens: &Tensor,
rotary_cache: &HashMap<bool, (Tensor, Tensor)>,
max_s: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.clone();

for layer in self.layers.iter() {
hidden_states =
layer.forward(&hidden_states, attention_mask, silding_attention_mask)?;
for (index, layer) in self.layers.iter().enumerate() {
let use_local_attention = index % self.global_attn_every_n_layers != 0;
let (cos, sin) = &rotary_cache[&use_local_attention];

hidden_states = layer.forward(&hidden_states, cu_seqlens, cos, sin, max_s)?;
}

Ok(hidden_states)
Expand All @@ -208,6 +226,10 @@ pub struct FlashModernBertModel {
classifier: Option<Box<dyn ClassificationHead + Send>>,

local_attention: usize,
rotary_dim: usize,
rotary_cache: HashMap<bool, (Tensor, Tensor)>,
pad_token_id: u32,
num_attention_heads: usize,

device: Device,
dtype: DType,
Expand Down Expand Up @@ -251,18 +273,45 @@ impl FlashModernBertModel {
let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config)?;
let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config)?;
let final_norm = LayerNorm::load(
vb.pp("final_norm"),
vb.pp("model.final_norm"),
config.hidden_size,
config.norm_eps as f32,
)?;

let rotary_dim = config.hidden_size / config.num_attention_heads;
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();

for use_local_attention in [true, false] {
let rope_theta = if use_local_attention {
config.local_rope_theta
} else {
config.global_rope_theta
};

let max_position_embeddings = if use_local_attention {
config.max_position_embeddings
} else {
config.local_attention
};

let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?;

let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?;

rotary_cache.insert(use_local_attention, (cos, sin));
}

Ok(Self {
embeddings,
encoder,
final_norm,
pool,
classifier,
local_attention: config.local_attention,
rotary_dim,
rotary_cache,
pad_token_id: config.pad_token_id as u32,
num_attention_heads: config.num_attention_heads,
device: vb.device().clone(),
dtype: vb.dtype(),
span: tracing::span!(tracing::Level::TRACE, "model"),
Expand All @@ -273,6 +322,7 @@ impl FlashModernBertModel {
&self,
attention_mask: Option<&Tensor>,
input_shape: &Shape,
num_attention_heads: usize,
) -> Result<Tensor> {
let extended_attention_mask = if let Some(attention_mask) = attention_mask {
attention_mask.squeeze(2)?
Expand All @@ -283,16 +333,9 @@ impl FlashModernBertModel {
.unsqueeze(1)?
.to_dtype(self.dtype)?;

let min_value = match self.dtype {
DType::F32 => f32::MIN as f64,
_ => -65504.0_f64, // f16 minumum value
};

let extended_attention_mask = ((1.0 - extended_attention_mask)? * min_value)?;

let (bs, seq_len) = input_shape.dims2()?;
let extended_attention_mask =
extended_attention_mask.broadcast_as((bs, 1, seq_len, seq_len))?;
extended_attention_mask.broadcast_as((bs, num_attention_heads, seq_len, seq_len))?;

Ok(extended_attention_mask)
}
Expand All @@ -302,28 +345,24 @@ impl FlashModernBertModel {
attention_mask: &Tensor,
local_attention: usize,
) -> Result<Tensor> {
let attention_mask = attention_mask.to_dtype(DType::U8)?;
let mask_shape = attention_mask.shape();
let (_, _, seq_len, _) = mask_shape.dims4()?;

let rows = Tensor::arange(0, seq_len as i64, attention_mask.device())?.unsqueeze(0)?;
let rows = rows.broadcast_as((seq_len, seq_len))?;

let distance = (&rows - &rows.t()?)?.abs()?;

let window_size = local_attention / 2;
let window_mask = distance
.le(window_size as i64)?
.unsqueeze(0)?
.unsqueeze(0)?;

let dtype = attention_mask.dtype();
let min_value = match dtype {
DType::F32 => f32::MIN as f64,
_ => -65504.0, // f16 minimum value
};
.unsqueeze(0)?
.broadcast_as(mask_shape)?;

let inverted_window_mask = window_mask.eq(0_i64)?;
let min_value_tensor = Tensor::full(min_value, mask_shape, attention_mask.device())?;
let sliding_window_mask =
attention_mask.where_cond(&inverted_window_mask, &min_value_tensor)?;
let zero_tensor = Tensor::zeros_like(&attention_mask)?;
let sliding_window_mask = attention_mask.where_cond(&window_mask, &zero_tensor)?;

Ok(sliding_window_mask)
}
Expand All @@ -335,22 +374,32 @@ impl FlashModernBertModel {
let shape = batch.input_ids.len();

let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?;
let cu_seqlens = Tensor::from_vec(
batch.cumulative_seq_lengths.clone(),
batch_size + 1,
&self.device,
)?;

let global_attention_mask =
self.get_global_attention_mask(attention_mask.as_ref(), input_ids.shape())?;
let silding_attention_mask =
self.get_silding_window_mask(&global_attention_mask, self.local_attention)?;
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();
for use_local_attention in [true, false] {
let (cos, sin) = &self.rotary_cache[&use_local_attention];

let cos = cos.index_select(&position_ids, 0)?;
let sin = sin.index_select(&position_ids, 0)?;

let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?;
let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?;

rotary_cache.insert(use_local_attention, (cos, sin));
}

let hidden_states = self.embeddings.forward(&input_ids)?;
let hidden_states = self.encoder.forward(
&hidden_states,
&global_attention_mask,
&silding_attention_mask,
&cu_seqlens,
&rotary_cache,
batch.max_length as usize,
)?;
let outputs = self.final_norm.forward(&hidden_states, None)?;

Expand Down

0 comments on commit 8cc7120

Please sign in to comment.