Skip to content

Commit

Permalink
Adjust zero point when converting batched matmul to non-batched
Browse files Browse the repository at this point in the history
When a batched matmul `[A, M, K] x [K, N]` is converted to a non-batched matmul
with LHS shape `[A * M, K]` the zero point needs to be broadcast to match the
new row count.

This fixes an error when running the Segment Anything demo with a quantized
image encoder.
  • Loading branch information
robertknight committed Feb 3, 2025
1 parent f5fcba9 commit ad04fbd
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,18 @@ where
// nb. We assume `a` is likely already contiguous, so this will be cheap.
let a_contig = a.to_contiguous_in(pool).auto_return(pool);
let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice());

// Broadcast zero point to match new row count.
let a_quant: Option<Vec<LhsT>> = a_quant.map(|a_quant| {
a_quant
.zero_point
.iter()
.copied()
.cycle()
.take(a_matrix.size(0))
.collect()
});

let mut output = matmul_impl(
pool,
a_matrix.view(),
Expand All @@ -230,7 +242,9 @@ where
strategy,
bias,
alpha,
a_quant,
a_quant.as_ref().map(|zero_point| QuantParams {
zero_point: zero_point.as_slice(),
}),
b_quant,
)?;
output.reshape(out_shape);
Expand Down Expand Up @@ -1058,14 +1072,22 @@ mod tests {
expected_err: None,
},
// Vector zero points
Case {
a: Tensor::zeros(&[3, 2, 2]),
b: Tensor::from([[5, 6], [7, 8]]),
a_zero_point: Some(Tensor::from([1, 2])),
b_zero_point: Some(Tensor::from([3, 4])),
expected_err: None,
},
// LHS batch input with vector zero point
Case {
a: Tensor::from([[1, 2], [3, 4]]),
b: Tensor::from([[5, 6], [7, 8]]),
a_zero_point: Some(Tensor::from([1, 2])),
b_zero_point: Some(Tensor::from([3, 4])),
expected_err: None,
},
// A input which is a row vector
// An input which is a row vector
Case {
a: Tensor::from([[1, 2, 3, 4]]),
b: Tensor::from([[5, 6], [7, 8], [9, 10], [11, 12]]),
Expand Down

0 comments on commit ad04fbd

Please sign in to comment.