diff --git a/dag_in_context/src/add_context.rs b/dag_in_context/src/add_context.rs index e3ea66580..b76c14e25 100644 --- a/dag_in_context/src/add_context.rs +++ b/dag_in_context/src/add_context.rs @@ -1,27 +1,35 @@ //! Adds context to the tree program. //! The `add_context` method recursively adds context to all of the nodes in the tree program //! by remembering the most recent context (ex. DoWhile or If). +//! Mantains the sharing invariant (see restore_sharing_invariant) by using a cache. + +use std::collections::HashMap; use crate::{ ast::{in_context, infunc}, schema::{Assumption, Expr, RcExpr, TreeProgram}, + schema_helpers::AssumptionRef, }; +struct ContextCache { + with_ctx: HashMap<(*const Expr, AssumptionRef), RcExpr>, +} + impl TreeProgram { pub fn add_context(&self) -> TreeProgram { TreeProgram { functions: self .functions .iter() - .map(|f| f.clone().func_add_context()) + .map(|f| f.clone().func_add_ctx()) .collect(), - entry: self.entry.clone().func_add_context(), + entry: self.entry.clone().func_add_ctx(), } } } impl Expr { - pub(crate) fn func_add_context(self: RcExpr) -> RcExpr { + pub(crate) fn func_add_ctx(self: RcExpr) -> RcExpr { let Expr::Function(name, arg_ty, ret_ty, body) = &self.as_ref() else { panic!("Expected Function, got {:?}", self); }; @@ -30,78 +38,105 @@ impl Expr { name.clone(), arg_ty.clone(), ret_ty.clone(), - body.add_context(current_ctx), + body.add_ctx(current_ctx), )) } - fn add_context(self: &RcExpr, current_ctx: Assumption) -> RcExpr { - match self.as_ref() { + pub(crate) fn add_ctx(self: &RcExpr, current_ctx: Assumption) -> RcExpr { + let mut cache = ContextCache { + with_ctx: HashMap::new(), + }; + self.add_ctx_with_cache(current_ctx, &mut cache) + } + + fn add_ctx_with_cache( + self: &RcExpr, + current_ctx: Assumption, + cache: &mut ContextCache, + ) -> RcExpr { + let ctx_ref = current_ctx.to_ref(); + if let Some(expr) = cache + .with_ctx + .get(&((*self).as_ref() as *const Expr, ctx_ref.clone())) + { + return expr.clone(); + } + let res = match self.as_ref() { // leaf nodes are constant, empty, and arg Expr::Const(..) | Expr::Empty(..) | Expr::Arg(..) => { in_context(current_ctx, self.clone()) } // create new contexts for let, loop, and if Expr::DoWhile(inputs, pred_and_body) => { - let new_inputs = inputs.add_context(current_ctx.clone()); + let new_inputs = inputs.add_ctx_with_cache(current_ctx.clone(), cache); let new_ctx = Assumption::InLoop(new_inputs.clone(), pred_and_body.clone()); RcExpr::new(Expr::DoWhile( new_inputs, - pred_and_body.add_context(new_ctx), + pred_and_body.add_ctx_with_cache(new_ctx, cache), )) } Expr::If(pred, input, then_case, else_calse) => { - let new_pred = pred.add_context(current_ctx.clone()); - let new_input = input.add_context(current_ctx.clone()); - let then_ctx = Assumption::InIf(true, new_pred.clone()); - let else_ctx = Assumption::InIf(false, new_pred.clone()); + let new_pred = pred.add_ctx_with_cache(current_ctx.clone(), cache); + let new_input = input.add_ctx_with_cache(current_ctx.clone(), cache); + let then_ctx = Assumption::InIf(true, new_pred.clone(), new_input.clone()); + let else_ctx = Assumption::InIf(false, new_pred.clone(), new_input.clone()); RcExpr::new(Expr::If( new_pred, new_input, - then_case.add_context(then_ctx), - else_calse.add_context(else_ctx), + then_case.add_ctx_with_cache(then_ctx, cache), + else_calse.add_ctx_with_cache(else_ctx, cache), )) } Expr::Switch(case_num, input, branches) => { - let new_case_num = case_num.add_context(current_ctx.clone()); - let new_input = input.add_context(current_ctx.clone()); + let new_case_num = case_num.add_ctx_with_cache(current_ctx.clone(), cache); + let new_input = input.add_ctx_with_cache(current_ctx.clone(), cache); let new_branches = branches .iter() - .map(|b| b.add_context(current_ctx.clone())) + .map(|b| b.add_ctx_with_cache(current_ctx.clone(), cache)) .collect(); RcExpr::new(Expr::Switch(new_case_num, new_input, new_branches)) } // for all other nodes, just add the context to the children Expr::Bop(op, x, y) => RcExpr::new(Expr::Bop( op.clone(), - x.add_context(current_ctx.clone()), - y.add_context(current_ctx), + x.add_ctx_with_cache(current_ctx.clone(), cache), + y.add_ctx_with_cache(current_ctx, cache), )), Expr::Top(op, x, y, z) => RcExpr::new(Expr::Top( op.clone(), - x.add_context(current_ctx.clone()), - y.add_context(current_ctx.clone()), - z.add_context(current_ctx), + x.add_ctx_with_cache(current_ctx.clone(), cache), + y.add_ctx_with_cache(current_ctx.clone(), cache), + z.add_ctx_with_cache(current_ctx, cache), )), - Expr::Uop(op, x) => RcExpr::new(Expr::Uop(op.clone(), x.add_context(current_ctx))), - Expr::Get(e, i) => RcExpr::new(Expr::Get(e.add_context(current_ctx), *i)), + Expr::Uop(op, x) => RcExpr::new(Expr::Uop( + op.clone(), + x.add_ctx_with_cache(current_ctx, cache), + )), + Expr::Get(e, i) => RcExpr::new(Expr::Get(e.add_ctx_with_cache(current_ctx, cache), *i)), Expr::Alloc(id, e, state, ty) => RcExpr::new(Expr::Alloc( *id, - e.add_context(current_ctx.clone()), - state.add_context(current_ctx), + e.add_ctx_with_cache(current_ctx.clone(), cache), + state.add_ctx_with_cache(current_ctx, cache), ty.clone(), )), - Expr::Call(f, arg) => { - RcExpr::new(Expr::Call(f.clone(), arg.add_context(current_ctx.clone()))) - } - Expr::Single(e) => RcExpr::new(Expr::Single(e.add_context(current_ctx))), + Expr::Call(f, arg) => RcExpr::new(Expr::Call( + f.clone(), + arg.add_ctx_with_cache(current_ctx.clone(), cache), + )), + Expr::Single(e) => RcExpr::new(Expr::Single(e.add_ctx_with_cache(current_ctx, cache))), Expr::Concat(x, y) => RcExpr::new(Expr::Concat( - x.add_context(current_ctx.clone()), - y.add_context(current_ctx), + x.add_ctx_with_cache(current_ctx.clone(), cache), + y.add_ctx_with_cache(current_ctx, cache), )), Expr::InContext(..) => { panic!("add_context expects a term without context") } Expr::Function(..) => panic!("Function should have been handled in func_add_context"), - } + }; + cache.with_ctx.insert( + (self.as_ref() as *const Expr, ctx_ref), + res.clone(), + ); + res } } diff --git a/dag_in_context/src/ast.rs b/dag_in_context/src/ast.rs index 53d70c03b..ab705aa36 100644 --- a/dag_in_context/src/ast.rs +++ b/dag_in_context/src/ast.rs @@ -313,8 +313,8 @@ pub fn inloop(e1: RcExpr, e2: RcExpr) -> Assumption { Assumption::InLoop(e1, e2) } -pub fn inif(is_then: bool, pred: RcExpr) -> Assumption { - Assumption::InIf(is_then, pred) +pub fn inif(is_then: bool, pred: RcExpr, input: RcExpr) -> Assumption { + Assumption::InIf(is_then, pred, input) } pub fn infunc(name: &str) -> Assumption { diff --git a/dag_in_context/src/from_egglog.rs b/dag_in_context/src/from_egglog.rs index 455ce522e..69c7cb9ea 100644 --- a/dag_in_context/src/from_egglog.rs +++ b/dag_in_context/src/from_egglog.rs @@ -129,12 +129,12 @@ impl FromEgglog { }; Assumption::InFunc(string.to_string()) } - ("InIf", [is_then, expr]) => { + ("InIf", [is_then, pred_expr, input_expr]) => { let Term::Lit(Literal::Bool(boolean)) = self.termdag.get(*is_then) else { panic!("Invalid boolean: {:?}", is_then) }; - Assumption::InIf(boolean, self.expr_from_egglog(self.termdag.get(*expr))) + Assumption::InIf(boolean, self.expr_from_egglog(self.termdag.get(*pred_expr)), self.expr_from_egglog(self.termdag.get(*input_expr))) } _ => panic!("Invalid assumption: {:?}", assumption), }) @@ -354,7 +354,7 @@ impl FromEgglog { self.program_from_egglog_preserve_ctx_nodes(new_term) } - fn without_ctx_nodes(&mut self, expr: Term) -> Term { + pub fn without_ctx_nodes(&mut self, expr: Term) -> Term { match expr { Term::App(head, children) => match (head.to_string().as_str(), children.as_slice()) { ("InContext", [_assumption, expr]) => { diff --git a/dag_in_context/src/interval_analysis.egg b/dag_in_context/src/interval_analysis.egg index c09524ee2..8192a70ad 100644 --- a/dag_in_context/src/interval_analysis.egg +++ b/dag_in_context/src/interval_analysis.egg @@ -79,7 +79,7 @@ ; Conditionals ; ================================= (rule ( - (= lhs (If cond thn els)) + (= lhs (If cond inputs thn els)) (= thn-ival (ival thn)) (= els-ival (ival els)) ) diff --git a/dag_in_context/src/lib.rs b/dag_in_context/src/lib.rs index 1091581ec..4377258f1 100644 --- a/dag_in_context/src/lib.rs +++ b/dag_in_context/src/lib.rs @@ -37,7 +37,8 @@ pub fn prologue() -> String { &optimizations::is_valid::rules().join("\n"), &optimizations::body_contains::rules().join("\n"), &optimizations::purity_analysis::rules().join("\n"), - &optimizations::conditional_invariant_code_motion::rules().join("\n"), + // TODO cond inv code motion with regions + //&optimizations::conditional_invariant_code_motion::rules().join("\n"), include_str!("utility/in_context.egg"), include_str!("utility/context-prop.egg"), include_str!("utility/subst.egg"), diff --git a/dag_in_context/src/optimizations/body_contains.rs b/dag_in_context/src/optimizations/body_contains.rs index eeaeab6e5..c83ce7b08 100644 --- a/dag_in_context/src/optimizations/body_contains.rs +++ b/dag_in_context/src/optimizations/body_contains.rs @@ -31,13 +31,15 @@ fn captured_expr_rule_for_ctor(ctor: Constructor) -> Option { fn subexpr_rule_for_ctor(ctor: Constructor) -> Option { let pat = ctor.construct(|field| field.var()); let actions = ctor.filter_map_fields(|field| { - (field.purpose == Purpose::SubExpr || field.purpose == Purpose::SubListExpr).then(|| { - format!( - "(BodyContains{sort} body {e})", - sort = field.sort().name(), - e = field.var() - ) - }) + (field.purpose == Purpose::SubExpr || field.purpose == Purpose::CapturedSubListExpr).then( + || { + format!( + "(BodyContains{sort} body {e})", + sort = field.sort().name(), + e = field.var() + ) + }, + ) }); (!actions.is_empty()).then(|| { format!( diff --git a/dag_in_context/src/optimizations/conditional_invariant_code_motion.rs b/dag_in_context/src/optimizations/conditional_invariant_code_motion.rs index 2b3acfa81..35cfb7fba 100644 --- a/dag_in_context/src/optimizations/conditional_invariant_code_motion.rs +++ b/dag_in_context/src/optimizations/conditional_invariant_code_motion.rs @@ -1,7 +1,9 @@ -use crate::schema_helpers::{Constructor, ESort, Purpose}; +/*use crate::schema_helpers::{Constructor, ESort, Purpose}; use std::iter; use strum::IntoEnumIterator; +// TODO implement now that we have if regions + fn rules_for_ctor(ctor: Constructor) -> Option { use Constructor::*; if [DoWhile, InContext].contains(&ctor) || ctor.sort() != ESort::Expr { @@ -53,7 +55,7 @@ fn rules_for_ctor(ctor: Constructor) -> Option { ((union (Switch pred exprs) {resulting_switch})) :ruleset conditional-invariant-code-motion) - + (rewrite (If c {ctor_pattern1} {ctor_pattern2}) {resulting_if} :when ((ExprIsValid (If c {ctor_pattern1} {ctor_pattern2}))) @@ -147,3 +149,4 @@ fn test_lift_if() -> crate::Result { vec![], ) } +*/ diff --git a/dag_in_context/src/optimizations/is_valid.rs b/dag_in_context/src/optimizations/is_valid.rs index 9097dc256..76e7d4476 100644 --- a/dag_in_context/src/optimizations/is_valid.rs +++ b/dag_in_context/src/optimizations/is_valid.rs @@ -4,7 +4,7 @@ use strum::IntoEnumIterator; fn rule_for_ctor(ctor: Constructor) -> Option { let actions = ctor.filter_map_fields(|field| match field.purpose { Purpose::Static(_) => None, - Purpose::CapturedExpr | Purpose::SubExpr | Purpose::SubListExpr => Some(format!( + Purpose::CapturedExpr | Purpose::SubExpr | Purpose::CapturedSubListExpr => Some(format!( "({sort}IsValid {var})", sort = field.sort().name(), var = field.var() diff --git a/dag_in_context/src/optimizations/loop_invariant.rs b/dag_in_context/src/optimizations/loop_invariant.rs index 3bb7fe11a..5df3b4b9b 100644 --- a/dag_in_context/src/optimizations/loop_invariant.rs +++ b/dag_in_context/src/optimizations/loop_invariant.rs @@ -51,7 +51,7 @@ fn is_invariant_rule_for_ctor(ctor: Constructor) -> Option { let is_inv_ctor = ctor .filter_map_fields(|field| match field.purpose { Purpose::Static(_) | Purpose::CapturedExpr => None, - Purpose::SubExpr | Purpose::SubListExpr => { + Purpose::SubExpr | Purpose::CapturedSubListExpr => { let var = field.var(); let sort = field.sort().name(); Some(format!("(= true (is-inv-{sort} loop {var}))")) diff --git a/dag_in_context/src/optimizations/purity_analysis.rs b/dag_in_context/src/optimizations/purity_analysis.rs index 0ce989d2e..cda5690f0 100644 --- a/dag_in_context/src/optimizations/purity_analysis.rs +++ b/dag_in_context/src/optimizations/purity_analysis.rs @@ -39,7 +39,7 @@ fn purity_rules_for_ctor(ctor: Constructor) -> String { Purpose::Static(Sort::BinaryOp) | Purpose::Static(Sort::UnaryOp) | Purpose::SubExpr - | Purpose::SubListExpr + | Purpose::CapturedSubListExpr | Purpose::CapturedExpr => Some(format!( "({sort}IsPure {var})", sort = field.sort().name(), diff --git a/dag_in_context/src/optimizations/switch_rewrites.egg b/dag_in_context/src/optimizations/switch_rewrites.egg index 6eb83c6d3..808eacebc 100644 --- a/dag_in_context/src/optimizations/switch_rewrites.egg +++ b/dag_in_context/src/optimizations/switch_rewrites.egg @@ -1,19 +1,21 @@ (ruleset switch_rewrite) -(rewrite (If (Bop (And) a b) X Y) - (If a (If b X Y) Y) - :when ((ExprIsPure b)) - :ruleset switch_rewrite) -(rewrite (If (Bop (Or) a b) X Y) - (If a X (If b X Y)) - :when ((ExprIsPure b)) - :ruleset switch_rewrite) - -(rewrite (If (Const (Bool true) ty) thn els) - thn - :ruleset switch_rewrite) - -(rewrite (If (Const (Bool false) ty) thn els) - els - :ruleset switch_rewrite) +;; TODO rewrite for if regions +;;(rewrite (If (Bop (And) a b) X Y) +;; (If a (If b X Y) Y) +;; :when ((ExprIsPure b)) +;; :ruleset switch_rewrite) +;; +;;(rewrite (If (Bop (Or) a b) X Y) +;; (If a X (If b X Y)) +;; :when ((ExprIsPure b)) +;; :ruleset switch_rewrite) +;; +;;(rewrite (If (Const (Bool true) ty) thn els) +;; thn +;; :ruleset switch_rewrite) +;; +;;(rewrite (If (Const (Bool false) ty) thn els) +;; els +;; :ruleset switch_rewrite) diff --git a/dag_in_context/src/schedule.egg b/dag_in_context/src/schedule.egg index 4ad76a72f..d0e2bac25 100644 --- a/dag_in_context/src/schedule.egg +++ b/dag_in_context/src/schedule.egg @@ -13,7 +13,8 @@ (saturate context-prop) (saturate context-helpers) context)) - conditional-invariant-code-motion + ;; TODO enable when conditional inv code motion works again + ;;conditional-invariant-code-motion switch_rewrite loop-simplify )) diff --git a/dag_in_context/src/schema.egg b/dag_in_context/src/schema.egg index 0415115bb..c3e7e5bdf 100644 --- a/dag_in_context/src/schema.egg +++ b/dag_in_context/src/schema.egg @@ -165,8 +165,8 @@ (InFunc String) ; Branch of the switch and what the predicate is (InSwitch i64 Expr) - ; If the predicate was true, and what the predicate is - (InIf bool Expr) + ; If the predicate was true, and what the predicate is, and what the input is + (InIf bool Expr Expr) ; Other assumptions are possible, but not supported yet. ; For example: ; A boolean predicate is true. diff --git a/dag_in_context/src/schema.rs b/dag_in_context/src/schema.rs index ef4b829b5..b8446706e 100644 --- a/dag_in_context/src/schema.rs +++ b/dag_in_context/src/schema.rs @@ -73,9 +73,10 @@ pub type RcExpr = Rc; pub enum Assumption { InLoop(RcExpr, RcExpr), InFunc(String), - InIf(bool, RcExpr), + InIf(bool, RcExpr, RcExpr), } + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub enum Expr { Const(Constant, Type), diff --git a/dag_in_context/src/schema_helpers.rs b/dag_in_context/src/schema_helpers.rs index dc2ab9552..489ca679b 100644 --- a/dag_in_context/src/schema_helpers.rs +++ b/dag_in_context/src/schema_helpers.rs @@ -1,9 +1,12 @@ -use std::fmt::{Display, Formatter}; +use std::{ + fmt::{Display, Formatter}, + rc::Rc, +}; use strum_macros::EnumIter; use crate::{ ast::{base, boolt, intt}, - schema::{BinaryOp, Constant, Expr, RcExpr, TernaryOp, TreeProgram, Type, UnaryOp}, + schema::{Assumption, BinaryOp, Constant, Expr, RcExpr, TernaryOp, TreeProgram, Type, UnaryOp}, }; /// Display for Constant implements a @@ -294,10 +297,10 @@ pub enum Constructor { #[derive(Clone, Copy, Debug, PartialEq)] pub(crate) enum Purpose { - Static(Sort), // some int, bool, order that parameterizes constructor - SubExpr, // subexpression, e.g. Add's summand - SubListExpr, // sublistexpr, e.g. Switch's branch lsit - CapturedExpr, // a body's outputs + Static(Sort), // some int, bool, order that parameterizes constructor + SubExpr, // subexpression, e.g. Add's summand + CapturedSubListExpr, // a swtich's branches + CapturedExpr, // a body's outputs } impl Purpose { @@ -305,7 +308,7 @@ impl Purpose { match self { Purpose::SubExpr => Sort::Expr, Purpose::CapturedExpr => Sort::Expr, - Purpose::SubListExpr => Sort::ListExpr, + Purpose::CapturedSubListExpr => Sort::ListExpr, Purpose::Static(sort) => sort, } } @@ -352,7 +355,7 @@ impl Constructor { } pub(crate) fn fields(&self) -> Vec { - use Purpose::{CapturedExpr, Static, SubExpr, SubListExpr}; + use Purpose::{CapturedExpr, CapturedSubListExpr, Static, SubExpr}; let f = |purpose, name| Field { purpose, name }; match self { Constructor::Function => { @@ -388,10 +391,19 @@ impl Constructor { vec![f(SubExpr, "x")] } Constructor::Switch => { - vec![f(SubExpr, "pred"), f(SubListExpr, "branches")] + vec![ + f(SubExpr, "pred"), + f(SubExpr, "inputs"), + f(CapturedSubListExpr, "branches"), + ] } Constructor::If => { - vec![f(SubExpr, "pred"), f(SubExpr, "then"), f(SubExpr, "else")] + vec![ + f(SubExpr, "pred"), + f(SubExpr, "input"), + f(CapturedExpr, "then"), + f(CapturedExpr, "else"), + ] } Constructor::DoWhile => { vec![f(SubExpr, "in"), f(CapturedExpr, "pred-and-output")] @@ -401,7 +413,7 @@ impl Constructor { vec![f(Static(Sort::String), "func"), f(SubExpr, "arg")] } Constructor::Empty => vec![f(Static(Sort::Type), "ty")], - Constructor::Cons => vec![f(SubExpr, "hd"), f(SubListExpr, "tl")], + Constructor::Cons => vec![f(SubExpr, "hd"), f(CapturedSubListExpr, "tl")], Constructor::Nil => vec![], Constructor::Alloc => vec![ f(Static(Sort::I64), "id"), @@ -485,3 +497,25 @@ impl UnaryOp { } } } + +/// used to hash an assumption +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum AssumptionRef { + InLoop(*const Expr, *const Expr), + InFunc(String), + InIf(bool, *const Expr, *const Expr), +} + +impl Assumption { + pub fn to_ref(&self) -> AssumptionRef { + match self { + Assumption::InLoop(inputs, pred_and_body) => { + AssumptionRef::InLoop(Rc::as_ptr(inputs), Rc::as_ptr(pred_and_body)) + } + Assumption::InFunc(name) => AssumptionRef::InFunc(name.clone()), + Assumption::InIf(b, pred, input) => { + AssumptionRef::InIf(*b, Rc::as_ptr(pred), Rc::as_ptr(input)) + } + } + } +} diff --git a/dag_in_context/src/to_egglog.rs b/dag_in_context/src/to_egglog.rs index 7ae5f98d4..9a98df02f 100644 --- a/dag_in_context/src/to_egglog.rs +++ b/dag_in_context/src/to_egglog.rs @@ -109,10 +109,11 @@ impl Assumption { let name_lit = term_dag.lit(Literal::String(name.into())); term_dag.app("InFunc".into(), vec![name_lit]) } - Assumption::InIf(is_then, pred) => { + Assumption::InIf(is_then, pred, input) => { let pred = pred.to_egglog_internal(term_dag); let is_then = term_dag.lit(Literal::Bool(*is_then)); - term_dag.app("InIf".into(), vec![is_then, pred]) + let input = input.to_egglog_internal(term_dag); + term_dag.app("InIf".into(), vec![is_then, pred, input]) } } } @@ -343,6 +344,7 @@ fn convert_to_egglog_switch() { test_expr_parses_to( expr, "(Switch (Const (Int 1) (Base (IntT))) + (Empty (Base (IntT))) (Cons (Concat (Single (Const (Int 1) (Base (IntT)))) (Single (Const (Int 2) (Base (IntT))))) (Cons diff --git a/dag_in_context/src/type_analysis.egg b/dag_in_context/src/type_analysis.egg index 4af64c271..2e1d3c528 100644 --- a/dag_in_context/src/type_analysis.egg +++ b/dag_in_context/src/type_analysis.egg @@ -84,39 +84,35 @@ (HasArgType e2 ty)) ((HasArgType lhs ty)) :ruleset type-analysis) -(rule ((HasArgType (Switch pred (Cons branch (Nil))) ty)) - ( - (HasArgType pred ty) - (HasArgType branch ty) - ) - :ruleset type-analysis) -(rule ((= lhs (Switch pred (Cons branch rest))) +(rule ((= lhs (Switch pred inputs (Cons branch rest))) (HasArgType pred ty)) ((HasArgType lhs ty)) :ruleset type-analysis) -(rule ((= lhs (Switch pred (Cons branch rest))) - (HasArgType branch ty)) - ((HasArgType lhs ty)) +(rule ((= lhs (Switch pred inputs (Cons branch rest))) + (HasArgType branch ty) + (HasType inputs ty2) + (!= ty ty2)) + ((panic "switch branches then branch has incorrect input type")) :ruleset type-analysis) ;; demand with one fewer branches -(rule ((= lhs (Switch pred (Cons branch rest)))) - ((Switch pred rest)) - :ruleset type-analysis) -(rule ((= lhs (Switch pred (Cons branch rest))) - (HasArgType (Switch pred rest) ty)) - ((HasArgType lhs ty)) +(rule ((= lhs (Switch pred inputs (Cons branch rest)))) + ((Switch pred inputs rest)) :ruleset type-analysis) -(rule ((= lhs (If c t e)) +(rule ((= lhs (If c i t e)) (HasArgType c ty)) ((HasArgType lhs ty)) :ruleset type-analysis) -(rule ((= lhs (If c t e)) - (HasArgType t ty)) - ((HasArgType lhs ty)) +(rule ((= lhs (If c i t e)) + (HasType i ty) + (HasArgType t ty2) + (!= ty ty2)) + ((panic "if branches then branch has incorrect input type")) :ruleset type-analysis) -(rule ((= lhs (If c t e)) - (HasArgType e ty)) - ((HasArgType lhs ty)) +(rule ((= lhs (If c i t e)) + (HasType i ty) + (HasArgType e ty2) + (!= ty ty2)) + ((panic "if branches else branch has incorrect input type")) :ruleset type-analysis) (rule ((= lhs (DoWhile ins body)) (HasArgType ins ty)) @@ -435,11 +431,11 @@ ; ================================= ; Control flow ; ================================= -(rule ((= lhs (If pred then else))) +(rule ((= lhs (If pred inputs then else))) ((ExpectType pred (Base (BoolT)) "If predicate must be boolean")) :ruleset type-analysis) (rule ( - (= lhs (If pred then else)) + (= lhs (If pred inputs then else)) (HasType pred (Base (BoolT))) (HasType then ty) (HasType else ty) @@ -448,7 +444,7 @@ :ruleset type-analysis) (rule ( - (= lhs (If pred then else)) + (= lhs (If pred inputs then else)) (HasType pred (Base (BoolT))) (HasType then tya) (HasType else tyb) @@ -459,13 +455,13 @@ -(rule ((= lhs (Switch pred branches))) +(rule ((= lhs (Switch pred inputs branches))) ((ExpectType pred (Base (IntT)) "Switch predicate must be integer")) :ruleset type-analysis) ; base case: single branch switch has type of branch (rule ( - (= lhs (Switch pred (Cons branch (Nil)))) + (= lhs (Switch pred inputs (Cons branch (Nil)))) (HasType pred (Base (IntT))) (HasType branch ty) ) @@ -473,24 +469,24 @@ :ruleset type-analysis) ; recursive case: peel off a layer -(rule ((Switch pred (Cons branch rest))) - ((Switch pred rest)) +(rule ((Switch pred inputs (Cons branch rest))) + ((Switch pred inputs rest)) :ruleset type-analysis) (rule ( - (= lhs (Switch pred (Cons branch rest))) + (= lhs (Switch pred inputs (Cons branch rest))) (HasType pred (Base (IntT))) (HasType branch ty) - (HasType (Switch pred rest) ty) ; rest of the branches also have type ty + (HasType (Switch pred inputs rest) ty) ; rest of the branches also have type ty ) ((HasType lhs ty)) :ruleset type-analysis) (rule ( - (= lhs (Switch pred (Cons branch rest))) + (= lhs (Switch pred inputs (Cons branch rest))) (HasType pred (Base (IntT))) (HasType branch tya) - (HasType (Switch pred rest) tyb) + (HasType (Switch pred inputs rest) tyb) (!= tya tyb) ) ((panic "switch branches had different types")) diff --git a/dag_in_context/src/typechecker.rs b/dag_in_context/src/typechecker.rs index b1514d93d..cc6f91646 100644 --- a/dag_in_context/src/typechecker.rs +++ b/dag_in_context/src/typechecker.rs @@ -148,9 +148,12 @@ impl<'a> TypeChecker<'a> { inloop(inputs_with_types, body_with_types) } Assumption::InFunc(name) => infunc(&name), - Assumption::InIf(branch, pred) => { - let pred_with_types = self.add_arg_types_to_expr(pred.clone(), arg_tys); - inif(branch, pred_with_types.1) + Assumption::InIf(branch, pred, input) => { + let outer_types = arg_tys.popped(); + let pred_with_types = self.add_arg_types_to_expr(pred.clone(), &outer_types); + let input_with_types = self.add_arg_types_to_expr(input.clone(), &outer_types); + + inif(branch, pred_with_types.1, input_with_types.1) } } } diff --git a/dag_in_context/src/utility/context-prop.egg b/dag_in_context/src/utility/context-prop.egg index 16d76841a..02d919210 100644 --- a/dag_in_context/src/utility/context-prop.egg +++ b/dag_in_context/src/utility/context-prop.egg @@ -37,8 +37,9 @@ (Concat (InContext ctx e1) (InContext ctx e2)) :ruleset context-prop) -(rewrite (InContext ctx (If cond then else)) - (If (InContext ctx cond) (InContext ctx then) (InContext ctx else)) +;; then and else are new regions +(rewrite (InContext ctx (If cond inputs then else)) + (If (InContext ctx cond) (InContext ctx inputs) then else) :ruleset context-prop) diff --git a/dag_in_context/src/utility/context_of.egg b/dag_in_context/src/utility/context_of.egg index aa1676bbe..648bfb506 100644 --- a/dag_in_context/src/utility/context_of.egg +++ b/dag_in_context/src/utility/context_of.egg @@ -46,11 +46,14 @@ (rule ((Single x) (ContextOf x ctx)) ((ContextOf (Single x) ctx)) :ruleset always-run) -(rule ((Switch pred branches) (ContextOf pred ctx)) - ((ContextOf (Switch pred branches) ctx)) :ruleset always-run) +(rule ((Switch pred inputs branches) (ContextOf pred ctx)) + ((ContextOf (Switch pred inputs branches) ctx)) :ruleset always-run) -(rule ((If pred then else) (ContextOf pred ctx)) - ((ContextOf (If pred then else) ctx)) :ruleset always-run) +(rule ((If pred inputs then else) (ContextOf pred ctx)) + ((ContextOf (If pred inputs then else) ctx)) :ruleset always-run) + +(rule ((If pred inputs then else) (ContextOf inputs ctx)) + ((ContextOf (If pred inputs then else) ctx)) :ruleset always-run) (rule ((DoWhile in pred-and-output) (ContextOf in ctx)) ((ContextOf (DoWhile in pred-and-output) ctx)) :ruleset always-run) diff --git a/dag_in_context/src/utility/context_of.rs b/dag_in_context/src/utility/context_of.rs index beee0fa7e..6cc56597a 100644 --- a/dag_in_context/src/utility/context_of.rs +++ b/dag_in_context/src/utility/context_of.rs @@ -4,22 +4,22 @@ fn test_context_of() -> crate::Result { // fn main(x): if x = 5 then x else 4 let pred = eq(arg(), int(5)); - let body = tif(pred, arg(), int(4)).with_arg_types(base(intt()), base(intt())); + let body = tif(pred, arg(), arg(), int(4)) + .with_arg_types(base(intt()), base(intt())) + .with_arg_types(base(intt()), base(intt())); + let body_with_context = body.clone().add_ctx(infunc("main")); let build = function("main", base(intt()), base(intt()), body.clone()) .func_with_arg_types() - .func_add_context(); + .func_add_ctx(); // If statement should have the context of its predicate - let check = " + let check = format!(" (let pred-ctx (InFunc \"main\")) (let pred (Bop (Eq) (InContext (InFunc \"main\") (Arg (Base (IntT)))) (InContext (InFunc \"main\") (Const (Int 5) (Base (IntT)))))) (check (ContextOf pred pred-ctx)) - (let if - (If pred - (InContext (InIf true (Bop (Eq) (InContext (InFunc \"main\") (Arg (Base (IntT)))) (InContext (InFunc \"main\") (Const (Int 5) (Base (IntT)))))) (Arg (Base (IntT)))) - (InContext (InIf false (Bop (Eq) (InContext (InFunc \"main\") (Arg (Base (IntT)))) (InContext (InFunc \"main\") (Const (Int 5) (Base (IntT)))))) (Const (Int 4) (Base (IntT)))))) + (let if {body_with_context}) (check (ContextOf if pred-ctx)) - ".to_string(); + "); crate::egglog_test( &format!("(let build {build})"), @@ -86,7 +86,7 @@ fn test_context_of_no_func_context() -> crate::Result { ), ) .func_with_arg_types() - .func_add_context(); + .func_add_ctx(); let check = format!("(fail (check (ContextOf {} ctx)))", build.clone()); diff --git a/dag_in_context/src/utility/in_context.egg b/dag_in_context/src/utility/in_context.egg index b22ca6ffd..97edc5bdd 100644 --- a/dag_in_context/src/utility/in_context.egg +++ b/dag_in_context/src/utility/in_context.egg @@ -160,21 +160,23 @@ ;; ########################################## Control flow -(rewrite (DoAddContext seen ctx scope (Switch pred branches)) +;; TODO when scope is full, add more context to the switch +(rewrite (DoAddContext seen ctx scope (Switch pred inputs branches)) (Switch (DoAddContext seen ctx scope pred) - (DoAddContextList seen ctx scope branches)) + (DoAddContext seen ctx scope inputs) + branches) :ruleset context-helpers) -(rule ((= lhs (DoAddContext seen ctx scope (If pred c1 c2)))) +;; TODO when scope is full, add more context to if +(rule ((= lhs (DoAddContext seen ctx scope (If pred inputs c1 c2)))) ((let newpred (DoAddContext seen ctx scope pred)) - (let newpath - (PathCons (If pred c1 c2) lhs seen)) (union lhs (If newpred - (DoAddContext newpath (InIf true newpred) scope c1) - (DoAddContext newpath (InIf false newpred) scope c2)))) + (DoAddContext seen ctx scope inputs) + c1 + c2))) :ruleset context) diff --git a/dag_in_context/src/utility/in_context.rs b/dag_in_context/src/utility/in_context.rs index abcaa8a60..1bd6c9a6f 100644 --- a/dag_in_context/src/utility/in_context.rs +++ b/dag_in_context/src/utility/in_context.rs @@ -24,7 +24,7 @@ fn test_in_context_two_loops() -> crate::Result { ) .func_with_arg_types(); - let with_context = expr.clone().func_add_context(); + let with_context = expr.clone().func_add_ctx(); egglog_test( &format!("(AddFuncContext {expr})"), diff --git a/dag_in_context/src/utility/subst.egg b/dag_in_context/src/utility/subst.egg index e3675a0b8..6288aafe4 100644 --- a/dag_in_context/src/utility/subst.egg +++ b/dag_in_context/src/utility/subst.egg @@ -33,11 +33,11 @@ ((union lhs (InContext assum (SubstLeaf to leaf)))) :ruleset subst) ;; modify inif context -(rule ((= lhs (Subst assum to (InContext (InIf branch pred) leaf))) +(rule ((= lhs (Subst assum to (InContext (InIf branch pred inputs) leaf))) (IsLeaf leaf)) ((union lhs (InContext - (InIf branch (Subst assum to pred)) + (InIf branch (Subst assum to pred) (Subst assum to inputs)) (SubstLeaf to leaf)))) :ruleset subst) @@ -84,14 +84,16 @@ ;; Control flow -(rewrite (Subst assum to (Switch pred branches)) +(rewrite (Subst assum to (Switch pred inputs branches)) (Switch (Subst assum to pred) - (SubstList assum to branches)) + (Subst assum to inputs) + branches) :ruleset subst) -(rewrite (Subst assum to (If pred c1 c2)) +(rewrite (Subst assum to (If pred inputs c1 c2)) (If (Subst assum to pred) - (Subst assum to c1) - (Subst assum to c2)) + (Subst assum to inputs) + c1 + c2) :ruleset subst) (rewrite (Subst assum to (DoWhile in out)) (DoWhile (Subst assum to in) diff --git a/dag_in_context/src/utility/subst.rs b/dag_in_context/src/utility/subst.rs index f55c7a6dc..5456ee09c 100644 --- a/dag_in_context/src/utility/subst.rs +++ b/dag_in_context/src/utility/subst.rs @@ -178,7 +178,7 @@ fn test_subst_preserves_context() -> crate::Result { let outer_if = tif(less_than(arg(), int(5)), arg(), int(1)); let expression = function("main", base(intt()), base(intt()), outer_if) .func_with_arg_types() - .func_add_context(); + .func_add_ctx(); let replace_with = int(5).with_arg_types(base(intt()), base(intt())); @@ -189,7 +189,7 @@ fn test_subst_preserves_context() -> crate::Result { tif(less_than(int(5), int(5)), int(5), int(1)), ) .func_with_arg_types() - .func_add_context(); + .func_add_ctx(); let build = format!( " diff --git a/dag_in_context/src/utility/util.egg b/dag_in_context/src/utility/util.egg index 0bf258e1b..20958b506 100644 --- a/dag_in_context/src/utility/util.egg +++ b/dag_in_context/src/utility/util.egg @@ -3,7 +3,7 @@ (function ListExpr-suffix (ListExpr i64) ListExpr :unextractable) (function Append (ListExpr Expr) ListExpr :unextractable) -(rule ((Switch pred branch)) ((union (ListExpr-suffix branch 0) branch)) :ruleset always-run) +(rule ((Switch pred inputs branch)) ((union (ListExpr-suffix branch 0) branch)) :ruleset always-run) (rule ((= (ListExpr-suffix top n) (Cons hd tl))) ((union (ListExpr-ith top n) hd) diff --git a/dag_in_context/src/utility/util.rs b/dag_in_context/src/utility/util.rs index c96f35db6..3edeb99ed 100644 --- a/dag_in_context/src/utility/util.rs +++ b/dag_in_context/src/utility/util.rs @@ -12,7 +12,7 @@ fn test_list_util() -> crate::Result { (Cons (Const (Int 2) {emptyt}) (Cons (Const (Int 3) {emptyt}) (Cons (Const (Int 4) {emptyt}) (Nil))))))) - (let expr (Switch (Const (Int 1) {emptyt}) list)) + (let expr (Switch (Const (Int 1) {emptyt}) (Empty {emptyt}) list)) " ); let check = format!( diff --git a/src/rvsdg/from_dag.rs b/src/rvsdg/from_dag.rs index 63e3cdd1a..39fadc724 100644 --- a/src/rvsdg/from_dag.rs +++ b/src/rvsdg/from_dag.rs @@ -186,14 +186,14 @@ impl<'a> TreeToRvsdg<'a> { /// initial_translation_cache is a cache of already evaluated expressions. /// For branch subregions, the initial translation cache maps branch input expressions /// to the Operand::Arg corresponding to them. - fn translate_subregion(&mut self, expr: RcExpr, current_args: Vec) -> Vec { - // TODO fix bug here, region graph needs to take the whole region as input + fn translate_subregion(&mut self, expr: RcExpr, num_args: usize) -> Vec { + let args = (0..num_args).map(Operand::Arg).collect(); let mut translator = TreeToRvsdg { program: self.program, nodes: self.nodes, type_cache: self.type_cache, translation_cache: HashMap::new(), - current_args, + current_args: args, }; translator.convert_expr(expr) } @@ -340,8 +340,8 @@ impl<'a> TreeToRvsdg<'a> { // then convert the inputs let input = self.convert_expr(input.clone()); - let then_region = self.translate_subregion(then_branch.clone(), input.clone()); - let else_region = self.translate_subregion(else_branch.clone(), input.clone()); + let then_region = self.translate_subregion(then_branch.clone(), input.len()); + let else_region = self.translate_subregion(else_branch.clone(), input.len()); let new_id = self.nodes.len(); assert_eq!( @@ -370,8 +370,7 @@ impl<'a> TreeToRvsdg<'a> { let mut case_regions = vec![]; for case_expr in cases { - let case_region = - self.translate_subregion(case_expr.clone(), new_inputs.clone()); + let case_region = self.translate_subregion(case_expr.clone(), new_inputs.len()); case_regions.push(case_region); } @@ -389,8 +388,7 @@ impl<'a> TreeToRvsdg<'a> { } Expr::DoWhile(inputs, body) => { let inputs_converted = self.convert_expr(inputs.clone()); - let new_args = (0..inputs_converted.len()).map(Operand::Arg).collect(); - let pred_and_body = self.translate_subregion(body.clone(), new_args); + let pred_and_body = self.translate_subregion(body.clone(), inputs_converted.len()); assert_eq!( inputs_converted.len(), pred_and_body.len() - 1, diff --git a/src/rvsdg/to_dag.rs b/src/rvsdg/to_dag.rs index 532ae493f..14004faec 100644 --- a/src/rvsdg/to_dag.rs +++ b/src/rvsdg/to_dag.rs @@ -89,8 +89,6 @@ struct DagTranslator<'a> { nodes: &'a [RvsdgBody], /// The next id to assign to an alloc. next_alloc_id: i64, - /// cache loop body for use in contexts - loop_body_cache: HashMap, } impl<'a> DagTranslator<'a> { @@ -215,10 +213,6 @@ impl<'a> DagTranslator<'a> { let loop_body_translated = self.translate_subregion(iter::once(pred).chain(outputs.iter()).copied()); - // cache the loop body for use in contexts - self.loop_body_cache - .insert(id, loop_body_translated.clone()); - let loop_expr = dowhile(inputs_translated, loop_body_translated); self.tuple_res(loop_expr, id) @@ -333,7 +327,6 @@ impl RvsdgFunction { fn to_dag_encoding(&self) -> RcExpr { let mut translator = DagTranslator { stored_node: HashMap::new(), - loop_body_cache: HashMap::new(), nodes: &self.nodes, next_alloc_id: 0, }; diff --git a/src/util.rs b/src/util.rs index 08e71ea1e..066fd617d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -569,8 +569,8 @@ impl Run { } RunType::OptimizedRvsdg => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let tree = rvsdg.to_dag_encoding(true); - let optimized = dag_in_context::optimize(&tree).map_err(EggCCError::EggLog)?; + let dag = rvsdg.to_dag_encoding(true); + let optimized = dag_in_context::optimize(&dag).map_err(EggCCError::EggLog)?; let rvsdg = dag_to_rvsdg(&optimized); ( vec![Visualization { @@ -583,8 +583,8 @@ impl Run { } RunType::Egglog => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let tree = rvsdg.to_dag_encoding(true); - let egglog = build_program(&tree); + let dag = rvsdg.to_dag_encoding(true); + let egglog = build_program(&dag); ( vec![Visualization { result: egglog,