Skip to content

Commit

Permalink
fix: correct batch inversion implementation (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench authored Feb 8, 2025
1 parent 5ad862d commit 399d21a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/fns/constrained_ops.nr
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ global TWO_POW_120: Field = 0x1000000000000000000000000000000;
*/

pub(crate) fn limbs_to_field<let N: u32, let MOD_BITS: u32>(
params: P<N, MOD_BITS>,
_params: P<N, MOD_BITS>,
limbs: [Field; N],
) -> Field {
let TWO_POW_120 = 0x1000000000000000000000000000000;
Expand All @@ -58,7 +58,7 @@ pub(crate) fn limbs_to_field<let N: u32, let MOD_BITS: u32>(
}

pub(crate) fn from_field<let N: u32, let MOD_BITS: u32>(
params: P<N, MOD_BITS>,
_params: P<N, MOD_BITS>,
field: Field,
) -> [Field; N] {
// Safety: we check that the resulting limbs represent the intended field element
Expand Down
2 changes: 1 addition & 1 deletion src/fns/unconstrained_ops.nr
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ pub(crate) unconstrained fn __batch_invert<let N: u32, let MOD_BITS: u32, let M:
) -> [[Field; N]; M] {
// TODO: ugly! Will fail if input slice is empty
let mut accumulator: [Field; N] = __one::<N>();
let mut temporaries: [[Field; N]; N] = std::mem::zeroed();
let mut temporaries: [[Field; N]; M] = std::mem::zeroed();
for i in 0..M {
temporaries[i] = accumulator;
if (!__is_zero(x[i])) {
Expand Down
8 changes: 2 additions & 6 deletions src/tests/bignum_test.nr
Original file line number Diff line number Diff line change
Expand Up @@ -950,9 +950,7 @@ where
#[test]
unconstrained fn test_batch_inversion_BN381(seeds: [[u8; 2]; 3]) {
let fields = seeds.map(|seed| BN381::derive_from_seed(seed));
unsafe {
test_batch_inversion(fields)
}
test_batch_inversion(fields)
}

unconstrained fn test_batch_inversion_slice<BN>(fields: [BN])
Expand All @@ -969,7 +967,5 @@ where
#[test]
unconstrained fn test_batch_inversion_slice_BN381(seeds: [[u8; 2]; 3]) {
let fields = seeds.map(|seed| BN381::derive_from_seed(seed)).as_slice();
unsafe {
test_batch_inversion_slice(fields)
}
test_batch_inversion_slice(fields)
}

0 comments on commit 399d21a

Please sign in to comment.