Skip to content

Commit

Permalink
feature: cls head for gte arch
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Nov 17, 2024
1 parent 2fbf552 commit 2d1e2af
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 5 deletions.
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ impl Model for FlashBertModel {
fn is_padded(&self) -> bool {
false
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}
Expand Down
29 changes: 25 additions & 4 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::{GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling};
use crate::models::{
GTEClassificationHead, GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling,
};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use text_embeddings_backend_core::{Batch, ModelType, Pool};
Expand Down Expand Up @@ -205,6 +207,7 @@ pub struct FlashGTEModel {
embeddings_norm: LayerNorm,
cos_cache: Tensor,
sin_cache: Tensor,
classifier: Option<Box<dyn ClassificationHead + Send>>,
pool: Pool,
pub device: Device,

Expand Down Expand Up @@ -233,11 +236,15 @@ impl FlashGTEModel {
candle::bail!("Only `PositionEmbeddingType::Rope` is supported");
}

let pool = match model_type {
let (pool, classifier) = match model_type {
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for GTE")
let pool = Pool::Cls;

let classifier: Box<dyn ClassificationHead + Send> =
Box::new(GTEClassificationHead::load(vb.clone(), config)?);
(pool, Some(classifier))
}
ModelType::Embedding(pool) => pool,
ModelType::Embedding(pool) => (pool, None),
};

let word_embeddings = Embedding::new(
Expand Down Expand Up @@ -292,6 +299,7 @@ impl FlashGTEModel {
embeddings_norm,
cos_cache,
sin_cache,
classifier,
pool,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
Expand Down Expand Up @@ -457,7 +465,20 @@ impl Model for FlashGTEModel {
fn is_padded(&self) -> bool {
false
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}

fn predict(&self, batch: Batch) -> Result<Tensor> {
match &self.classifier {
None => candle::bail!("`predict` is not implemented for this model"),
Some(classifier) => {
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
let pooled_embeddings =
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
classifier.forward(&pooled_embeddings)
}
}
}
}
46 changes: 46 additions & 0 deletions backends/candle/src/models/gte.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use crate::layers::HiddenAct;
use crate::layers::Linear;
use crate::models::PositionEmbeddingType;
use candle::{Result, Tensor};
use candle_nn::VarBuilder;
use serde::Deserialize;
use std::collections::HashMap;

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct NTKScaling {
Expand Down Expand Up @@ -32,4 +36,46 @@ pub struct GTEConfig {
pub logn_attention_scale: bool,
#[serde(default)]
pub logn_attention_clip1: bool,
pub id2label: Option<HashMap<String, String>>,
}

pub trait ClassificationHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
}

pub struct GTEClassificationHead {
classifier: Linear,
span: tracing::Span,
}

impl GTEClassificationHead {
#[allow(dead_code)]
pub(crate) fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
let n_classes = match &config.id2label {
None => candle::bail!("`id2label` must be set for classifier models"),
Some(id2label) => id2label.len(),
};

let classifier_weight = vb
.pp("classifier")
.get((n_classes, config.hidden_size), "weight")?;
let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?;
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);

Ok(Self {
classifier,
span: tracing::span!(tracing::Level::TRACE, "classifier"),
})
}
}

impl ClassificationHead for GTEClassificationHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let hidden_states = hidden_states.unsqueeze(1)?;
let hidden_states = self.classifier.forward(&hidden_states)?;
let hidden_states = hidden_states.squeeze(1)?;
Ok(hidden_states)
}
}
2 changes: 1 addition & 1 deletion backends/candle/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
use candle::{Result, Tensor};
pub use distilbert::{DistilBertConfig, DistilBertModel};
#[allow(unused_imports)]
pub use gte::{GTEConfig, NTKScaling, RopeScaling};
pub use gte::{GTEClassificationHead, GTEConfig, NTKScaling, RopeScaling};
pub use jina::JinaBertModel;
pub use jina_code::JinaCodeBertModel;
pub use mistral::MistralConfig;
Expand Down

0 comments on commit 2d1e2af

Please sign in to comment.