Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: let Function::inlined take a should_inline_call function #7149

Merged
merged 9 commits into from
Jan 23, 2025
153 changes: 90 additions & 63 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,59 +68,35 @@ impl Ssa {
// instead of inlining the "leaf" functions, moving up towards the entry point.
self.functions = btree_map(inline_targets, |entry_point| {
let function = &self.functions[&entry_point];
let new_function =
function.inlined(&self, inline_no_predicates_functions, inline_infos);
(entry_point, new_function)
});
self
}
}

impl Function {
/// Create a new function which has the functions called by this one inlined into its body.
pub(super) fn inlined(
&self,
ssa: &Ssa,
inline_no_predicates_functions: bool,
inline_infos: &InlineInfos,
) -> Function {
let caller_runtime = self.runtime();

let should_inline_call =
|_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool {
// Do not inline self-recursive functions on the top level.
// Inlining a self-recursive function works when there is something to inline into
// by importing all the recursive blocks, but for the entry function there is no wrapper.
if called_func_id == self.id() {
return false;
}
let callee = &ssa.functions[&called_func_id];

let should_inline_call = |callee: &Function| -> bool {
match callee.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point

RuntimeType::Acir(_) => {
// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && callee.is_no_predicates();

!inline_type.is_entry_point() && !preserve_function
!preserve_function
}
RuntimeType::Brillig(_) => {
if caller_runtime.is_acir() {
// We never inline a brillig function into an ACIR function.
return false;
}
// We inline inline if the function called wasn't ruled out as too costly or recursive.
inline_infos
.get(&called_func_id)
.map(|info| info.should_inline)
.unwrap_or_default()
InlineInfo::should_inline(inline_infos, callee.id())
}
}
};
let new_function = function.inlined(&self, &should_inline_call);
(entry_point, new_function)
});
self
}
}

