Skip to content

Commit

Permalink
fix: Keep inc_rc for array inputs during preprocessing (#7163)
Browse files Browse the repository at this point in the history
Co-authored-by: Maxim Vezenov <[email protected]>
Co-authored-by: Tom French <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2025
1 parent ff55a77 commit 29d2d8a
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 32 deletions.
71 changes: 57 additions & 14 deletions compiler/noirc_evaluator/src/ssa/opt/die.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,16 @@ impl Context {
let block = &function.dfg[block_id];
self.mark_terminator_values_as_used(function, block);

let instructions_len = block.instructions().len();

let mut rc_tracker = RcTracker::default();
rc_tracker.mark_terminator_arrays_as_used(function, block);

let instructions_len = block.instructions().len();

// Indexes of instructions that might be out of bounds.
// We'll remove those, but before that we'll insert bounds checks for them.
let mut possible_index_out_of_bounds_indexes = Vec::new();

// Going in reverse so we know if a result of an instruction was used.
for (instruction_index, instruction_id) in block.instructions().iter().rev().enumerate() {
let instruction = &function.dfg[*instruction_id];

Expand Down Expand Up @@ -241,6 +243,8 @@ impl Context {
}
}

/// Go through the RC instructions collected when we figured out which values were unused;
/// for each RC that refers to an unused value, remove the RC as well.
fn remove_rc_instructions(&self, dfg: &mut DataFlowGraph) {
let unused_rc_values_by_block: HashMap<BasicBlockId, HashSet<InstructionId>> =
self.rc_instructions.iter().fold(HashMap::default(), |mut acc, (rc, block)| {
Expand Down Expand Up @@ -580,10 +584,12 @@ struct RcTracker {
// with the same value but no array set in between.
// If we see an inc/dec RC pair within a block we can safely remove both instructions.
rcs_with_possible_pairs: HashMap<Type, Vec<RcInstruction>>,
// Tracks repeated RC instructions: if there are two `inc_rc` for the same value in a row, the 2nd one is redundant.
rc_pairs_to_remove: HashSet<InstructionId>,
// We also separately track all IncrementRc instructions and all array types which have been mutably borrowed.
// If an array is the same type as one of those non-mutated array types, we can safely remove all IncrementRc instructions on that array.
inc_rcs: HashMap<ValueId, HashSet<InstructionId>>,
// When tracking mutations we consider arrays with the same type as all being possibly mutated.
mutated_array_types: HashSet<Type>,
// The SSA often creates patterns where after simplifications we end up with repeat
// IncrementRc instructions on the same value. We track whether the previous instruction was an IncrementRc,
Expand All @@ -593,9 +599,19 @@ struct RcTracker {
}

impl RcTracker {
fn mark_terminator_arrays_as_used(&mut self, function: &Function, block: &BasicBlock) {
block.unwrap_terminator().for_each_value(|value| {
let typ = function.dfg.type_of_value(value);
if matches!(&typ, Type::Array(_, _) | Type::Slice(_)) {
self.mutated_array_types.insert(typ);
}
});
}

fn track_inc_rcs_to_remove(&mut self, instruction_id: InstructionId, function: &Function) {
let instruction = &function.dfg[instruction_id];

// Deduplicate IncRC instructions.
if let Instruction::IncrementRc { value } = instruction {
if let Some(previous_value) = self.previous_inc_rc {
if previous_value == *value {
Expand All @@ -604,13 +620,16 @@ impl RcTracker {
}
self.previous_inc_rc = Some(*value);
} else {
// Reset the deduplication.
self.previous_inc_rc = None;
}

// DIE loops over a block in reverse order, so we insert an RC instruction for possible removal
// when we see a DecrementRc and check whether it was possibly mutated when we see an IncrementRc.
match instruction {
Instruction::IncrementRc { value } => {
// Get any RC instruction recorded further down the block for this array;
// if it exists and not marked as mutated, then both RCs can be removed.
if let Some(inc_rc) =
pop_rc_for(*value, function, &mut self.rcs_with_possible_pairs)
{
Expand All @@ -619,7 +638,7 @@ impl RcTracker {
self.rc_pairs_to_remove.insert(instruction_id);
}
}

// Remember that this array was RC'd by this instruction.
self.inc_rcs.entry(*value).or_default().insert(instruction_id);
}
Instruction::DecrementRc { value } => {
Expand All @@ -632,12 +651,12 @@ impl RcTracker {
}
Instruction::ArraySet { array, .. } => {
let typ = function.dfg.type_of_value(*array);
// We mark all RCs that refer to arrays with a matching type as the one being set, as possibly mutated.
if let Some(dec_rcs) = self.rcs_with_possible_pairs.get_mut(&typ) {
for dec_rc in dec_rcs {
dec_rc.possibly_mutated = true;
}
}

self.mutated_array_types.insert(typ);
}
Instruction::Store { value, .. } => {
Expand All @@ -648,6 +667,9 @@ impl RcTracker {
}
}
Instruction::Call { arguments, .. } => {
// Treat any array-type arguments to calls as possible sources of mutation.
// During the preprocessing of functions in isolation we don't want to
// get rid of IncRCs arrays that can potentially be mutated outside.
for arg in arguments {
let typ = function.dfg.type_of_value(*arg);
if matches!(&typ, Type::Array(..) | Type::Slice(..)) {
Expand All @@ -659,6 +681,7 @@ impl RcTracker {
}
}

/// Get all RC instructions which work on arrays whose type has not been marked as mutated.
fn get_non_mutated_arrays(&self, dfg: &DataFlowGraph) -> HashSet<InstructionId> {
self.inc_rcs
.keys()
Expand Down Expand Up @@ -857,16 +880,6 @@ mod test {

#[test]
fn keep_inc_rc_on_borrowed_array_set() {
// brillig(inline) fn main f0 {
// b0(v0: [u32; 2]):
// inc_rc v0
// v3 = array_set v0, index u32 0, value u32 1
// inc_rc v0
// inc_rc v0
// inc_rc v0
// v4 = array_get v3, index u32 1
// return v4
// }
let src = "
brillig(inline) fn main f0 {
b0(v0: [u32; 2]):
Expand Down Expand Up @@ -951,6 +964,36 @@ mod test {
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn do_not_remove_inc_rcs_for_arrays_in_terminator() {
let src = "
brillig(inline) fn main f0 {
b0(v0: [Field; 2]):
inc_rc v0
inc_rc v0
inc_rc v0
v2 = array_get v0, index u32 0 -> Field
inc_rc v0
return v0, v2
}
";

let ssa = Ssa::from_str(src).unwrap();

let expected = "
brillig(inline) fn main f0 {
b0(v0: [Field; 2]):
inc_rc v0
v2 = array_get v0, index u32 0 -> Field
inc_rc v0
return v0, v2
}
";

let ssa = ssa.dead_instruction_elimination();
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn do_not_remove_inc_rc_if_used_as_call_arg() {
// We do not want to remove inc_rc instructions on values
Expand Down
26 changes: 11 additions & 15 deletions compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ impl Ssa {
// Bottom-up order, starting with the "leaf" functions, so we inline already optimized code into the ones that call them.
let bottom_up = inlining::compute_bottom_up_order(&self);

// As a heuristic to avoid optimizing functions near the entry point, find a cutoff weight.
let total_weight =
bottom_up.iter().fold(0usize, |acc, (_, (_, w))| (acc.saturating_add(*w)));
let mean_weight = total_weight / bottom_up.len();
let cutoff_weight = mean_weight;

// Preliminary inlining decisions.
let inline_infos = inlining::compute_inline_infos(&self, false, aggressiveness);

Expand All @@ -36,19 +30,21 @@ impl Ssa {
};

for (id, (own_weight, transitive_weight)) in bottom_up {
// Skip preprocessing heavy functions that gained most of their weight from transitive accumulation.
let function = &self.functions[&id];

// Skip preprocessing heavy functions that gained most of their weight from transitive accumulation, which tend to be near the entry.
// These can be processed later by the regular SSA passes.
if transitive_weight >= cutoff_weight && transitive_weight > own_weight * 2 {
continue;
}
let is_heavy = transitive_weight > own_weight * 10;

// Functions which are inline targets will be processed in later passes.
// Here we want to treat the functions which will be inlined into them.
if let Some(info) = inline_infos.get(&id) {
if info.is_inline_target() {
continue;
}
let is_target =
inline_infos.get(&id).map(|info| info.is_inline_target()).unwrap_or_default();

if is_heavy || is_target {
continue;
}
let function = &self.functions[&id];

// Start with an inline pass.
let mut function = function.inlined(&self, &should_inline_call);
// Help unrolling determine bounds.
Expand Down
7 changes: 7 additions & 0 deletions test_programs/execution_success/regression_11294/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "regression_11294"
version = "0.1.0"
type = "bin"
authors = [""]

[dependencies]
47 changes: 47 additions & 0 deletions test_programs/execution_success/regression_11294/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0c78b411fc893c51d446c08daa5741b9ba6103126c9e450bed90fcde8793168a"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000002"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000007"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
Loading

0 comments on commit 29d2d8a

Please sign in to comment.