From 6293b0d941a610e5afbe2612677c84322b7c787d Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 17 Nov 2024 15:47:34 +0900 Subject: [PATCH] update: pooler --- backends/candle/src/models/gte.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index 8dc5b344..f70da311 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -44,6 +44,7 @@ pub trait ClassificationHead { } pub struct GTEClassificationHead { + pooler: Option, classifier: Linear, span: tracing::Span, } @@ -56,6 +57,16 @@ impl GTEClassificationHead { Some(id2label) => id2label.len(), }; + let pooler = if let Ok(pooler_weight) = vb + .pp("new.pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight") + { + let pooler_bias = vb.pp("new.pooler.dense").get(config.hidden_size, "bias")?; + Some(Linear::new(pooler_weight, Some(pooler_bias), None)) + } else { + None + }; + let classifier_weight = vb .pp("classifier") .get((n_classes, config.hidden_size), "weight")?; @@ -64,6 +75,7 @@ impl GTEClassificationHead { Ok(Self { classifier, + pooler, span: tracing::span!(tracing::Level::TRACE, "classifier"), }) } @@ -73,7 +85,12 @@ impl ClassificationHead for GTEClassificationHead { fn forward(&self, hidden_states: &Tensor) -> Result { let _enter = self.span.enter(); - let hidden_states = hidden_states.unsqueeze(1)?; + let mut hidden_states = hidden_states.unsqueeze(1)?; + if let Some(pooler) = self.pooler.as_ref() { + hidden_states = pooler.forward(&hidden_states)?; + hidden_states = hidden_states.tanh()?; + } + let hidden_states = self.classifier.forward(&hidden_states)?; let hidden_states = hidden_states.squeeze(1)?; Ok(hidden_states)