Skip to content

Commit

Permalink
fix: set index and value to 0 for array_get with predicate (#4971)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves issue on noir-protocol-circuit and should also resolve issues
described in #4716.
The issue it resolves is that array_get under a predicate are done at
index 0 if the predicate is false, but since arrays are not homogenous,
it may contain some value that overflow the type of the array_get.

## Summary\*
When getting value from an array_get, I multiply it with the predicate
to avoid any possible overflow if element at index 0 has not the same
type as the expected one.

## Additional Context
If the array is simple (not nested), I get the offset to a compatible
type and use it when computing the predicate_index.
If the first element of the array has compatible size, then the
predicate_index will be correct
If not, we fallback to the multiplication of the value with the
predicate.


## Documentation\*

Check one:
- [X] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [X] I have tested the changes locally.
- [X] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <[email protected]>
  • Loading branch information
guipublic and TomAFrench authored May 7, 2024
1 parent f3f1150 commit c49d3a9
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 8 deletions.
69 changes: 61 additions & 8 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -957,13 +957,39 @@ impl<'a> Context<'a> {
return Ok(());
}

let (new_index, new_value) =
self.convert_array_operation_inputs(array, dfg, index, store_value)?;
// Get an offset such that the type of the array at the offset is the same as the type at the 'index'
// If we find one, we will use it when computing the index under the enable_side_effect predicate
// If not, array_get(..) will use a fallback costing one multiplication in the worst case.
// cf. https://github.com/noir-lang/noir/pull/4971
let array_id = dfg.resolve(array);
let array_typ = dfg.type_of_value(array_id);
// For simplicity we compute the offset only for simple arrays
let is_simple_array = dfg.instruction_results(instruction).len() == 1
&& can_omit_element_sizes_array(&array_typ);
let offset = if is_simple_array {
let result_type = dfg.type_of_value(dfg.instruction_results(instruction)[0]);
match array_typ {
Type::Array(item_type, _) | Type::Slice(item_type) => item_type
.iter()
.enumerate()
.find_map(|(index, typ)| (result_type == *typ).then_some(index)),
_ => None,
}
} else {
None
};
let (new_index, new_value) = self.convert_array_operation_inputs(
array,
dfg,
index,
store_value,
offset.unwrap_or_default(),
)?;

if let Some(new_value) = new_value {
self.array_set(instruction, new_index, new_value, dfg, mutable_array_set)?;
} else {
self.array_get(instruction, array, new_index, dfg)?;
self.array_get(instruction, array, new_index, dfg, offset.is_none())?;
}

Ok(())
Expand Down Expand Up @@ -1053,7 +1079,7 @@ impl<'a> Context<'a> {
/// - new_index is the index of the array. ACIR memory operations work with a flat memory, so we fully flattened the specified index
/// in case we have a nested array. The index for SSA array operations only represents the flattened index of the current array.
/// Thus internal array element type sizes need to be computed to accurately transform the index.
/// - predicate_index is 0, or the index if the predicate is true
/// - predicate_index is offset, or the index if the predicate is true
/// - new_value is the optional value when the operation is an array_set
/// When there is a predicate, it is predicate*value + (1-predicate)*dummy, where dummy is the value of the array at the requested index.
/// It is a dummy value because in the case of a false predicate, the value stored at the requested index will be itself.
Expand All @@ -1063,14 +1089,18 @@ impl<'a> Context<'a> {
dfg: &DataFlowGraph,
index: ValueId,
store_value: Option<ValueId>,
offset: usize,
) -> Result<(AcirVar, Option<AcirValue>), RuntimeError> {
let (array_id, array_typ, block_id) = self.check_array_is_initialized(array, dfg)?;

let index_var = self.convert_numeric_value(index, dfg)?;
let index_var = self.get_flattened_index(&array_typ, array_id, index_var, dfg)?;

let predicate_index =
self.acir_context.mul_var(index_var, self.current_side_effects_enabled_var)?;
// predicate_index = index*predicate + (1-predicate)*offset
let offset = self.acir_context.add_constant(offset);
let sub = self.acir_context.sub_var(index_var, offset)?;
let pred = self.acir_context.mul_var(sub, self.current_side_effects_enabled_var)?;
let predicate_index = self.acir_context.add_var(pred, offset)?;

let new_value = if let Some(store) = store_value {
let store_value = self.convert_value(store, dfg);
Expand Down Expand Up @@ -1171,12 +1201,14 @@ impl<'a> Context<'a> {
}

/// Generates a read opcode for the array
/// `index_side_effect == false` means that we ensured `var_index` will have a type matching the value in the array
fn array_get(
&mut self,
instruction: InstructionId,
array: ValueId,
mut var_index: AcirVar,
dfg: &DataFlowGraph,
mut index_side_effect: bool,
) -> Result<AcirValue, RuntimeError> {
let (array_id, _, block_id) = self.check_array_is_initialized(array, dfg)?;
let results = dfg.instruction_results(instruction);
Expand All @@ -1195,7 +1227,7 @@ impl<'a> Context<'a> {
self.data_bus.call_data_map[&array_id] as i128,
));
let new_index = self.acir_context.add_var(offset, bus_index)?;
return self.array_get(instruction, call_data, new_index, dfg);
return self.array_get(instruction, call_data, new_index, dfg, index_side_effect);
}
}

Expand All @@ -1204,7 +1236,28 @@ impl<'a> Context<'a> {
!res_typ.contains_slice_element(),
"ICE: Nested slice result found during ACIR generation"
);
let value = self.array_get_value(&res_typ, block_id, &mut var_index)?;
let mut value = self.array_get_value(&res_typ, block_id, &mut var_index)?;

if let AcirValue::Var(value_var, typ) = &value {
let array_id = dfg.resolve(array_id);
let array_typ = dfg.type_of_value(array_id);
if let (Type::Numeric(numeric_type), AcirType::NumericType(num)) =
(array_typ.first(), typ)
{
if numeric_type.bit_size() <= num.bit_size() {
// first element is compatible
index_side_effect = false;
}
}
// Fallback to multiplication if the index side_effects have not already been handled
if index_side_effect {
// Set the value to 0 if current_side_effects is 0, to ensure it fits in any value type
value = AcirValue::Var(
self.acir_context.mul_var(*value_var, self.current_side_effects_enabled_var)?,
typ.clone(),
);
}
}

self.define_result(dfg, instruction, value.clone());

Expand Down
8 changes: 8 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ impl Type {
other => panic!("element_types: Expected array or slice, found {other}"),
}
}

pub(crate) fn first(&self) -> Type {
match self {
Type::Numeric(_) | Type::Function => self.clone(),
Type::Reference(typ) => typ.first(),
Type::Slice(element_types) | Type::Array(element_types, _) => element_types[0].first(),
}
}
}

