Skip to content

Commit

Permalink
feature: support classification head
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Feb 6, 2025
1 parent f0e491a commit 30d829c
Showing 1 changed file with 76 additions and 4 deletions.
80 changes: 76 additions & 4 deletions backends/candle/src/models/distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::models::Model;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Embedding, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

#[derive(Debug, Clone, PartialEq, Deserialize)]
Expand All @@ -16,6 +17,8 @@ pub struct DistilBertConfig {
pub max_position_embeddings: usize,
pub pad_token_id: usize,
pub model_type: Option<String>,
pub classifier_dropout: Option<f64>,
pub id2label: Option<HashMap<String, String>>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -318,6 +321,56 @@ impl DistilBertEncoder {
}
}

pub trait ClassificationHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
}

pub struct DistilBertClassificationHead {
pre_classifier: Linear,
classifier: Linear,
span: tracing::Span,
}

impl DistilBertClassificationHead {
pub(crate) fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> {
let n_classes = match &config.id2label {
None => candle::bail!("`id2label` must be set for classifier models"),
Some(id2label) => id2label.len(),
};

let pre_classifier_weight = vb
.pp("pre_classifier")
.get((config.dim, config.dim), "weight")?;
let pre_classifier_bias = vb.pp("pre_classifier").get(config.dim, "bias")?;
let pre_classifier = Linear::new(pre_classifier_weight, Some(pre_classifier_bias), None);

let classifier_weight = vb.pp("classifier").get((n_classes, config.dim), "weight")?;
let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?;
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);

Ok(Self {
pre_classifier,
classifier,
span: tracing::span!(tracing::Level::TRACE, "classifier"),
})
}
}

impl ClassificationHead for DistilBertClassificationHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let hidden_states = hidden_states.unsqueeze(1)?;

let hidden_states = self.pre_classifier.forward(&hidden_states)?;
let hidden_states = hidden_states.relu()?;

let hidden_states = self.classifier.forward(&hidden_states)?;
let hidden_states = hidden_states.squeeze(1)?;
Ok(hidden_states)
}
}

#[derive(Debug)]
pub struct DistilBertSpladeHead {
vocab_transform: Linear,
Expand Down Expand Up @@ -368,11 +421,11 @@ impl DistilBertSpladeHead {
}
}

#[derive(Debug)]
pub struct DistilBertModel {
embeddings: DistilBertEmbeddings,
encoder: DistilBertEncoder,
pool: Pool,
classifier: Option<Box<dyn ClassificationHead + Send>>,
splade: Option<DistilBertSpladeHead>,

num_attention_heads: usize,
Expand All @@ -385,15 +438,21 @@ pub struct DistilBertModel {

impl DistilBertModel {
pub fn load(vb: VarBuilder, config: &DistilBertConfig, model_type: ModelType) -> Result<Self> {
let pool = match model_type {
let (pool, classifier) = match model_type {
// Classifier models always use CLS pooling
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for DistilBert")
let pool = Pool::Cls;

let classifier: Box<dyn ClassificationHead + Send> =
Box::new(DistilBertClassificationHead::load(vb.clone(), config)?);
(pool, Some(classifier))
}
ModelType::Embedding(pool) => {
if pool == Pool::LastToken {
candle::bail!("`last_token` is not supported for DistilBert");
}
pool

(pool, None)
}
};

Expand Down Expand Up @@ -424,6 +483,7 @@ impl DistilBertModel {
embeddings,
encoder,
pool,
classifier,
splade,
num_attention_heads: config.n_heads,
device: vb.device().clone(),
Expand Down Expand Up @@ -660,4 +720,16 @@ impl Model for DistilBertModel {
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}

fn predict(&self, batch: Batch) -> Result<Tensor> {
match &self.classifier {
None => candle::bail!("`predict` is not implemented for this model"),
Some(classifier) => {
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
let pooled_embeddings =
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
classifier.forward(&pooled_embeddings)
}
}
}
}

0 comments on commit 30d829c

Please sign in to comment.