Skip to content

Commit

Permalink
fix: modernbert
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Dec 25, 2024
1 parent 3b0701e commit ad25832
Show file tree
Hide file tree
Showing 9 changed files with 3,237 additions and 43 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ hf-hub = { version = "0.3.2", features = ["tokio", "online"], default-features =
metrics = "0.23"
nohash-hasher = "0.2"
num_cpus = "1.16.0"
tokenizers = { version = "0.19.1", default-features = false, features = ["onig", "esaxx_fast"] }
tokenizers = { version = "0.21.0", default-features = false, features = ["onig", "esaxx_fast"] }
tokio = { version = "1.25", features = ["rt", "rt-multi-thread", "parking_lot", "sync", "signal"] }
tracing = "0.1"
serde = { version = "1.0", features = ["serde_derive"] }
Expand Down
19 changes: 0 additions & 19 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,25 +383,6 @@ impl CandleBackend {
FlashQwen2Model::load(vb, &config, model_type).s()?,
))
}
#[cfg(feature = "cuda")]
(Config::ModernBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
&& &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
== "true"
{
return Err(BackendError::Start(
"ModernBert does not support flash attention".to_string(),
));
}

tracing::info!("Starting ModernBert model on {:?}", device);
Ok(Box::new(
ModernBERTModel::load(vb, &config, model_type).s()?,
))
}
};

Ok(Self {
Expand Down
28 changes: 12 additions & 16 deletions backends/candle/src/models/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ impl ModernBertAttention {
let attn_weights = query_layer.matmul(&key_layer.t()?)?;
let attn_weights = (attn_weights * self.softmax_scale)?;
let attn_weights = attn_weights.add(attention_mask)?;

let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_layer.contiguous()?)
}?;
Expand All @@ -277,7 +276,7 @@ struct ModernBertEncoderLayer {

impl ModernBertEncoderLayer {
pub fn load(vb: VarBuilder, index: usize, config: &ModernBertConfig) -> Result<Self> {
let attn_norm = if index > 0 {
let attn_norm = if index != 0 {
Some(LayerNorm::load(
vb.pp("attn_norm"),
config.hidden_size,
Expand Down Expand Up @@ -316,20 +315,23 @@ impl ModernBertEncoderLayer {
) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.clone();
let residual = hidden_states.clone();

if let Some(attn_norm) = &self.attn_norm {
hidden_states = attn_norm.forward(&hidden_states, None)?;
}
let attn_norm = if let Some(attn_norm) = &self.attn_norm {
attn_norm.forward(hidden_states, None)?
} else {
hidden_states.clone()
};

let attn_outputs = self.attn.forward(&attn_norm, attention_mask, cos, sin)?;

let hidden_states = residual.add(&attn_outputs)?;

let hidden_states = self
.attn
.forward(&hidden_states, attention_mask, cos, sin)?;
let mlp_output = self
.mlp
.forward(&self.mlp_norm.forward(&hidden_states, None)?)?;

hidden_states.broadcast_add(&mlp_output)
hidden_states.add(&mlp_output)
}
}

Expand Down Expand Up @@ -714,22 +716,17 @@ impl ModernBertModel {
};

let pooled_embeddings = match self.pool {
// CLS pooling
Pool::Cls => outputs.i((.., 0))?,
// Last token pooling is not supported for this model
Pool::LastToken | Pool::Splade => unreachable!(),
// Mean pooling
Pool::Mean => {
if let Some(ref attention_mask) = attention_mask {
let mut attention_mask = attention_mask.clone();

if let Some(pooled_indices) = pooled_indices {
// Select values in the batch
attention_mask = attention_mask.index_select(&pooled_indices, 0)?;
input_lengths = input_lengths.index_select(&pooled_indices, 0)?;
};

// Mask padded values
outputs = outputs.broadcast_mul(&attention_mask)?;
}

Expand All @@ -742,7 +739,6 @@ impl ModernBertModel {
};

let raw_embeddings = if has_raw_requests {
// Reshape outputs
let (b, l, h) = outputs.shape().dims3()?;
let outputs = outputs.reshape((b * l, h))?;

Expand Down
6 changes: 3 additions & 3 deletions backends/candle/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ pub fn load_tokenizer(model_root: &Path) -> Result<Tokenizer> {
// We are forced to clone since `Tokenizer` does not have a `get_mut` for `pre_tokenizer`
let mut m = m.clone();
m.set_prepend_scheme(PrependScheme::First);
tokenizer.with_pre_tokenizer(PreTokenizerWrapper::Metaspace(m));
tokenizer.with_pre_tokenizer(Some(PreTokenizerWrapper::Metaspace(m)));
} else if let PreTokenizerWrapper::Sequence(s) = pre_tokenizer {
let pre_tokenizers = s.get_pre_tokenizers();
// Check if we have a Metaspace pre tokenizer in the sequence
Expand All @@ -222,9 +222,9 @@ pub fn load_tokenizer(model_root: &Path) -> Result<Tokenizer> {
}
new_pre_tokenizers.push(pre_tokenizer);
}
tokenizer.with_pre_tokenizer(PreTokenizerWrapper::Sequence(Sequence::new(
tokenizer.with_pre_tokenizer(Some(PreTokenizerWrapper::Sequence(Sequence::new(
new_pre_tokenizers,
)));
))));
}
}
}
Expand Down
Loading

0 comments on commit ad25832

Please sign in to comment.