Skip to content

Commit

Permalink
fix(ort): fix mean pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jul 8, 2024
1 parent e496fe7 commit dc3aa1e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
3 changes: 2 additions & 1 deletion backends/candle/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ impl BertModel {
let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
let type_ids = Tensor::from_vec(type_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?;
let input_lengths =
let mut input_lengths =
Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;

let embedding_output = self
Expand Down Expand Up @@ -847,6 +847,7 @@ impl BertModel {
if let Some(pooled_indices) = pooled_indices {
// Select values in the batch
attention_mask = attention_mask.index_select(&pooled_indices, 0)?;
input_lengths = input_lengths.index_select(&pooled_indices, 0)?;
};

// Mask padded values
Expand Down
6 changes: 5 additions & 1 deletion backends/ort/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,22 @@ impl Backend for OrtBackend {
Pool::Mean => {
if masking {
let mut attention_mask = attention_mask;
let mut input_lengths = input_lengths;

if let Some(indices) = indices {
// Select values in the batch
attention_mask = attention_mask.select(Axis(0), &indices);
input_lengths = input_lengths.select(Axis(0), &indices);
};

// Cast and reshape
let attention_mask = attention_mask.mapv(|x| x as f32).insert_axis(Axis(2));

// Mask padded values
outputs = outputs.mul(attention_mask);
outputs.sum_axis(Axis(1)).div(input_lengths)
outputs
.sum_axis(Axis(1))
.div(input_lengths.insert_axis(Axis(1)))
} else {
outputs.mean_axis(Axis(1)).unwrap()
}
Expand Down
2 changes: 1 addition & 1 deletion load_tests/load.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export const options = {
executor: 'constant-arrival-rate',
duration: '30s',
preAllocatedVUs: 5000,
rate: 10,
rate: 50,
timeUnit: '1s',
gracefulStop: '1s',
},
Expand Down

0 comments on commit dc3aa1e

Please sign in to comment.