Skip to content

Commit

Permalink
chore: let Function::inlined take a should_inline_call function (#…
Browse files Browse the repository at this point in the history
…7149)

Co-authored-by: Tom French <[email protected]>
  • Loading branch information
asterite and TomAFrench authored Jan 23, 2025
1 parent dce2c7d commit 2da0a4f
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 65 deletions.
154 changes: 92 additions & 62 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,27 @@ impl Ssa {
let inline_targets =
inline_infos.iter().filter_map(|(id, info)| info.is_inline_target().then_some(*id));

let should_inline_call = |callee: &Function| -> bool {
match callee.runtime() {
RuntimeType::Acir(_) => {
// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && callee.is_no_predicates();
!preserve_function
}
RuntimeType::Brillig(_) => {
// We inline inline if the function called wasn't ruled out as too costly or recursive.
InlineInfo::should_inline(inline_infos, callee.id())
}
}
};

// NOTE: Functions are processed independently of each other, with the final mapping replacing the original,
// instead of inlining the "leaf" functions, moving up towards the entry point.
self.functions = btree_map(inline_targets, |entry_point| {
let function = &self.functions[&entry_point];
let new_function =
function.inlined(&self, inline_no_predicates_functions, inline_infos);
let new_function = function.inlined(&self, &should_inline_call);
(entry_point, new_function)
});
self
Expand All @@ -81,46 +96,8 @@ impl Function {
pub(super) fn inlined(
&self,
ssa: &Ssa,
inline_no_predicates_functions: bool,
inline_infos: &InlineInfos,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Function {
let caller_runtime = self.runtime();

let should_inline_call =
|_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool {
// Do not inline self-recursive functions on the top level.
// Inlining a self-recursive function works when there is something to inline into
// by importing all the recursive blocks, but for the entry function there is no wrapper.
if called_func_id == self.id() {
return false;
}
let callee = &ssa.functions[&called_func_id];

match callee.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && callee.is_no_predicates();

!inline_type.is_entry_point() && !preserve_function
}
RuntimeType::Brillig(_) => {
if caller_runtime.is_acir() {
// We never inline a brillig function into an ACIR function.
return false;
}
// We inline inline if the function called wasn't ruled out as too costly or recursive.
inline_infos
.get(&called_func_id)
.map(|info| info.should_inline)
.unwrap_or_default()
}
}
};

InlineContext::new(ssa, self.id()).inline_all(ssa, &should_inline_call)
}
}
Expand All @@ -146,6 +123,9 @@ struct InlineContext {
/// inline into. The same goes for ValueIds, InstructionIds, and for storing other data like
/// parameter to argument mappings.
struct PerFunctionContext<'function> {
/// The function that we are inlining calls into.
entry_function: &'function Function,

/// The source function is the function we're currently inlining into the function being built.
source_function: &'function Function,

Expand Down Expand Up @@ -218,6 +198,10 @@ impl InlineInfo {
|| self.is_recursive
|| !self.should_inline
}

pub(super) fn should_inline(inline_infos: &InlineInfos, called_func_id: FunctionId) -> bool {
inline_infos.get(&called_func_id).map(|info| info.should_inline).unwrap_or_default()
}
}

