From 17b7160a64f3e917d29083292b6b6476b1810409 Mon Sep 17 00:00:00 2001 From: TomAFrench Date: Wed, 15 Jan 2025 23:51:30 +0000 Subject: [PATCH] chore: allow passing custom conditions to inlining pass --- .../noirc_evaluator/src/ssa/opt/inlining.rs | 124 +++++++++--------- 1 file changed, 59 insertions(+), 65 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 8e0614d15de..b5cbc90e30d 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -46,29 +46,48 @@ impl Ssa { /// This step should run after runtime separation, since it relies on the runtime of the called functions being final. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn inline_functions(self, aggressiveness: i64) -> Ssa { - Self::inline_functions_inner(self, aggressiveness, false) + let inline_sources = get_functions_to_inline_into(&self, false, aggressiveness); + Self::inline_functions_inner(self, &inline_sources, false) } // Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points pub(crate) fn inline_functions_with_no_predicates(self, aggressiveness: i64) -> Ssa { - Self::inline_functions_inner(self, aggressiveness, true) + let inline_sources = get_functions_to_inline_into(&self, true, aggressiveness); + Self::inline_functions_inner(self, &inline_sources, true) } fn inline_functions_inner( mut self, - aggressiveness: i64, + inline_sources: &BTreeSet, inline_no_predicates_functions: bool, ) -> Ssa { - let inline_sources = - get_functions_to_inline_into(&self, inline_no_predicates_functions, aggressiveness); - self.functions = btree_map(&inline_sources, |entry_point| { - let new_function = InlineContext::new( - &self, - *entry_point, - inline_no_predicates_functions, - inline_sources.clone(), - ) - .inline_all(&self); + // Note that we clear all functions other than those in `inline_sources`. + // If we decide to do partial inlining then we should change this to preserve those functions which still exist. + self.functions = btree_map(inline_sources, |entry_point| { + let should_inline_call = + |_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool { + let function = &ssa.functions[&called_func_id]; + + match function.runtime() { + RuntimeType::Acir(inline_type) => { + // If the called function is acir, we inline if it's not an entry point + + // If we have not already finished the flattening pass, functions marked + // to not have predicates should be preserved. + let preserve_function = + !inline_no_predicates_functions && function.is_no_predicates(); + !inline_type.is_entry_point() && !preserve_function + } + RuntimeType::Brillig(_) => { + // If the called function is brillig, we inline only if it's into brillig and the function is not recursive + ssa.functions[entry_point].runtime().is_brillig() + && !inline_sources.contains(&called_func_id) + } + } + }; + + let new_function = + InlineContext::new(&self, *entry_point).inline_all(&self, &should_inline_call); (*entry_point, new_function) }); self @@ -88,16 +107,6 @@ struct InlineContext { // The FunctionId of the entry point function we're inlining into in the old, unmodified Ssa. entry_point: FunctionId, - - /// Whether the inlining pass should inline any functions marked with [`InlineType::NoPredicates`] - /// or whether these should be preserved as entrypoint functions. - /// - /// This is done as we delay inlining of functions with the attribute `#[no_predicates]` until after - /// the control flow graph has been flattened. - inline_no_predicates_functions: bool, - - // These are the functions of the program that we shouldn't inline. - functions_not_to_inline: BTreeSet, } /// The per-function inlining context contains information that is only valid for one function. @@ -355,32 +364,23 @@ impl InlineContext { /// The function being inlined into will always be the main function, although it is /// actually a copy that is created in case the original main is still needed from a function /// that could not be inlined calling it. - fn new( - ssa: &Ssa, - entry_point: FunctionId, - inline_no_predicates_functions: bool, - functions_not_to_inline: BTreeSet, - ) -> Self { + fn new(ssa: &Ssa, entry_point: FunctionId) -> Self { let source = &ssa.functions[&entry_point]; let mut builder = FunctionBuilder::new(source.name().to_owned(), entry_point); builder.set_runtime(source.runtime()); builder.current_function.set_globals(source.dfg.globals.clone()); - Self { - builder, - recursion_level: 0, - entry_point, - call_stack: CallStackId::root(), - inline_no_predicates_functions, - functions_not_to_inline, - } + Self { builder, recursion_level: 0, entry_point, call_stack: CallStackId::root() } } /// Start inlining the entry point function and all functions reachable from it. - fn inline_all(mut self, ssa: &Ssa) -> Function { + fn inline_all( + mut self, + ssa: &Ssa, + should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + ) -> Function { let entry_point = &ssa.functions[&self.entry_point]; - // let globals = self.globals; let mut context = PerFunctionContext::new(&mut self, entry_point, &ssa.globals); context.inlining_entry = true; @@ -401,7 +401,7 @@ impl InlineContext { } context.blocks.insert(context.source_function.entry_block(), entry_block); - context.inline_blocks(ssa); + context.inline_blocks(ssa, should_inline_call); // translate databus values let databus = entry_point.dfg.data_bus.map_values(|t| context.translate_value(t)); @@ -420,6 +420,7 @@ impl InlineContext { ssa: &Ssa, id: FunctionId, arguments: &[ValueId], + should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, ) -> Vec { self.recursion_level += 1; @@ -440,7 +441,7 @@ impl InlineContext { let current_block = context.context.builder.current_block(); context.blocks.insert(source_function.entry_block(), current_block); - let return_values = context.inline_blocks(ssa); + let return_values = context.inline_blocks(ssa, should_inline_call); self.recursion_level -= 1; return_values } @@ -568,7 +569,11 @@ impl<'function> PerFunctionContext<'function> { } /// Inline all reachable blocks within the source_function into the destination function. - fn inline_blocks(&mut self, ssa: &Ssa) -> Vec { + fn inline_blocks( + &mut self, + ssa: &Ssa, + should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + ) -> Vec { let mut seen_blocks = HashSet::new(); let mut block_queue = VecDeque::new(); block_queue.push_back(self.source_function.entry_block()); @@ -585,7 +590,7 @@ impl<'function> PerFunctionContext<'function> { self.context.builder.switch_to_block(translated_block_id); seen_blocks.insert(source_block_id); - self.inline_block_instructions(ssa, source_block_id); + self.inline_block_instructions(ssa, source_block_id, should_inline_call); if let Some((block, values)) = self.handle_terminator_instruction(source_block_id, &mut block_queue) @@ -630,7 +635,12 @@ impl<'function> PerFunctionContext<'function> { /// Inline each instruction in the given block into the function being inlined into. /// This may recurse if it finds another function to inline if a call instruction is within this block. - fn inline_block_instructions(&mut self, ssa: &Ssa, block_id: BasicBlockId) { + fn inline_block_instructions( + &mut self, + ssa: &Ssa, + block_id: BasicBlockId, + should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + ) { let mut side_effects_enabled: Option = None; let block = &self.source_function.dfg[block_id]; @@ -638,8 +648,8 @@ impl<'function> PerFunctionContext<'function> { match &self.source_function.dfg[*id] { Instruction::Call { func, arguments } => match self.get_function(*func) { Some(func_id) => { - if self.should_inline_call(ssa, func_id) { - self.inline_function(ssa, *id, func_id, arguments); + if should_inline_call(self, ssa, func_id) { + self.inline_function(ssa, *id, func_id, arguments, should_inline_call); // This is only relevant during handling functions with `InlineType::NoPredicates` as these // can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`, @@ -667,24 +677,6 @@ impl<'function> PerFunctionContext<'function> { } } - fn should_inline_call(&self, ssa: &Ssa, called_func_id: FunctionId) -> bool { - let function = &ssa.functions[&called_func_id]; - - if let RuntimeType::Acir(inline_type) = function.runtime() { - // If the called function is acir, we inline if it's not an entry point - - // If we have not already finished the flattening pass, functions marked - // to not have predicates should be preserved. - let preserve_function = - !self.context.inline_no_predicates_functions && function.is_no_predicates(); - !inline_type.is_entry_point() && !preserve_function - } else { - // If the called function is brillig, we inline only if it's into brillig and the function is not recursive - matches!(ssa.functions[&self.context.entry_point].runtime(), RuntimeType::Brillig(_)) - && !self.context.functions_not_to_inline.contains(&called_func_id) - } - } - /// Inline a function call and remember the inlined return values in the values map fn inline_function( &mut self, @@ -692,6 +684,7 @@ impl<'function> PerFunctionContext<'function> { call_id: InstructionId, function: FunctionId, arguments: &[ValueId], + should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, ) { let old_results = self.source_function.dfg.instruction_results(call_id); let arguments = vecmap(arguments, |arg| self.translate_value(*arg)); @@ -707,7 +700,8 @@ impl<'function> PerFunctionContext<'function> { .extend_call_stack(self.context.call_stack, &call_stack); self.context.call_stack = new_call_stack; - let new_results = self.context.inline_function(ssa, function, &arguments); + let new_results = + self.context.inline_function(ssa, function, &arguments, should_inline_call); self.context.call_stack = self .context .builder