Skip to content

Commit

Permalink
update: pooler
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Nov 17, 2024
1 parent 2d1e2af commit 6293b0d
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion backends/candle/src/models/gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub trait ClassificationHead {
}

pub struct GTEClassificationHead {
pooler: Option<Linear>,
classifier: Linear,
span: tracing::Span,
}
Expand All @@ -56,6 +57,16 @@ impl GTEClassificationHead {
Some(id2label) => id2label.len(),
};

let pooler = if let Ok(pooler_weight) = vb
.pp("new.pooler.dense")
.get((config.hidden_size, config.hidden_size), "weight")
{
let pooler_bias = vb.pp("new.pooler.dense").get(config.hidden_size, "bias")?;
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
} else {
None
};

let classifier_weight = vb
.pp("classifier")
.get((n_classes, config.hidden_size), "weight")?;
Expand All @@ -64,6 +75,7 @@ impl GTEClassificationHead {

Ok(Self {
classifier,
pooler,
span: tracing::span!(tracing::Level::TRACE, "classifier"),
})
}
Expand All @@ -73,7 +85,12 @@ impl ClassificationHead for GTEClassificationHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let hidden_states = hidden_states.unsqueeze(1)?;
let mut hidden_states = hidden_states.unsqueeze(1)?;
if let Some(pooler) = self.pooler.as_ref() {
hidden_states = pooler.forward(&hidden_states)?;
hidden_states = hidden_states.tanh()?;
}

let hidden_states = self.classifier.forward(&hidden_states)?;
let hidden_states = hidden_states.squeeze(1)?;
Ok(hidden_states)
Expand Down

0 comments on commit 6293b0d

Please sign in to comment.