diff --git a/backends/candle/src/models/bert.rs b/backends/candle/src/models/bert.rs index a36a1555..32880d44 100644 --- a/backends/candle/src/models/bert.rs +++ b/backends/candle/src/models/bert.rs @@ -804,7 +804,7 @@ impl BertModel { let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?; let type_ids = Tensor::from_vec(type_ids, shape, &self.device)?; let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?; - let input_lengths = + let mut input_lengths = Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?; let embedding_output = self @@ -847,6 +847,7 @@ impl BertModel { if let Some(pooled_indices) = pooled_indices { // Select values in the batch attention_mask = attention_mask.index_select(&pooled_indices, 0)?; + input_lengths = input_lengths.index_select(&pooled_indices, 0)?; }; // Mask padded values diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index 6951d3b4..9573f6b0 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -209,10 +209,12 @@ impl Backend for OrtBackend { Pool::Mean => { if masking { let mut attention_mask = attention_mask; + let mut input_lengths = input_lengths; if let Some(indices) = indices { // Select values in the batch attention_mask = attention_mask.select(Axis(0), &indices); + input_lengths = input_lengths.select(Axis(0), &indices); }; // Cast and reshape @@ -220,7 +222,9 @@ impl Backend for OrtBackend { // Mask padded values outputs = outputs.mul(attention_mask); - outputs.sum_axis(Axis(1)).div(input_lengths) + outputs + .sum_axis(Axis(1)) + .div(input_lengths.insert_axis(Axis(1))) } else { outputs.mean_axis(Axis(1)).unwrap() } diff --git a/load_tests/load.js b/load_tests/load.js index b3705476..867f9fdb 100644 --- a/load_tests/load.js +++ b/load_tests/load.js @@ -27,7 +27,7 @@ export const options = { executor: 'constant-arrival-rate', duration: '30s', preAllocatedVUs: 5000, - rate: 10, + rate: 50, timeUnit: '1s', gracefulStop: '1s', },