Skip to content

Commit

Permalink
feat: inline simple functions (#7160)
Browse files Browse the repository at this point in the history
Co-authored-by: Maxim Vezenov <[email protected]>
  • Loading branch information
asterite and vezenovm authored Jan 23, 2025
1 parent 2d415ca commit c17e228
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 4 deletions.
8 changes: 5 additions & 3 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,16 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result<Ss
Ok(builder
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (1st)")
.run_pass(Ssa::defunctionalize, "Defunctionalization")
.run_pass(Ssa::inline_simple_functions, "Inlining simple functions")
.run_pass(Ssa::mem2reg, "Mem2Reg (1st)")
.run_pass(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs")
.run_pass(
|ssa| ssa.preprocess_functions(options.inliner_aggressiveness),
"Preprocessing Functions",
)
.run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "Inlining (1st)")
// Run mem2reg with the CFG separated into blocks
.run_pass(Ssa::mem2reg, "Mem2Reg (1st)")
.run_pass(Ssa::mem2reg, "Mem2Reg (2nd)")
.run_pass(Ssa::simplify_cfg, "Simplifying (1st)")
.run_pass(Ssa::as_slice_optimization, "`as_slice` optimization")
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (2nd)")
Expand All @@ -173,11 +175,11 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result<Ss
"Unrolling",
)?
.run_pass(Ssa::simplify_cfg, "Simplifying (2nd)")
.run_pass(Ssa::mem2reg, "Mem2Reg (2nd)")
.run_pass(Ssa::mem2reg, "Mem2Reg (3rd)")
.run_pass(Ssa::flatten_cfg, "Flattening")
.run_pass(Ssa::remove_bit_shifts, "Removing Bit Shifts")
// Run mem2reg once more with the flattened CFG to catch any remaining loads/stores
.run_pass(Ssa::mem2reg, "Mem2Reg (3rd)")
.run_pass(Ssa::mem2reg, "Mem2Reg (4th)")
// Run the inlining pass again to handle functions with `InlineType::NoPredicates`.
// Before flattening is run, we treat functions marked with the `InlineType::NoPredicates` as an entry point.
// This pass must come immediately following `mem2reg` as the succeeding passes
Expand Down
103 changes: 102 additions & 1 deletion compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,34 @@ impl Ssa {
});
self
}

pub(crate) fn inline_simple_functions(mut self: Ssa) -> Ssa {
let should_inline_call = |callee: &Function| {
if let RuntimeType::Acir(_) = callee.runtime() {
// Functions marked to not have predicates should be preserved.
if callee.is_no_predicates() {
return false;
}
}

let entry_block_id = callee.entry_block();
let entry_block = &callee.dfg[entry_block_id];

// Only inline functions with a single block
if entry_block.successors().next().is_some() {
return false;
}

// Only inline functions with 0 or 1 instructions
entry_block.instructions().len() <= 1
};

self.functions = btree_map(self.functions.iter(), |(id, function)| {
(*id, function.inlined(&self, &should_inline_call))
});

self
}
}

impl Function {
Expand Down Expand Up @@ -185,7 +213,7 @@ pub(super) struct InlineInfo {
is_brillig_entry_point: bool,
is_acir_entry_point: bool,
is_recursive: bool,
should_inline: bool,
pub(super) should_inline: bool,
weight: i64,
cost: i64,
}
Expand Down Expand Up @@ -1123,6 +1151,7 @@ mod test {
map::Id,
types::{NumericType, Type},
},
opt::assert_normalized_ssa_equals,
Ssa,
};

Expand Down Expand Up @@ -1597,4 +1626,76 @@ mod test {
);
assert!(tws[3] > max(tws[1], tws[2]), "ideally 'main' has the most weight");
}

#[test]
fn inline_simple_functions_with_zero_instructions() {
let src = "
acir(inline) fn main f0 {
b0(v0: Field):
v2 = call f1(v0) -> Field
v3 = call f1(v0) -> Field
v4 = add v2, v3
return v4
}
acir(inline) fn foo f1 {
b0(v0: Field):
return v0
}
";
let ssa = Ssa::from_str(src).unwrap();

let expected = "
acir(inline) fn main f0 {
b0(v0: Field):
v1 = add v0, v0
return v1
}
acir(inline) fn foo f1 {
b0(v0: Field):
return v0
}
";

let ssa = ssa.inline_simple_functions();
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn inline_simple_functions_with_one_instruction() {
let src = "
acir(inline) fn main f0 {
b0(v0: Field):
v2 = call f1(v0) -> Field
v3 = call f1(v0) -> Field
v4 = add v2, v3
return v4
}
acir(inline) fn foo f1 {
b0(v0: Field):
v2 = add v0, Field 1
return v2
}
";
let ssa = Ssa::from_str(src).unwrap();

let expected = "
acir(inline) fn main f0 {
b0(v0: Field):
v2 = add v0, Field 1
v3 = add v0, Field 1
v4 = add v2, v3
return v4
}
acir(inline) fn foo f1 {
b0(v0: Field):
v2 = add v0, Field 1
return v2
}
";

let ssa = ssa.inline_simple_functions();
assert_normalized_ssa_equals(ssa, expected);
}
}

0 comments on commit c17e228

Please sign in to comment.