From c17e228f98e5714fb3a2b56ff4c0a2aa469e77c4 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 23 Jan 2025 16:45:12 -0300 Subject: [PATCH] feat: inline simple functions (#7160) Co-authored-by: Maxim Vezenov --- compiler/noirc_evaluator/src/ssa.rs | 8 +- .../noirc_evaluator/src/ssa/opt/inlining.rs | 103 +++++++++++++++++- 2 files changed, 107 insertions(+), 4 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 12ea04daebd..4cefce1d647 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -152,6 +152,8 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result Result Result Ssa { + let should_inline_call = |callee: &Function| { + if let RuntimeType::Acir(_) = callee.runtime() { + // Functions marked to not have predicates should be preserved. + if callee.is_no_predicates() { + return false; + } + } + + let entry_block_id = callee.entry_block(); + let entry_block = &callee.dfg[entry_block_id]; + + // Only inline functions with a single block + if entry_block.successors().next().is_some() { + return false; + } + + // Only inline functions with 0 or 1 instructions + entry_block.instructions().len() <= 1 + }; + + self.functions = btree_map(self.functions.iter(), |(id, function)| { + (*id, function.inlined(&self, &should_inline_call)) + }); + + self + } } impl Function { @@ -185,7 +213,7 @@ pub(super) struct InlineInfo { is_brillig_entry_point: bool, is_acir_entry_point: bool, is_recursive: bool, - should_inline: bool, + pub(super) should_inline: bool, weight: i64, cost: i64, } @@ -1123,6 +1151,7 @@ mod test { map::Id, types::{NumericType, Type}, }, + opt::assert_normalized_ssa_equals, Ssa, }; @@ -1597,4 +1626,76 @@ mod test { ); assert!(tws[3] > max(tws[1], tws[2]), "ideally 'main' has the most weight"); } + + #[test] + fn inline_simple_functions_with_zero_instructions() { + let src = " + acir(inline) fn main f0 { + b0(v0: Field): + v2 = call f1(v0) -> Field + v3 = call f1(v0) -> Field + v4 = add v2, v3 + return v4 + } + + acir(inline) fn foo f1 { + b0(v0: Field): + return v0 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(v0: Field): + v1 = add v0, v0 + return v1 + } + acir(inline) fn foo f1 { + b0(v0: Field): + return v0 + } + "; + + let ssa = ssa.inline_simple_functions(); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inline_simple_functions_with_one_instruction() { + let src = " + acir(inline) fn main f0 { + b0(v0: Field): + v2 = call f1(v0) -> Field + v3 = call f1(v0) -> Field + v4 = add v2, v3 + return v4 + } + + acir(inline) fn foo f1 { + b0(v0: Field): + v2 = add v0, Field 1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(v0: Field): + v2 = add v0, Field 1 + v3 = add v0, Field 1 + v4 = add v2, v3 + return v4 + } + acir(inline) fn foo f1 { + b0(v0: Field): + v2 = add v0, Field 1 + return v2 + } + "; + + let ssa = ssa.inline_simple_functions(); + assert_normalized_ssa_equals(ssa, expected); + } }