Skip to content

Commit

Permalink
feature: flashmodernbert
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Dec 25, 2024
1 parent 3253fc7 commit 9de3025
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ impl BertAttention {
self.softmax_scale,
false,
None,
None,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

Expand Down
7 changes: 4 additions & 3 deletions backends/candle/src/models/flash_modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ impl FlashModernBertModel {

let batch_size = batch.len();
let shape = batch.input_ids.len();
let max_length = batch.max_length as usize;

let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?;
Expand All @@ -403,8 +404,8 @@ impl FlashModernBertModel {
let cos = cos.index_select(&position_ids, 0)?;
let sin = sin.index_select(&position_ids, 0)?;

let cos = cos.reshape((batch_size, 1, batch.max_length, self.rotary_dim))?;
let sin = sin.reshape((batch_size, 1, batch.max_length, self.rotary_dim))?;
let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?;
let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?;

rotary_cache.insert(use_local_attention, (cos, sin));
}
Expand All @@ -414,7 +415,7 @@ impl FlashModernBertModel {
&hidden_states,
&cu_seqlens,
&rotary_cache,
batch.max_length as usize,
max_length,
)?;
let outputs = self.final_norm.forward(&hidden_states, None)?;

Expand Down
2 changes: 1 addition & 1 deletion backends/candle/src/models/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ impl ModernBertModel {
} else {
(
batch.input_ids,
vec![batch.max_length as f32],
vec![max_length as f32],
batch.position_ids,
None,
)
Expand Down

0 comments on commit 9de3025

Please sign in to comment.