Skip to content

Commit

Permalink
feature: flashmodernbert
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Dec 25, 2024
1 parent 9de3025 commit d29dec6
Showing 1 changed file with 0 additions and 55 deletions.
55 changes: 0 additions & 55 deletions backends/candle/src/models/flash_modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,8 @@ pub struct FlashModernBertModel {
pool: Pool,
classifier: Option<Box<dyn ClassificationHead + Send>>,

local_attention: usize,
rotary_dim: usize,
rotary_cache: HashMap<bool, (Tensor, Tensor)>,
pad_token_id: u32,
num_attention_heads: usize,

device: Device,
dtype: DType,
Expand Down Expand Up @@ -322,66 +319,14 @@ impl FlashModernBertModel {
final_norm,
pool,
classifier,
local_attention: config.local_attention,
rotary_dim,
rotary_cache,
pad_token_id: config.pad_token_id as u32,
num_attention_heads: config.num_attention_heads,
device: vb.device().clone(),
dtype: vb.dtype(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}

fn get_global_attention_mask(
&self,
attention_mask: Option<&Tensor>,
input_shape: &Shape,
num_attention_heads: usize,
) -> Result<Tensor> {
let extended_attention_mask = if let Some(attention_mask) = attention_mask {
attention_mask.squeeze(2)?
} else {
Tensor::ones(input_shape, DType::F32, &self.device)?
}
.unsqueeze(1)?
.unsqueeze(1)?
.to_dtype(self.dtype)?;

let (bs, seq_len) = input_shape.dims2()?;
let extended_attention_mask =
extended_attention_mask.broadcast_as((bs, num_attention_heads, seq_len, seq_len))?;

Ok(extended_attention_mask)
}

fn get_silding_window_mask(
&self,
attention_mask: &Tensor,
local_attention: usize,
) -> Result<Tensor> {
let attention_mask = attention_mask.to_dtype(DType::U8)?;
let mask_shape = attention_mask.shape();
let (_, _, seq_len, _) = mask_shape.dims4()?;

let rows = Tensor::arange(0, seq_len as i64, attention_mask.device())?.unsqueeze(0)?;
let rows = rows.broadcast_as((seq_len, seq_len))?;

let distance = (&rows - &rows.t()?)?.abs()?;

let window_size = local_attention / 2;
let window_mask = distance
.le(window_size as i64)?
.unsqueeze(0)?
.unsqueeze(0)?
.broadcast_as(mask_shape)?;

let zero_tensor = Tensor::zeros_like(&attention_mask)?;
let sliding_window_mask = attention_mask.where_cond(&window_mask, &zero_tensor)?;

Ok(sliding_window_mask)
}

pub fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
let _enter = self.span.enter();

Expand Down

0 comments on commit d29dec6

Please sign in to comment.