Skip to content

Commit

Permalink
add pedantic_solving parameter wherever needed, default to true when …
Browse files Browse the repository at this point in the history
…testing or constant-folding, add tests for pedantic_solving, test to ensure allowed bingint moduli are prime
  • Loading branch information
michaeljklein committed Dec 5, 2024
1 parent b5ee668 commit f65a0b0
Show file tree
Hide file tree
Showing 32 changed files with 645 additions and 235 deletions.
55 changes: 55 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 8 additions & 6 deletions acvm-repo/acvm/src/pwg/blackbox/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ impl AcvmBigIntSolver {
modulus: &[u8],
output: u32,
initial_witness: &mut WitnessMap<F>,
pedantic_solving: bool,
) -> Result<(), OpcodeResolutionError<F>> {
let bytes = inputs
.iter()
.map(|input| input_to_value(initial_witness, *input, false).unwrap().to_u128() as u8)
.collect::<Vec<u8>>();
self.bigint_solver.bigint_from_bytes(&bytes, modulus, output)?;
self.bigint_solver.bigint_from_bytes(&bytes, modulus, output, pedantic_solving)?;
Ok(())
}

Expand All @@ -38,7 +39,11 @@ impl AcvmBigIntSolver {
input: u32,
outputs: &[Witness],
initial_witness: &mut WitnessMap<F>,
pedantic_solving: bool,
) -> Result<(), OpcodeResolutionError<F>> {
if pedantic_solving && outputs.len() != 32 {
panic!("--pedantic-solving: bigint_to_bytes: outputs.len() != 32: {}", outputs.len());
}
let mut bytes = self.bigint_solver.bigint_to_bytes(input)?;
while bytes.len() < outputs.len() {
bytes.push(0);
Expand All @@ -55,12 +60,9 @@ impl AcvmBigIntSolver {
rhs: u32,
output: u32,
func: BlackBoxFunc,
pedantic_solving: bool,
) -> Result<(), OpcodeResolutionError<F>> {
self.bigint_solver.bigint_op(lhs, rhs, output, func)?;
self.bigint_solver.bigint_op(lhs, rhs, output, func, pedantic_solving)?;
Ok(())
}

pub(crate) fn is_valid_modulus(&self, modulus: &[u8]) -> bool {
self.bigint_solver.is_valid_modulus(modulus)
}
}
3 changes: 2 additions & 1 deletion acvm-repo/acvm/src/pwg/blackbox/embedded_curve_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub(super) fn multi_scalar_mul<F: AcirField>(
points: &[FunctionInput<F>],
scalars: &[FunctionInput<F>],
outputs: (Witness, Witness, Witness),
pedantic_solving: bool,
) -> Result<(), OpcodeResolutionError<F>> {
let points: Result<Vec<_>, _> =
points.iter().map(|input| input_to_value(initial_witness, *input, false)).collect();
Expand All @@ -31,7 +32,7 @@ pub(super) fn multi_scalar_mul<F: AcirField>(
}
// Call the backend's multi-scalar multiplication function
let (res_x, res_y, is_infinite) =
backend.multi_scalar_mul(&points, &scalars_lo, &scalars_hi)?;
backend.multi_scalar_mul(&points, &scalars_lo, &scalars_hi, pedantic_solving)?;

// Insert the resulting point into the witness map
insert_value(&outputs.0, res_x, initial_witness)?;
Expand Down
12 changes: 8 additions & 4 deletions acvm-repo/acvm/src/pwg/blackbox/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ pub(super) fn and<F: AcirField>(
lhs: &FunctionInput<F>,
rhs: &FunctionInput<F>,
output: &Witness,
pedantic_solving: bool,
) -> Result<(), OpcodeResolutionError<F>> {
assert_eq!(
lhs.num_bits(),
rhs.num_bits(),
"number of bits specified for each input must be the same"
);
solve_logic_opcode(initial_witness, lhs, rhs, *output, |left, right| {
solve_logic_opcode(initial_witness, lhs, rhs, *output, pedantic_solving, |left, right| {
bit_and(left, right, lhs.num_bits())
})
}
Expand All @@ -32,13 +33,14 @@ pub(super) fn xor<F: AcirField>(
lhs: &FunctionInput<F>,
rhs: &FunctionInput<F>,
output: &Witness,
pedantic_solving: bool,
) -> Result<(), OpcodeResolutionError<F>> {
assert_eq!(
lhs.num_bits(),
rhs.num_bits(),
"number of bits specified for each input must be the same"
);
solve_logic_opcode(initial_witness, lhs, rhs, *output, |left, right| {
solve_logic_opcode(initial_witness, lhs, rhs, *output, pedantic_solving, |left, right| {
bit_xor(left, right, lhs.num_bits())
})
}
Expand All @@ -49,11 +51,13 @@ fn solve_logic_opcode<F: AcirField>(
a: &FunctionInput<F>,
b: &FunctionInput<F>,
result: Witness,
pedantic_solving: bool,
logic_op: impl Fn(F, F) -> F,
) -> Result<(), OpcodeResolutionError<F>> {
// TODO(https://github.com/noir-lang/noir/issues/5985): re-enable these once we figure out how to combine these with existing
// TODO(https://github.com/noir-lang/noir/issues/5985): re-enable these by
// default once we figure out how to combine these with existing
// noirc_frontend/noirc_evaluator overflow error messages
let skip_bitsize_checks = true;
let skip_bitsize_checks = !pedantic_solving;
let w_l_value = input_to_value(initial_witness, *a, skip_bitsize_checks)?;
let w_r_value = input_to_value(initial_witness, *b, skip_bitsize_checks)?;
let assignment = logic_op(w_l_value, w_r_value);
Expand Down
48 changes: 20 additions & 28 deletions acvm-repo/acvm/src/pwg/blackbox/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,19 @@ pub(crate) fn solve<F: AcirField>(
));
}

// TODO: check input value sizes when pedantic_solving
match bb_func {
BlackBoxFuncCall::AES128Encrypt { inputs, iv, key, outputs } => {
solve_aes128_encryption_opcode(initial_witness, inputs, iv, key, outputs)
}
BlackBoxFuncCall::AND { lhs, rhs, output } => and(initial_witness, lhs, rhs, output),
BlackBoxFuncCall::XOR { lhs, rhs, output } => xor(initial_witness, lhs, rhs, output),
BlackBoxFuncCall::RANGE { input } => solve_range_opcode(initial_witness, input, pedantic_solving),
BlackBoxFuncCall::AND { lhs, rhs, output } => {
and(initial_witness, lhs, rhs, output, pedantic_solving)
}
BlackBoxFuncCall::XOR { lhs, rhs, output } => {
xor(initial_witness, lhs, rhs, output, pedantic_solving)
}
BlackBoxFuncCall::RANGE { input } => {
solve_range_opcode(initial_witness, input, pedantic_solving)
}
BlackBoxFuncCall::Blake2s { inputs, outputs } => {
solve_generic_256_hash_opcode(initial_witness, inputs, None, outputs, blake2s)
}
Expand Down Expand Up @@ -149,11 +154,7 @@ pub(crate) fn solve<F: AcirField>(
*output,
),
BlackBoxFuncCall::MultiScalarMul { points, scalars, outputs } => {
if pedantic_solving && points.len() != scalars.len() {
// TODO: better error or ICE
panic!("MultiScalarMul")
}
multi_scalar_mul(backend, initial_witness, points, scalars, *outputs)
multi_scalar_mul(backend, initial_witness, points, scalars, *outputs, pedantic_solving)
}
BlackBoxFuncCall::EmbeddedCurveAdd { input1, input2, outputs } => {
embedded_curve_add(backend, initial_witness, **input1, **input2, *outputs)
Expand All @@ -163,31 +164,22 @@ pub(crate) fn solve<F: AcirField>(
BlackBoxFuncCall::BigIntAdd { lhs, rhs, output }
| BlackBoxFuncCall::BigIntSub { lhs, rhs, output }
| BlackBoxFuncCall::BigIntMul { lhs, rhs, output }
| BlackBoxFuncCall::BigIntDiv { lhs, rhs, output } => {
bigint_solver.bigint_op(*lhs, *rhs, *output, bb_func.get_black_box_func())
}
BlackBoxFuncCall::BigIntFromLeBytes { inputs, modulus, output } => {
if pedantic_solving && (!bigint_solver.is_valid_modulus(modulus) || inputs.len() > 32) {
// TODO: better error or ICE
panic!("BigIntFromLeBytes")
}
bigint_solver.bigint_from_bytes(inputs, modulus, *output, initial_witness)
}
| BlackBoxFuncCall::BigIntDiv { lhs, rhs, output } => bigint_solver.bigint_op(
*lhs,
*rhs,
*output,
bb_func.get_black_box_func(),
pedantic_solving,
),
BlackBoxFuncCall::BigIntFromLeBytes { inputs, modulus, output } => bigint_solver
.bigint_from_bytes(inputs, modulus, *output, initial_witness, pedantic_solving),
BlackBoxFuncCall::BigIntToLeBytes { input, outputs } => {
if pedantic_solving && outputs.len() != 32 {
// TODO: better error or ICE
panic!("BigIntToLeBytes")
}
bigint_solver.bigint_to_bytes(*input, outputs, initial_witness)
bigint_solver.bigint_to_bytes(*input, outputs, initial_witness, pedantic_solving)
}
BlackBoxFuncCall::Sha256Compression { inputs, hash_values, outputs } => {
solve_sha_256_permutation_opcode(initial_witness, inputs, hash_values, outputs)
}
BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs, len } => {
if pedantic_solving && inputs.len() != outputs.len() {
// TODO: better error or ICE
panic!("Poseidon2Permutation")
}
solve_poseidon2_permutation_opcode(backend, initial_witness, inputs, outputs, *len)
}
}
Expand Down
12 changes: 11 additions & 1 deletion acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl<'b, B: BlackBoxFunctionSolver<F>, F: AcirField> BrilligSolver<'b, F, B> {
acir_index: usize,
brillig_function_id: BrilligFunctionId,
profiling_active: bool,
pedantic_solving: bool,
) -> Result<Self, OpcodeResolutionError<F>> {
let vm = Self::setup_brillig_vm(
initial_witness,
Expand All @@ -75,6 +76,7 @@ impl<'b, B: BlackBoxFunctionSolver<F>, F: AcirField> BrilligSolver<'b, F, B> {
brillig_bytecode,
bb_solver,
profiling_active,
pedantic_solving,
)?;
Ok(Self { vm, acir_index, function_id: brillig_function_id })
}
Expand All @@ -86,6 +88,7 @@ impl<'b, B: BlackBoxFunctionSolver<F>, F: AcirField> BrilligSolver<'b, F, B> {
brillig_bytecode: &'b [BrilligOpcode<F>],
bb_solver: &'b B,
profiling_active: bool,
pedantic_solving: bool,
) -> Result<VM<'b, F, B>, OpcodeResolutionError<F>> {
// Set input values
let mut calldata: Vec<F> = Vec::new();
Expand Down Expand Up @@ -133,7 +136,14 @@ impl<'b, B: BlackBoxFunctionSolver<F>, F: AcirField> BrilligSolver<'b, F, B> {

// Instantiate a Brillig VM given the solved calldata
// along with the Brillig bytecode.
let vm = VM::new(calldata, brillig_bytecode, vec![], bb_solver, profiling_active);
let vm = VM::new(
calldata,
brillig_bytecode,
vec![],
bb_solver,
profiling_active,
pedantic_solving,
);
Ok(vm)
}

Expand Down
29 changes: 24 additions & 5 deletions acvm-repo/acvm/src/pwg/memory_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ impl<F: AcirField> MemoryOpSolver<F> {

// Fetch whether or not the predicate is false (e.g. equal to zero)
let opcode_location = ErrorLocation::Unresolved;
let skip_operation = is_predicate_false(initial_witness, predicate, pedantic_solving, &opcode_location)?;
let skip_operation =
is_predicate_false(initial_witness, predicate, pedantic_solving, &opcode_location)?;

if is_read_operation {
// `value_read = arr[memory_index]`
Expand Down Expand Up @@ -152,7 +153,10 @@ mod tests {
block_solver.init(&init, &initial_witness).unwrap();

for op in trace {
block_solver.solve_memory_op(&op, &mut initial_witness, &None).unwrap();
let pedantic_solving = true;
block_solver
.solve_memory_op(&op, &mut initial_witness, &None, pedantic_solving)
.unwrap();
}

assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128));
Expand All @@ -177,7 +181,10 @@ mod tests {
let mut err = None;
for op in invalid_trace {
if err.is_none() {
err = block_solver.solve_memory_op(&op, &mut initial_witness, &None).err();
let pedantic_solving = true;
err = block_solver
.solve_memory_op(&op, &mut initial_witness, &None, pedantic_solving)
.err();
}
}

Expand Down Expand Up @@ -210,8 +217,14 @@ mod tests {
let mut err = None;
for op in invalid_trace {
if err.is_none() {
let pedantic_solving = true;
err = block_solver
.solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero()))
.solve_memory_op(
&op,
&mut initial_witness,
&Some(Expression::zero()),
pedantic_solving,
)
.err();
}
}
Expand Down Expand Up @@ -242,8 +255,14 @@ mod tests {
let mut err = None;
for op in invalid_trace {
if err.is_none() {
let pedantic_solving = true;
err = block_solver
.solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero()))
.solve_memory_op(
&op,
&mut initial_witness,
&Some(Expression::zero()),
pedantic_solving,
)
.err();
}
}
Expand Down
Loading

0 comments on commit f65a0b0

Please sign in to comment.