Skip to content

Commit

Permalink
feature: ModernBert
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Dec 24, 2024
1 parent 57d8fc8 commit 3b0701e
Show file tree
Hide file tree
Showing 5 changed files with 1,350 additions and 7 deletions.
16 changes: 11 additions & 5 deletions backends/candle/src/layers/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use candle_nn::VarBuilder;
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
bias: Option<Tensor>,
epsilon: f32,
span: tracing::Span,
}
Expand All @@ -17,7 +17,8 @@ impl LayerNorm {
.or_else(|_| vb.get(hidden_size, "gamma"))?,
bias: vb
.get(hidden_size, "bias")
.or_else(|_| vb.get(hidden_size, "beta"))?,
.or_else(|_| vb.get(hidden_size, "beta"))
.ok(),
epsilon,
span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
})
Expand Down Expand Up @@ -49,7 +50,12 @@ impl LayerNorm {
let hidden_states = hidden_states_normed
.to_dtype(hidden_states_dtype)?
.broadcast_mul(&self.weight)?;
hidden_states.broadcast_add(&self.bias)

if let Some(bias) = &self.bias {
hidden_states.broadcast_add(bias)
} else {
Ok(hidden_states)
}
}
Device::Cuda(_) => {
#[cfg(feature = "cuda")]
Expand All @@ -66,12 +72,12 @@ impl LayerNorm {
&hidden_states,
&residual,
&self.weight,
Some(&self.bias),
&self.bias,
self.epsilon,
)?;
Ok(result)
} else {
layer_norm(&hidden_states, &self.weight, Some(&self.bias), self.epsilon)
layer_norm(&hidden_states, &self.weight, &self.bias, self.epsilon)
}?;
result.reshape(original_shape)
}
Expand Down
31 changes: 29 additions & 2 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use crate::compute_cap::{
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, NomicBertModel, NomicConfig,
Qwen2Config,
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig,
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
};
#[cfg(feature = "cuda")]
use crate::models::{
Expand Down Expand Up @@ -63,6 +63,8 @@ enum Config {
Qwen2(Qwen2Config),
#[serde(rename = "mpnet")]
MPNet(MPNetConfig),
#[serde(rename(deserialize = "modernbert"))]
ModernBert(ModernBertConfig),
}

pub struct CandleBackend {
Expand Down Expand Up @@ -233,6 +235,12 @@ impl CandleBackend {
tracing::info!("Starting MPNet model on {:?}", device);
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
}
(Config::ModernBert(config), _) => {
tracing::info!("Starting ModernBert model on {:?}", device);
Ok(Box::new(
ModernBertModel::load(vb, &config, model_type).s()?,
))
}
#[cfg(feature = "cuda")]
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
Expand Down Expand Up @@ -375,6 +383,25 @@ impl CandleBackend {
FlashQwen2Model::load(vb, &config, model_type).s()?,
))
}
#[cfg(feature = "cuda")]
(Config::ModernBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
&& &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
== "true"
{
return Err(BackendError::Start(
"ModernBert does not support flash attention".to_string(),
));
}

tracing::info!("Starting ModernBert model on {:?}", device);
Ok(Box::new(
ModernBERTModel::load(vb, &config, model_type).s()?,
))
}
};

Ok(Self {
Expand Down
Loading

0 comments on commit 3b0701e

Please sign in to comment.