type InlineInfos = BTreeMap<FunctionId, InlineInfo>;
Expand Down Expand Up @@ -519,7 +503,7 @@ fn mark_brillig_functions_to_retain(
inline_no_predicates_functions: bool,
aggressiveness: i64,
times_called: &HashMap<FunctionId, usize>,
inline_infos: &mut BTreeMap<FunctionId, InlineInfo>,
inline_infos: &mut InlineInfos,
) {
let brillig_entry_points = inline_infos
.iter()
Expand Down Expand Up @@ -574,11 +558,12 @@ impl InlineContext {
fn inline_all(
mut self,
ssa: &Ssa,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Function {
let entry_point = &ssa.functions[&self.entry_point];

let mut context = PerFunctionContext::new(&mut self, entry_point, &ssa.globals);
let mut context =
PerFunctionContext::new(&mut self, entry_point, entry_point, &ssa.globals);
context.inlining_entry = true;

for (_, value) in entry_point.dfg.globals.values_iter() {
Expand Down Expand Up @@ -617,7 +602,7 @@ impl InlineContext {
ssa: &Ssa,
id: FunctionId,
arguments: &[ValueId],
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Vec<ValueId> {
self.recursion_level += 1;

Expand All @@ -629,7 +614,8 @@ impl InlineContext {
);
}

let mut context = PerFunctionContext::new(self, source_function, &ssa.globals);
let entry_point = &ssa.functions[&self.entry_point];
let mut context = PerFunctionContext::new(self, entry_point, source_function, &ssa.globals);

let parameters = source_function.parameters();
assert_eq!(parameters.len(), arguments.len());
Expand All @@ -651,11 +637,13 @@ impl<'function> PerFunctionContext<'function> {
/// the arguments of the destination function.
fn new(
context: &'function mut InlineContext,
entry_function: &'function Function,
source_function: &'function Function,
globals: &'function Function,
) -> Self {
Self {
context,
entry_function,
source_function,
blocks: HashMap::default(),
values: HashMap::default(),
Expand Down Expand Up @@ -777,7 +765,7 @@ impl<'function> PerFunctionContext<'function> {
fn inline_blocks(
&mut self,
ssa: &Ssa,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Vec<ValueId> {
let mut seen_blocks = HashSet::new();
let mut block_queue = VecDeque::new();
Expand Down Expand Up @@ -844,7 +832,7 @@ impl<'function> PerFunctionContext<'function> {
&mut self,
ssa: &Ssa,
block_id: BasicBlockId,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) {
let mut side_effects_enabled: Option<ValueId> = None;

Expand All @@ -853,19 +841,29 @@ impl<'function> PerFunctionContext<'function> {
match &self.source_function.dfg[*id] {
Instruction::Call { func, arguments } => match self.get_function(*func) {
Some(func_id) => {
if should_inline_call(self, ssa, func_id) {
self.inline_function(ssa, *id, func_id, arguments, should_inline_call);

// This is only relevant during handling functions with `InlineType::NoPredicates` as these
// can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`,
// resulting in predicates not being applied properly.
//
// Note that this doesn't cover the case in which there exists an `Instruction::EnabledSideEffects`
// within the function being inlined whilst the source function has not encountered one yet.
// In practice this isn't an issue as the last `Instruction::EnabledSideEffects` in the
// function being inlined will be to turn off predicates rather than to create one.
if let Some(condition) = side_effects_enabled {
self.context.builder.insert_enable_side_effects_if(condition);
if let Some(callee) = self.should_inline_call(ssa, func_id) {
if should_inline_call(callee) {
self.inline_function(
ssa,
*id,
func_id,
arguments,
should_inline_call,
);

// This is only relevant during handling functions with `InlineType::NoPredicates` as these
// can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`,
// resulting in predicates not being applied properly.
//
// Note that this doesn't cover the case in which there exists an `Instruction::EnabledSideEffects`
// within the function being inlined whilst the source function has not encountered one yet.
// In practice this isn't an issue as the last `Instruction::EnabledSideEffects` in the
// function being inlined will be to turn off predicates rather than to create one.
if let Some(condition) = side_effects_enabled {
self.context.builder.insert_enable_side_effects_if(condition);
}
} else {
self.push_instruction(*id);
}
} else {
self.push_instruction(*id);
Expand All @@ -882,14 +880,46 @@ impl<'function> PerFunctionContext<'function> {
}
}

fn should_inline_call<'a>(
&self,
ssa: &'a Ssa,
called_func_id: FunctionId,
) -> Option<&'a Function> {
// Do not inline self-recursive functions on the top level.
// Inlining a self-recursive function works when there is something to inline into
// by importing all the recursive blocks, but for the entry function there is no wrapper.
if self.entry_function.id() == called_func_id {
return None;
}

let callee = &ssa.functions[&called_func_id];

match callee.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point
if inline_type.is_entry_point() {
return None;
}
}
RuntimeType::Brillig(_) => {
if self.entry_function.runtime().is_acir() {
// We never inline a brillig function into an ACIR function.
return None;
}
}
}

Some(callee)
}

/// Inline a function call and remember the inlined return values in the values map
fn inline_function(
&mut self,
ssa: &Ssa,
call_id: InstructionId,
function: FunctionId,
arguments: &[ValueId],
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) {
let old_results = self.source_function.dfg.instruction_results(call_id);
let arguments = vecmap(arguments, |arg| self.translate_value(*arg));
Expand Down
22 changes: 19 additions & 3 deletions compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
//! Pre-process functions before inlining them into others.
use crate::ssa::Ssa;
use crate::ssa::{
ir::function::{Function, RuntimeType},
Ssa,
};

use super::inlining;
use super::inlining::{self, InlineInfo};

impl Ssa {
/// Run pre-processing steps on functions in isolation.
Expand All @@ -19,6 +22,19 @@ impl Ssa {
// Preliminary inlining decisions.
let inline_infos = inlining::compute_inline_infos(&self, false, aggressiveness);

let should_inline_call = |callee: &Function| -> bool {
match callee.runtime() {
RuntimeType::Acir(_) => {
// Functions marked to not have predicates should be preserved.
!callee.is_no_predicates()
}
RuntimeType::Brillig(_) => {
// We inline inline if the function called wasn't ruled out as too costly or recursive.
InlineInfo::should_inline(&inline_infos, callee.id())
}
}
};

for (id, (own_weight, transitive_weight)) in bottom_up {
// Skip preprocessing heavy functions that gained most of their weight from transitive accumulation.
// These can be processed later by the regular SSA passes.
Expand All @@ -34,7 +50,7 @@ impl Ssa {
}
let function = &self.functions[&id];
// Start with an inline pass.
let mut function = function.inlined(&self, false, &inline_infos);
let mut function = function.inlined(&self, &should_inline_call);
// Help unrolling determine bounds.
function.as_slice_optimization();
// Prepare for unrolling
Expand Down

0 comments on commit 2da0a4f

Please sign in to comment.