diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index c3b771d910..be40957fc2 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -64,12 +64,27 @@ impl Ssa { let inline_targets = inline_infos.iter().filter_map(|(id, info)| info.is_inline_target().then_some(*id)); + let should_inline_call = |callee: &Function| -> bool { + match callee.runtime() { + RuntimeType::Acir(_) => { + // If we have not already finished the flattening pass, functions marked + // to not have predicates should be preserved. + let preserve_function = + !inline_no_predicates_functions && callee.is_no_predicates(); + !preserve_function + } + RuntimeType::Brillig(_) => { + // We inline inline if the function called wasn't ruled out as too costly or recursive. + InlineInfo::should_inline(inline_infos, callee.id()) + } + } + }; + // NOTE: Functions are processed independently of each other, with the final mapping replacing the original, // instead of inlining the "leaf" functions, moving up towards the entry point. self.functions = btree_map(inline_targets, |entry_point| { let function = &self.functions[&entry_point]; - let new_function = - function.inlined(&self, inline_no_predicates_functions, inline_infos); + let new_function = function.inlined(&self, &should_inline_call); (entry_point, new_function) }); self @@ -81,46 +96,8 @@ impl Function { pub(super) fn inlined( &self, ssa: &Ssa, - inline_no_predicates_functions: bool, - inline_infos: &InlineInfos, + should_inline_call: &impl Fn(&Function) -> bool, ) -> Function { - let caller_runtime = self.runtime(); - - let should_inline_call = - |_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool { - // Do not inline self-recursive functions on the top level. - // Inlining a self-recursive function works when there is something to inline into - // by importing all the recursive blocks, but for the entry function there is no wrapper. - if called_func_id == self.id() { - return false; - } - let callee = &ssa.functions[&called_func_id]; - - match callee.runtime() { - RuntimeType::Acir(inline_type) => { - // If the called function is acir, we inline if it's not an entry point - - // If we have not already finished the flattening pass, functions marked - // to not have predicates should be preserved. - let preserve_function = - !inline_no_predicates_functions && callee.is_no_predicates(); - - !inline_type.is_entry_point() && !preserve_function - } - RuntimeType::Brillig(_) => { - if caller_runtime.is_acir() { - // We never inline a brillig function into an ACIR function. - return false; - } - // We inline inline if the function called wasn't ruled out as too costly or recursive. - inline_infos - .get(&called_func_id) - .map(|info| info.should_inline) - .unwrap_or_default() - } - } - }; - InlineContext::new(ssa, self.id()).inline_all(ssa, &should_inline_call) } } @@ -146,6 +123,9 @@ struct InlineContext { /// inline into. The same goes for ValueIds, InstructionIds, and for storing other data like /// parameter to argument mappings. struct PerFunctionContext<'function> { + /// The function that we are inlining calls into. + entry_function: &'function Function, + /// The source function is the function we're currently inlining into the function being built. source_function: &'function Function, @@ -218,6 +198,10 @@ impl InlineInfo { || self.is_recursive || !self.should_inline } + + pub(super) fn should_inline(inline_infos: &InlineInfos, called_func_id: FunctionId) -> bool { + inline_infos.get(&called_func_id).map(|info| info.should_inline).unwrap_or_default() + } } type InlineInfos = BTreeMap; @@ -519,7 +503,7 @@ fn mark_brillig_functions_to_retain( inline_no_predicates_functions: bool, aggressiveness: i64, times_called: &HashMap, - inline_infos: &mut BTreeMap, + inline_infos: &mut InlineInfos, ) { let brillig_entry_points = inline_infos .iter() @@ -574,11 +558,12 @@ impl InlineContext { fn inline_all( mut self, ssa: &Ssa, - should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + should_inline_call: &impl Fn(&Function) -> bool, ) -> Function { let entry_point = &ssa.functions[&self.entry_point]; - let mut context = PerFunctionContext::new(&mut self, entry_point, &ssa.globals); + let mut context = + PerFunctionContext::new(&mut self, entry_point, entry_point, &ssa.globals); context.inlining_entry = true; for (_, value) in entry_point.dfg.globals.values_iter() { @@ -617,7 +602,7 @@ impl InlineContext { ssa: &Ssa, id: FunctionId, arguments: &[ValueId], - should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + should_inline_call: &impl Fn(&Function) -> bool, ) -> Vec { self.recursion_level += 1; @@ -629,7 +614,8 @@ impl InlineContext { ); } - let mut context = PerFunctionContext::new(self, source_function, &ssa.globals); + let entry_point = &ssa.functions[&self.entry_point]; + let mut context = PerFunctionContext::new(self, entry_point, source_function, &ssa.globals); let parameters = source_function.parameters(); assert_eq!(parameters.len(), arguments.len()); @@ -651,11 +637,13 @@ impl<'function> PerFunctionContext<'function> { /// the arguments of the destination function. fn new( context: &'function mut InlineContext, + entry_function: &'function Function, source_function: &'function Function, globals: &'function Function, ) -> Self { Self { context, + entry_function, source_function, blocks: HashMap::default(), values: HashMap::default(), @@ -777,7 +765,7 @@ impl<'function> PerFunctionContext<'function> { fn inline_blocks( &mut self, ssa: &Ssa, - should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + should_inline_call: &impl Fn(&Function) -> bool, ) -> Vec { let mut seen_blocks = HashSet::new(); let mut block_queue = VecDeque::new(); @@ -844,7 +832,7 @@ impl<'function> PerFunctionContext<'function> { &mut self, ssa: &Ssa, block_id: BasicBlockId, - should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + should_inline_call: &impl Fn(&Function) -> bool, ) { let mut side_effects_enabled: Option = None; @@ -853,19 +841,29 @@ impl<'function> PerFunctionContext<'function> { match &self.source_function.dfg[*id] { Instruction::Call { func, arguments } => match self.get_function(*func) { Some(func_id) => { - if should_inline_call(self, ssa, func_id) { - self.inline_function(ssa, *id, func_id, arguments, should_inline_call); - - // This is only relevant during handling functions with `InlineType::NoPredicates` as these - // can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`, - // resulting in predicates not being applied properly. - // - // Note that this doesn't cover the case in which there exists an `Instruction::EnabledSideEffects` - // within the function being inlined whilst the source function has not encountered one yet. - // In practice this isn't an issue as the last `Instruction::EnabledSideEffects` in the - // function being inlined will be to turn off predicates rather than to create one. - if let Some(condition) = side_effects_enabled { - self.context.builder.insert_enable_side_effects_if(condition); + if let Some(callee) = self.should_inline_call(ssa, func_id) { + if should_inline_call(callee) { + self.inline_function( + ssa, + *id, + func_id, + arguments, + should_inline_call, + ); + + // This is only relevant during handling functions with `InlineType::NoPredicates` as these + // can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`, + // resulting in predicates not being applied properly. + // + // Note that this doesn't cover the case in which there exists an `Instruction::EnabledSideEffects` + // within the function being inlined whilst the source function has not encountered one yet. + // In practice this isn't an issue as the last `Instruction::EnabledSideEffects` in the + // function being inlined will be to turn off predicates rather than to create one. + if let Some(condition) = side_effects_enabled { + self.context.builder.insert_enable_side_effects_if(condition); + } + } else { + self.push_instruction(*id); } } else { self.push_instruction(*id); @@ -882,6 +880,38 @@ impl<'function> PerFunctionContext<'function> { } } + fn should_inline_call<'a>( + &self, + ssa: &'a Ssa, + called_func_id: FunctionId, + ) -> Option<&'a Function> { + // Do not inline self-recursive functions on the top level. + // Inlining a self-recursive function works when there is something to inline into + // by importing all the recursive blocks, but for the entry function there is no wrapper. + if self.entry_function.id() == called_func_id { + return None; + } + + let callee = &ssa.functions[&called_func_id]; + + match callee.runtime() { + RuntimeType::Acir(inline_type) => { + // If the called function is acir, we inline if it's not an entry point + if inline_type.is_entry_point() { + return None; + } + } + RuntimeType::Brillig(_) => { + if self.entry_function.runtime().is_acir() { + // We never inline a brillig function into an ACIR function. + return None; + } + } + } + + Some(callee) + } + /// Inline a function call and remember the inlined return values in the values map fn inline_function( &mut self, @@ -889,7 +919,7 @@ impl<'function> PerFunctionContext<'function> { call_id: InstructionId, function: FunctionId, arguments: &[ValueId], - should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + should_inline_call: &impl Fn(&Function) -> bool, ) { let old_results = self.source_function.dfg.instruction_results(call_id); let arguments = vecmap(arguments, |arg| self.translate_value(*arg)); diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs index 439c2da5a2..a2011eb5ec 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -1,8 +1,11 @@ //! Pre-process functions before inlining them into others. -use crate::ssa::Ssa; +use crate::ssa::{ + ir::function::{Function, RuntimeType}, + Ssa, +}; -use super::inlining; +use super::inlining::{self, InlineInfo}; impl Ssa { /// Run pre-processing steps on functions in isolation. @@ -19,6 +22,19 @@ impl Ssa { // Preliminary inlining decisions. let inline_infos = inlining::compute_inline_infos(&self, false, aggressiveness); + let should_inline_call = |callee: &Function| -> bool { + match callee.runtime() { + RuntimeType::Acir(_) => { + // Functions marked to not have predicates should be preserved. + !callee.is_no_predicates() + } + RuntimeType::Brillig(_) => { + // We inline inline if the function called wasn't ruled out as too costly or recursive. + InlineInfo::should_inline(&inline_infos, callee.id()) + } + } + }; + for (id, (own_weight, transitive_weight)) in bottom_up { // Skip preprocessing heavy functions that gained most of their weight from transitive accumulation. // These can be processed later by the regular SSA passes. @@ -34,7 +50,7 @@ impl Ssa { } let function = &self.functions[&id]; // Start with an inline pass. - let mut function = function.inlined(&self, false, &inline_infos); + let mut function = function.inlined(&self, &should_inline_call); // Help unrolling determine bounds. function.as_slice_optimization(); // Prepare for unrolling