Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: allow passing custom conditions to inlining pass #7083

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 59 additions & 65 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,48 @@ impl Ssa {
/// This step should run after runtime separation, since it relies on the runtime of the called functions being final.
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn inline_functions(self, aggressiveness: i64) -> Ssa {
Self::inline_functions_inner(self, aggressiveness, false)
let inline_sources = get_functions_to_inline_into(&self, false, aggressiveness);
Self::inline_functions_inner(self, &inline_sources, false)
}

// Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points
pub(crate) fn inline_functions_with_no_predicates(self, aggressiveness: i64) -> Ssa {
Self::inline_functions_inner(self, aggressiveness, true)
let inline_sources = get_functions_to_inline_into(&self, true, aggressiveness);
Self::inline_functions_inner(self, &inline_sources, true)
}

fn inline_functions_inner(
mut self,
aggressiveness: i64,
inline_sources: &BTreeSet<FunctionId>,
inline_no_predicates_functions: bool,
) -> Ssa {
let inline_sources =
get_functions_to_inline_into(&self, inline_no_predicates_functions, aggressiveness);
self.functions = btree_map(&inline_sources, |entry_point| {
let new_function = InlineContext::new(
&self,
*entry_point,
inline_no_predicates_functions,
inline_sources.clone(),
)
.inline_all(&self);
// Note that we clear all functions other than those in `inline_sources`.
// If we decide to do partial inlining then we should change this to preserve those functions which still exist.
self.functions = btree_map(inline_sources, |entry_point| {
let should_inline_call =
|_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool {
let function = &ssa.functions[&called_func_id];

match function.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && function.is_no_predicates();
!inline_type.is_entry_point() && !preserve_function
}
RuntimeType::Brillig(_) => {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
ssa.functions[entry_point].runtime().is_brillig()
&& !inline_sources.contains(&called_func_id)
}
}
};

let new_function =
InlineContext::new(&self, *entry_point).inline_all(&self, &should_inline_call);
(*entry_point, new_function)
});
self
Expand All @@ -88,16 +107,6 @@ struct InlineContext {

// The FunctionId of the entry point function we're inlining into in the old, unmodified Ssa.
entry_point: FunctionId,

/// Whether the inlining pass should inline any functions marked with [`InlineType::NoPredicates`]
/// or whether these should be preserved as entrypoint functions.
///
/// This is done as we delay inlining of functions with the attribute `#[no_predicates]` until after
/// the control flow graph has been flattened.
inline_no_predicates_functions: bool,

// These are the functions of the program that we shouldn't inline.
functions_not_to_inline: BTreeSet<FunctionId>,
}

/// The per-function inlining context contains information that is only valid for one function.
Expand Down Expand Up @@ -355,32 +364,23 @@ impl InlineContext {
/// The function being inlined into will always be the main function, although it is
/// actually a copy that is created in case the original main is still needed from a function
/// that could not be inlined calling it.
fn new(
ssa: &Ssa,
entry_point: FunctionId,
inline_no_predicates_functions: bool,
functions_not_to_inline: BTreeSet<FunctionId>,
) -> Self {
fn new(ssa: &Ssa, entry_point: FunctionId) -> Self {
let source = &ssa.functions[&entry_point];
let mut builder = FunctionBuilder::new(source.name().to_owned(), entry_point);
builder.set_runtime(source.runtime());
builder.current_function.set_globals(source.dfg.globals.clone());

Self {
builder,
recursion_level: 0,
entry_point,
call_stack: CallStackId::root(),
inline_no_predicates_functions,
functions_not_to_inline,
}
Self { builder, recursion_level: 0, entry_point, call_stack: CallStackId::root() }
}

/// Start inlining the entry point function and all functions reachable from it.
fn inline_all(mut self, ssa: &Ssa) -> Function {
fn inline_all(
mut self,
ssa: &Ssa,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
) -> Function {
let entry_point = &ssa.functions[&self.entry_point];

// let globals = self.globals;
let mut context = PerFunctionContext::new(&mut self, entry_point, &ssa.globals);
context.inlining_entry = true;

Expand All @@ -401,7 +401,7 @@ impl InlineContext {
}

context.blocks.insert(context.source_function.entry_block(), entry_block);
context.inline_blocks(ssa);
context.inline_blocks(ssa, should_inline_call);
// translate databus values
let databus = entry_point.dfg.data_bus.map_values(|t| context.translate_value(t));

Expand All @@ -420,6 +420,7 @@ impl InlineContext {
ssa: &Ssa,
id: FunctionId,
arguments: &[ValueId],
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
) -> Vec<ValueId> {
self.recursion_level += 1;

Expand All @@ -440,7 +441,7 @@ impl InlineContext {
let current_block = context.context.builder.current_block();
context.blocks.insert(source_function.entry_block(), current_block);

let return_values = context.inline_blocks(ssa);
let return_values = context.inline_blocks(ssa, should_inline_call);
self.recursion_level -= 1;
return_values
}
Expand Down Expand Up @@ -568,7 +569,11 @@ impl<'function> PerFunctionContext<'function> {
}

/// Inline all reachable blocks within the source_function into the destination function.
fn inline_blocks(&mut self, ssa: &Ssa) -> Vec<ValueId> {
fn inline_blocks(
&mut self,
ssa: &Ssa,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
) -> Vec<ValueId> {
let mut seen_blocks = HashSet::new();
let mut block_queue = VecDeque::new();
block_queue.push_back(self.source_function.entry_block());
Expand All @@ -585,7 +590,7 @@ impl<'function> PerFunctionContext<'function> {
self.context.builder.switch_to_block(translated_block_id);

seen_blocks.insert(source_block_id);
self.inline_block_instructions(ssa, source_block_id);
self.inline_block_instructions(ssa, source_block_id, should_inline_call);

if let Some((block, values)) =
self.handle_terminator_instruction(source_block_id, &mut block_queue)
Expand Down Expand Up @@ -630,16 +635,21 @@ impl<'function> PerFunctionContext<'function> {

/// Inline each instruction in the given block into the function being inlined into.
/// This may recurse if it finds another function to inline if a call instruction is within this block.
fn inline_block_instructions(&mut self, ssa: &Ssa, block_id: BasicBlockId) {
fn inline_block_instructions(
&mut self,
ssa: &Ssa,
block_id: BasicBlockId,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
) {
let mut side_effects_enabled: Option<ValueId> = None;

let block = &self.source_function.dfg[block_id];
for id in block.instructions() {
match &self.source_function.dfg[*id] {
Instruction::Call { func, arguments } => match self.get_function(*func) {
Some(func_id) => {
if self.should_inline_call(ssa, func_id) {
self.inline_function(ssa, *id, func_id, arguments);
if should_inline_call(self, ssa, func_id) {
self.inline_function(ssa, *id, func_id, arguments, should_inline_call);

// This is only relevant during handling functions with `InlineType::NoPredicates` as these
// can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`,
Expand Down Expand Up @@ -667,31 +677,14 @@ impl<'function> PerFunctionContext<'function> {
}
}

fn should_inline_call(&self, ssa: &Ssa, called_func_id: FunctionId) -> bool {
let function = &ssa.functions[&called_func_id];

if let RuntimeType::Acir(inline_type) = function.runtime() {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!self.context.inline_no_predicates_functions && function.is_no_predicates();
!inline_type.is_entry_point() && !preserve_function
} else {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
matches!(ssa.functions[&self.context.entry_point].runtime(), RuntimeType::Brillig(_))
&& !self.context.functions_not_to_inline.contains(&called_func_id)
}
}

/// Inline a function call and remember the inlined return values in the values map
fn inline_function(
&mut self,
ssa: &Ssa,
call_id: InstructionId,
function: FunctionId,
arguments: &[ValueId],
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
) {
let old_results = self.source_function.dfg.instruction_results(call_id);
let arguments = vecmap(arguments, |arg| self.translate_value(*arg));
Expand All @@ -707,7 +700,8 @@ impl<'function> PerFunctionContext<'function> {
.extend_call_stack(self.context.call_stack, &call_stack);

self.context.call_stack = new_call_stack;
let new_results = self.context.inline_function(ssa, function, &arguments);
let new_results =
self.context.inline_function(ssa, function, &arguments, should_inline_call);
self.context.call_stack = self
.context
.builder
Expand Down
Loading