/// Composite Types are essentially flattened struct or tuple types.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "regression_struct_array_conditional"
version = "0.1.0"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
y = 1
z = 1

[[x]]
value = "0x23de33be058ce5504e1ade738db8bdacfe268fa9dbde777092bf1d38519bdf59"
counter = "10"
dummy = "0"

[[x]]
value = "3"
counter = "2"
dummy = "0"

[[x]]
value = "2"
counter = "0"
dummy = "0"

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
struct foo {
value: Field,
counter: u8,
dummy: u8,
}
struct bar {
dummy: [u8;3],
value: Field,
counter: u8,
}
struct bar_field {
dummy: [Field;3],
value: Field,
counter: u8,
}
fn main(x: [foo; 3], y: u32, z: u32) -> pub u8 {
let a = [y, z, x[y].counter as u32];
let mut b = [bar { value: 0, counter: 0, dummy: [0; 3] }; 3];
let mut c = [bar_field { value: 0, counter: 0, dummy: [0; 3] }; 3];
for i in 0..3 {
b[i].value = x[i].value;
b[i].counter = x[i].counter;
b[i].dummy[0] = x[i].dummy;
c[i].value = x[i].value;
c[i].counter = x[i].counter;
c[i].dummy[0] = x[i].dummy as Field;
}
if z == 0 {
// offset
assert(y as u8 < x[y].counter);
assert(y <= a[y]);
// first element is compatible
assert(y as u8 < b[y].counter);
// fallback
assert(y as u8 < c[y].counter);
}
x[0].counter
}

0 comments on commit c49d3a9

Please sign in to comment.