Skip to content

Commit

Permalink
regions passing on some tests but strangely some are slow
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Apr 11, 2024
1 parent ab2aa6f commit 02671f4
Show file tree
Hide file tree
Showing 30 changed files with 255 additions and 176 deletions.
101 changes: 68 additions & 33 deletions dag_in_context/src/add_context.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
//! Adds context to the tree program.
//! The `add_context` method recursively adds context to all of the nodes in the tree program
//! by remembering the most recent context (ex. DoWhile or If).
//! Mantains the sharing invariant (see restore_sharing_invariant) by using a cache.
use std::collections::HashMap;

use crate::{
ast::{in_context, infunc},
schema::{Assumption, Expr, RcExpr, TreeProgram},
schema_helpers::AssumptionRef,
};

struct ContextCache {
with_ctx: HashMap<(*const Expr, AssumptionRef), RcExpr>,
}

impl TreeProgram {
pub fn add_context(&self) -> TreeProgram {
TreeProgram {
functions: self
.functions
.iter()
.map(|f| f.clone().func_add_context())
.map(|f| f.clone().func_add_ctx())
.collect(),
entry: self.entry.clone().func_add_context(),
entry: self.entry.clone().func_add_ctx(),
}
}
}

impl Expr {
pub(crate) fn func_add_context(self: RcExpr) -> RcExpr {
pub(crate) fn func_add_ctx(self: RcExpr) -> RcExpr {
let Expr::Function(name, arg_ty, ret_ty, body) = &self.as_ref() else {
panic!("Expected Function, got {:?}", self);
};
Expand All @@ -30,78 +38,105 @@ impl Expr {
name.clone(),
arg_ty.clone(),
ret_ty.clone(),
body.add_context(current_ctx),
body.add_ctx(current_ctx),
))
}

fn add_context(self: &RcExpr, current_ctx: Assumption) -> RcExpr {
match self.as_ref() {
pub(crate) fn add_ctx(self: &RcExpr, current_ctx: Assumption) -> RcExpr {
let mut cache = ContextCache {
with_ctx: HashMap::new(),
};
self.add_ctx_with_cache(current_ctx, &mut cache)
}

fn add_ctx_with_cache(
self: &RcExpr,
current_ctx: Assumption,
cache: &mut ContextCache,
) -> RcExpr {
let ctx_ref = current_ctx.to_ref();
if let Some(expr) = cache
.with_ctx
.get(&((*self).as_ref() as *const Expr, ctx_ref.clone()))
{
return expr.clone();
}
let res = match self.as_ref() {
// leaf nodes are constant, empty, and arg
Expr::Const(..) | Expr::Empty(..) | Expr::Arg(..) => {
in_context(current_ctx, self.clone())
}
// create new contexts for let, loop, and if
Expr::DoWhile(inputs, pred_and_body) => {
let new_inputs = inputs.add_context(current_ctx.clone());
let new_inputs = inputs.add_ctx_with_cache(current_ctx.clone(), cache);
let new_ctx = Assumption::InLoop(new_inputs.clone(), pred_and_body.clone());
RcExpr::new(Expr::DoWhile(
new_inputs,
pred_and_body.add_context(new_ctx),
pred_and_body.add_ctx_with_cache(new_ctx, cache),
))
}
Expr::If(pred, input, then_case, else_calse) => {
let new_pred = pred.add_context(current_ctx.clone());
let new_input = input.add_context(current_ctx.clone());
let then_ctx = Assumption::InIf(true, new_pred.clone());
let else_ctx = Assumption::InIf(false, new_pred.clone());
let new_pred = pred.add_ctx_with_cache(current_ctx.clone(), cache);
let new_input = input.add_ctx_with_cache(current_ctx.clone(), cache);
let then_ctx = Assumption::InIf(true, new_pred.clone(), new_input.clone());
let else_ctx = Assumption::InIf(false, new_pred.clone(), new_input.clone());
RcExpr::new(Expr::If(
new_pred,
new_input,
then_case.add_context(then_ctx),
else_calse.add_context(else_ctx),
then_case.add_ctx_with_cache(then_ctx, cache),
else_calse.add_ctx_with_cache(else_ctx, cache),
))
}
Expr::Switch(case_num, input, branches) => {
let new_case_num = case_num.add_context(current_ctx.clone());
let new_input = input.add_context(current_ctx.clone());
let new_case_num = case_num.add_ctx_with_cache(current_ctx.clone(), cache);
let new_input = input.add_ctx_with_cache(current_ctx.clone(), cache);
let new_branches = branches
.iter()
.map(|b| b.add_context(current_ctx.clone()))
.map(|b| b.add_ctx_with_cache(current_ctx.clone(), cache))
.collect();
RcExpr::new(Expr::Switch(new_case_num, new_input, new_branches))
}
// for all other nodes, just add the context to the children
Expr::Bop(op, x, y) => RcExpr::new(Expr::Bop(
op.clone(),
x.add_context(current_ctx.clone()),
y.add_context(current_ctx),
x.add_ctx_with_cache(current_ctx.clone(), cache),
y.add_ctx_with_cache(current_ctx, cache),
)),
Expr::Top(op, x, y, z) => RcExpr::new(Expr::Top(
op.clone(),
x.add_context(current_ctx.clone()),
y.add_context(current_ctx.clone()),
z.add_context(current_ctx),
x.add_ctx_with_cache(current_ctx.clone(), cache),
y.add_ctx_with_cache(current_ctx.clone(), cache),
z.add_ctx_with_cache(current_ctx, cache),
)),
Expr::Uop(op, x) => RcExpr::new(Expr::Uop(op.clone(), x.add_context(current_ctx))),
Expr::Get(e, i) => RcExpr::new(Expr::Get(e.add_context(current_ctx), *i)),
Expr::Uop(op, x) => RcExpr::new(Expr::Uop(
op.clone(),
x.add_ctx_with_cache(current_ctx, cache),
)),
Expr::Get(e, i) => RcExpr::new(Expr::Get(e.add_ctx_with_cache(current_ctx, cache), *i)),
Expr::Alloc(id, e, state, ty) => RcExpr::new(Expr::Alloc(
*id,
e.add_context(current_ctx.clone()),
state.add_context(current_ctx),
e.add_ctx_with_cache(current_ctx.clone(), cache),
state.add_ctx_with_cache(current_ctx, cache),
ty.clone(),
)),
Expr::Call(f, arg) => {
RcExpr::new(Expr::Call(f.clone(), arg.add_context(current_ctx.clone())))
}
Expr::Single(e) => RcExpr::new(Expr::Single(e.add_context(current_ctx))),
Expr::Call(f, arg) => RcExpr::new(Expr::Call(
f.clone(),
arg.add_ctx_with_cache(current_ctx.clone(), cache),
)),
Expr::Single(e) => RcExpr::new(Expr::Single(e.add_ctx_with_cache(current_ctx, cache))),
Expr::Concat(x, y) => RcExpr::new(Expr::Concat(
x.add_context(current_ctx.clone()),
y.add_context(current_ctx),
x.add_ctx_with_cache(current_ctx.clone(), cache),
y.add_ctx_with_cache(current_ctx, cache),
)),
Expr::InContext(..) => {
panic!("add_context expects a term without context")
}
Expr::Function(..) => panic!("Function should have been handled in func_add_context"),
}
};
cache.with_ctx.insert(
(self.as_ref() as *const Expr, ctx_ref),
res.clone(),
);
res
}
}
4 changes: 2 additions & 2 deletions dag_in_context/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ pub fn inloop(e1: RcExpr, e2: RcExpr) -> Assumption {
Assumption::InLoop(e1, e2)
}

pub fn inif(is_then: bool, pred: RcExpr) -> Assumption {
Assumption::InIf(is_then, pred)
pub fn inif(is_then: bool, pred: RcExpr, input: RcExpr) -> Assumption {
Assumption::InIf(is_then, pred, input)
}

pub fn infunc(name: &str) -> Assumption {
Expand Down
6 changes: 3 additions & 3 deletions dag_in_context/src/from_egglog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ impl FromEgglog {
};
Assumption::InFunc(string.to_string())
}
("InIf", [is_then, expr]) => {
("InIf", [is_then, pred_expr, input_expr]) => {
let Term::Lit(Literal::Bool(boolean)) = self.termdag.get(*is_then)
else {
panic!("Invalid boolean: {:?}", is_then)
};
Assumption::InIf(boolean, self.expr_from_egglog(self.termdag.get(*expr)))
Assumption::InIf(boolean, self.expr_from_egglog(self.termdag.get(*pred_expr)), self.expr_from_egglog(self.termdag.get(*input_expr)))
}
_ => panic!("Invalid assumption: {:?}", assumption),
})
Expand Down Expand Up @@ -354,7 +354,7 @@ impl FromEgglog {
self.program_from_egglog_preserve_ctx_nodes(new_term)
}

fn without_ctx_nodes(&mut self, expr: Term) -> Term {
pub fn without_ctx_nodes(&mut self, expr: Term) -> Term {
match expr {
Term::App(head, children) => match (head.to_string().as_str(), children.as_slice()) {
("InContext", [_assumption, expr]) => {
Expand Down
2 changes: 1 addition & 1 deletion dag_in_context/src/interval_analysis.egg
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
; Conditionals
; =================================
(rule (
(= lhs (If cond thn els))
(= lhs (If cond inputs thn els))
(= thn-ival (ival thn))
(= els-ival (ival els))
)
Expand Down
3 changes: 2 additions & 1 deletion dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ pub fn prologue() -> String {
&optimizations::is_valid::rules().join("\n"),
&optimizations::body_contains::rules().join("\n"),
&optimizations::purity_analysis::rules().join("\n"),
&optimizations::conditional_invariant_code_motion::rules().join("\n"),
// TODO cond inv code motion with regions
//&optimizations::conditional_invariant_code_motion::rules().join("\n"),
include_str!("utility/in_context.egg"),
include_str!("utility/context-prop.egg"),
include_str!("utility/subst.egg"),
Expand Down
16 changes: 9 additions & 7 deletions dag_in_context/src/optimizations/body_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ fn captured_expr_rule_for_ctor(ctor: Constructor) -> Option<String> {
fn subexpr_rule_for_ctor(ctor: Constructor) -> Option<String> {
let pat = ctor.construct(|field| field.var());
let actions = ctor.filter_map_fields(|field| {
(field.purpose == Purpose::SubExpr || field.purpose == Purpose::SubListExpr).then(|| {
format!(
"(BodyContains{sort} body {e})",
sort = field.sort().name(),
e = field.var()
)
})
(field.purpose == Purpose::SubExpr || field.purpose == Purpose::CapturedSubListExpr).then(
|| {
format!(
"(BodyContains{sort} body {e})",
sort = field.sort().name(),
e = field.var()
)
},
)
});
(!actions.is_empty()).then(|| {
format!(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::schema_helpers::{Constructor, ESort, Purpose};
/*use crate::schema_helpers::{Constructor, ESort, Purpose};
use std::iter;
use strum::IntoEnumIterator;
// TODO implement now that we have if regions
fn rules_for_ctor(ctor: Constructor) -> Option<String> {
use Constructor::*;
if [DoWhile, InContext].contains(&ctor) || ctor.sort() != ESort::Expr {
Expand Down Expand Up @@ -53,7 +55,7 @@ fn rules_for_ctor(ctor: Constructor) -> Option<String> {
((union (Switch pred exprs)
{resulting_switch}))
:ruleset conditional-invariant-code-motion)
(rewrite (If c {ctor_pattern1} {ctor_pattern2})
{resulting_if}
:when ((ExprIsValid (If c {ctor_pattern1} {ctor_pattern2})))
Expand Down Expand Up @@ -147,3 +149,4 @@ fn test_lift_if() -> crate::Result {
vec![],
)
}
*/
2 changes: 1 addition & 1 deletion dag_in_context/src/optimizations/is_valid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use strum::IntoEnumIterator;
fn rule_for_ctor(ctor: Constructor) -> Option<String> {
let actions = ctor.filter_map_fields(|field| match field.purpose {
Purpose::Static(_) => None,
Purpose::CapturedExpr | Purpose::SubExpr | Purpose::SubListExpr => Some(format!(
Purpose::CapturedExpr | Purpose::SubExpr | Purpose::CapturedSubListExpr => Some(format!(
"({sort}IsValid {var})",
sort = field.sort().name(),
var = field.var()
Expand Down
2 changes: 1 addition & 1 deletion dag_in_context/src/optimizations/loop_invariant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn is_invariant_rule_for_ctor(ctor: Constructor) -> Option<String> {
let is_inv_ctor = ctor
.filter_map_fields(|field| match field.purpose {
Purpose::Static(_) | Purpose::CapturedExpr => None,
Purpose::SubExpr | Purpose::SubListExpr => {
Purpose::SubExpr | Purpose::CapturedSubListExpr => {
let var = field.var();
let sort = field.sort().name();
Some(format!("(= true (is-inv-{sort} loop {var}))"))
Expand Down
2 changes: 1 addition & 1 deletion dag_in_context/src/optimizations/purity_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn purity_rules_for_ctor(ctor: Constructor) -> String {
Purpose::Static(Sort::BinaryOp)
| Purpose::Static(Sort::UnaryOp)
| Purpose::SubExpr
| Purpose::SubListExpr
| Purpose::CapturedSubListExpr
| Purpose::CapturedExpr => Some(format!(
"({sort}IsPure {var})",
sort = field.sort().name(),
Expand Down
34 changes: 18 additions & 16 deletions dag_in_context/src/optimizations/switch_rewrites.egg
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
(ruleset switch_rewrite)

(rewrite (If (Bop (And) a b) X Y)
(If a (If b X Y) Y)
:when ((ExprIsPure b))
:ruleset switch_rewrite)

(rewrite (If (Bop (Or) a b) X Y)
(If a X (If b X Y))
:when ((ExprIsPure b))
:ruleset switch_rewrite)

(rewrite (If (Const (Bool true) ty) thn els)
thn
:ruleset switch_rewrite)

(rewrite (If (Const (Bool false) ty) thn els)
els
:ruleset switch_rewrite)
;; TODO rewrite for if regions
;;(rewrite (If (Bop (And) a b) X Y)
;; (If a (If b X Y) Y)
;; :when ((ExprIsPure b))
;; :ruleset switch_rewrite)
;;
;;(rewrite (If (Bop (Or) a b) X Y)
;; (If a X (If b X Y))
;; :when ((ExprIsPure b))
;; :ruleset switch_rewrite)
;;
;;(rewrite (If (Const (Bool true) ty) thn els)
;; thn
;; :ruleset switch_rewrite)
;;
;;(rewrite (If (Const (Bool false) ty) thn els)
;; els
;; :ruleset switch_rewrite)
3 changes: 2 additions & 1 deletion dag_in_context/src/schedule.egg
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
(saturate context-prop)
(saturate context-helpers)
context))
conditional-invariant-code-motion
;; TODO enable when conditional inv code motion works again
;;conditional-invariant-code-motion
switch_rewrite
loop-simplify
))
4 changes: 2 additions & 2 deletions dag_in_context/src/schema.egg
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@
(InFunc String)
; Branch of the switch and what the predicate is
(InSwitch i64 Expr)
; If the predicate was true, and what the predicate is
(InIf bool Expr)
; If the predicate was true, and what the predicate is, and what the input is
(InIf bool Expr Expr)
; Other assumptions are possible, but not supported yet.
; For example:
; A boolean predicate is true.
Expand Down
3 changes: 2 additions & 1 deletion dag_in_context/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ pub type RcExpr = Rc<Expr>;
pub enum Assumption {
InLoop(RcExpr, RcExpr),
InFunc(String),
InIf(bool, RcExpr),
InIf(bool, RcExpr, RcExpr),
}


#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum Expr {
Const(Constant, Type),
Expand Down
Loading

0 comments on commit 02671f4

Please sign in to comment.