Skip to content

Commit

Permalink
feature: GTEModel
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Nov 30, 2024
1 parent 0bfeb7e commit 9a3676c
Show file tree
Hide file tree
Showing 5 changed files with 593 additions and 130 deletions.
10 changes: 5 additions & 5 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, JinaBertModel,
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config,
};
#[cfg(feature = "cuda")]
Expand Down Expand Up @@ -218,10 +218,10 @@ impl CandleBackend {
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
(Config::Gte(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
"GTE is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
(Config::Gte(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting GTE model on {:?}", device);
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
}
(Config::Qwen2(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
Expand Down
115 changes: 5 additions & 110 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::{GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling};
use crate::models::{
GTEClassificationHead, GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling, GTEMLP,
};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use text_embeddings_backend_core::{Batch, ModelType, Pool};
Expand Down Expand Up @@ -93,60 +95,7 @@ impl GTEAttention {
}
}

struct GTEMLP {
up_gate_proj: Linear,
down_proj: Linear,

act: HiddenAct,
intermediate_size: usize,

span: tracing::Span,
}

impl GTEMLP {
pub fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
let intermediate_size = config.intermediate_size;

let up_gate_proj_weight = vb
.pp("up_gate_proj")
.get((intermediate_size * 2, config.hidden_size), "weight")?;

let up_gate_proj = Linear::new(up_gate_proj_weight, None, None);

let down_proj_weight = vb
.pp("down_proj")
.get((config.hidden_size, intermediate_size), "weight")?;
let down_proj_bias = vb.pp("down_proj").get(config.hidden_size, "bias")?;
let down_proj = Linear::new(down_proj_weight, Some(down_proj_bias), None);

Ok(Self {
up_gate_proj,
down_proj,
intermediate_size,
act: config.hidden_act.clone(),
span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
}

pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let up_gate_states = self.up_gate_proj.forward(hidden_states)?;
let up_states = up_gate_states.narrow(1, 0, self.intermediate_size)?;
let gate_states =
up_gate_states.narrow(1, self.intermediate_size, self.intermediate_size)?;

let gate_states = match self.act {
HiddenAct::Gelu => gate_states.gelu(),
HiddenAct::Relu => gate_states.relu(),
HiddenAct::Swiglu => gate_states.silu(),
}?;
let r = self.down_proj.forward(&(gate_states * up_states)?);
r
}
}

struct GTELayer {
pub struct GTELayer {
attention: GTEAttention,
mlp: GTEMLP,
attention_layer_norm: LayerNorm,
Expand Down Expand Up @@ -183,9 +132,7 @@ impl GTELayer {
max_s: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();
let attn_output = self
.attention
.forward(&hidden_states, cu_seqlens, cos, sin, max_s)?;
let attn_output = self.attention.forward(&hidden_states, cos, sin)?;
let normed_attn_res_output = self
.attention_layer_norm
.forward(&attn_output, Some(hidden_states))?;
Expand All @@ -198,58 +145,6 @@ impl GTELayer {
}
}

pub struct GTEClassificationHead {
pooler: Option<Linear>,
classifier: Linear,
span: tracing::Span,
}

impl GTEClassificationHead {
#[allow(dead_code)]
pub(crate) fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
let n_classes = match &config.id2label {
None => candle::bail!("`id2label` must be set for classifier models"),
Some(id2label) => id2label.len(),
};

let pooler = if let Ok(pooler_weight) = vb
.pp("pooler.dense")
.get((config.hidden_size, config.hidden_size), "weight")
{
let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?;
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
} else {
None
};

let classifier_weight = vb
.pp("classifier")
.get((n_classes, config.hidden_size), "weight")?;
let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?;
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);

Ok(Self {
classifier,
pooler,
span: tracing::span!(tracing::Level::TRACE, "classifier"),
})
}

pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.unsqueeze(1)?;
if let Some(pooler) = self.pooler.as_ref() {
hidden_states = pooler.forward(&hidden_states)?;
hidden_states = hidden_states.tanh()?;
}

let hidden_states = self.classifier.forward(&hidden_states)?;
let hidden_states = hidden_states.squeeze(1)?;
Ok(hidden_states)
}
}

pub struct FlashGTEModel {
word_embeddings: Embedding,
token_type_embeddings: Option<Embedding>,
Expand Down
Loading

0 comments on commit 9a3676c

Please sign in to comment.