impl Function {
/// Create a new function which has the functions called by this one inlined into its body.
pub(super) fn inlined(
&self,
ssa: &Ssa,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Function {
InlineContext::new(ssa, self.id()).inline_all(ssa, &should_inline_call)
}
}
Expand All @@ -146,6 +122,9 @@ struct InlineContext {
/// inline into. The same goes for ValueIds, InstructionIds, and for storing other data like
/// parameter to argument mappings.
struct PerFunctionContext<'function> {
/// The function that we are inlining calls into.
entry_function: &'function Function,

/// The source function is the function we're currently inlining into the function being built.
source_function: &'function Function,

Expand Down Expand Up @@ -205,7 +184,7 @@ pub(super) struct InlineInfo {
is_brillig_entry_point: bool,
is_acir_entry_point: bool,
is_recursive: bool,
should_inline: bool,
pub(super) should_inline: bool,
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
weight: i64,
cost: i64,
}
Expand All @@ -218,6 +197,10 @@ impl InlineInfo {
|| self.is_recursive
|| !self.should_inline
}

pub(super) fn should_inline(inline_infos: &InlineInfos, called_func_id: FunctionId) -> bool {
inline_infos.get(&called_func_id).map(|info| info.should_inline).unwrap_or_default()
}
}

type InlineInfos = BTreeMap<FunctionId, InlineInfo>;
Expand Down Expand Up @@ -519,7 +502,7 @@ fn mark_brillig_functions_to_retain(
inline_no_predicates_functions: bool,
aggressiveness: i64,
times_called: &HashMap<FunctionId, usize>,
inline_infos: &mut BTreeMap<FunctionId, InlineInfo>,
inline_infos: &mut InlineInfos,
) {
let brillig_entry_points = inline_infos
.iter()
Expand Down Expand Up @@ -574,11 +557,12 @@ impl InlineContext {
fn inline_all(
mut self,
ssa: &Ssa,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Function {
let entry_point = &ssa.functions[&self.entry_point];

let mut context = PerFunctionContext::new(&mut self, entry_point, &ssa.globals);
let mut context =
PerFunctionContext::new(&mut self, entry_point, entry_point, &ssa.globals);
context.inlining_entry = true;

for (_, value) in entry_point.dfg.globals.values_iter() {
Expand Down Expand Up @@ -617,7 +601,7 @@ impl InlineContext {
ssa: &Ssa,
id: FunctionId,
arguments: &[ValueId],
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Vec<ValueId> {
self.recursion_level += 1;

Expand All @@ -629,7 +613,8 @@ impl InlineContext {
);
}

let mut context = PerFunctionContext::new(self, source_function, &ssa.globals);
let entry_point = &ssa.functions[&self.entry_point];
let mut context = PerFunctionContext::new(self, entry_point, source_function, &ssa.globals);

let parameters = source_function.parameters();
assert_eq!(parameters.len(), arguments.len());
Expand All @@ -651,11 +636,13 @@ impl<'function> PerFunctionContext<'function> {
/// the arguments of the destination function.
fn new(
context: &'function mut InlineContext,
entry_function: &'function Function,
source_function: &'function Function,
globals: &'function Function,
) -> Self {
Self {
context,
entry_function,
source_function,
blocks: HashMap::default(),
values: HashMap::default(),
Expand Down Expand Up @@ -777,7 +764,7 @@ impl<'function> PerFunctionContext<'function> {
fn inline_blocks(
&mut self,
ssa: &Ssa,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Vec<ValueId> {
let mut seen_blocks = HashSet::new();
let mut block_queue = VecDeque::new();
Expand Down Expand Up @@ -844,7 +831,7 @@ impl<'function> PerFunctionContext<'function> {
&mut self,
ssa: &Ssa,
block_id: BasicBlockId,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) {
let mut side_effects_enabled: Option<ValueId> = None;

Expand All @@ -853,19 +840,29 @@ impl<'function> PerFunctionContext<'function> {
match &self.source_function.dfg[*id] {
Instruction::Call { func, arguments } => match self.get_function(*func) {
Some(func_id) => {
if should_inline_call(self, ssa, func_id) {
self.inline_function(ssa, *id, func_id, arguments, should_inline_call);

// This is only relevant during handling functions with `InlineType::NoPredicates` as these
// can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`,
// resulting in predicates not being applied properly.
//
// Note that this doesn't cover the case in which there exists an `Instruction::EnabledSideEffects`
// within the function being inlined whilst the source function has not encountered one yet.
// In practice this isn't an issue as the last `Instruction::EnabledSideEffects` in the
// function being inlined will be to turn off predicates rather than to create one.
if let Some(condition) = side_effects_enabled {
self.context.builder.insert_enable_side_effects_if(condition);
if let Some(callee) = self.should_inline_call(ssa, func_id) {
if should_inline_call(callee) {
self.inline_function(
ssa,
*id,
func_id,
arguments,
should_inline_call,
);

// This is only relevant during handling functions with `InlineType::NoPredicates` as these
// can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`,
// resulting in predicates not being applied properly.
//
// Note that this doesn't cover the case in which there exists an `Instruction::EnabledSideEffects`
// within the function being inlined whilst the source function has not encountered one yet.
// In practice this isn't an issue as the last `Instruction::EnabledSideEffects` in the
// function being inlined will be to turn off predicates rather than to create one.
if let Some(condition) = side_effects_enabled {
self.context.builder.insert_enable_side_effects_if(condition);
}
} else {
self.push_instruction(*id);
}
} else {
self.push_instruction(*id);
Expand All @@ -882,14 +879,44 @@ impl<'function> PerFunctionContext<'function> {
}
}

fn should_inline_call<'a>(
&self,
ssa: &'a Ssa,
called_func_id: FunctionId,
) -> Option<&'a Function> {
// Do not inline self-recursive functions on the top level.
// Inlining a self-recursive function works when there is something to inline into
// by importing all the recursive blocks, but for the entry function there is no wrapper.
if self.entry_function.id() == called_func_id {
return None;
}

let callee = &ssa.functions[&called_func_id];
let callee_runtime = callee.runtime();

// Wd never inline one runtime into another
if self.entry_function.runtime().is_acir() != callee_runtime.is_acir() {
return None;
}

if let RuntimeType::Acir(inline_type) = callee_runtime {
// If the called function is acir, we inline if it's not an entry point
if inline_type.is_entry_point() {
return None;
}
}

Some(callee)
}

/// Inline a function call and remember the inlined return values in the values map
fn inline_function(
&mut self,
ssa: &Ssa,
call_id: InstructionId,
function: FunctionId,
arguments: &[ValueId],
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) {
let old_results = self.source_function.dfg.instruction_results(call_id);
let arguments = vecmap(arguments, |arg| self.translate_value(*arg));
Expand Down
22 changes: 19 additions & 3 deletions compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
//! Pre-process functions before inlining them into others.

use crate::ssa::Ssa;
use crate::ssa::{
ir::function::{Function, RuntimeType},
Ssa,
};

use super::inlining;
use super::inlining::{self, InlineInfo};

impl Ssa {
/// Run pre-processing steps on functions in isolation.
Expand Down Expand Up @@ -34,7 +37,20 @@ impl Ssa {
}
let function = &self.functions[&id];
// Start with an inline pass.
let mut function = function.inlined(&self, false, &inline_infos);
let should_inline_call = |callee: &Function| -> bool {
match callee.runtime() {
RuntimeType::Acir(_) => {
// Functions marked to not have predicates should be preserved.
!callee.is_no_predicates()
}
RuntimeType::Brillig(_) => {
// We inline inline if the function called wasn't ruled out as too costly or recursive.
InlineInfo::should_inline(&inline_infos, callee.id())
}
}
};

let mut function = function.inlined(&self, &should_inline_call);
// Help unrolling determine bounds.
function.as_slice_optimization();
// Prepare for unrolling
Expand Down
Loading