Skip to content

Commit

Permalink
refactor: extract out MemoryEffects structure (#222)
Browse files Browse the repository at this point in the history
### Description:

Extracted from #179, stacked on #220.

We extract out memory-effects related code from AxEffects into a new
MemoryEffects structure. This PR is purely a non-functional change, but
will serve as the starting point of integrating simp_mem with sym_n.

The current simplification is effectively a no-op, since the proof state
is not massaged to the way `simp_mem` wants it to be. Subsequent PRs
will focus on massaging the goal state to be as `simp_mem` likes, and
then trying to symbolically simplify the memory expression we see.

### Testing:

What tests have been run? Did `make all` succeed for your changes? Was
conformance testing successful on an Aarch64 machine? yes

### License:

By submitting this pull request, I confirm that my contribution is
made under the terms of the Apache 2.0 license.

Co-authored-by @bollu<[email protected]>

---------

Co-authored-by: Shilpi Goel <[email protected]>
  • Loading branch information
alexkeizer and shigoel authored Oct 10, 2024
1 parent ae0d779 commit 0ac14c8
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 60 deletions.
36 changes: 36 additions & 0 deletions Tactics/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,42 @@ def Lean.Expr.eqReadField? (e : Expr) : Option (Expr × Expr × Expr) := do
| none
some (field, state, value)

/-- Return the expression for `Memory` -/
def mkMemory : Expr := mkConst ``Memory

/-! ## Expr Helpers -/

/-- Throw an error if `e` is not of type `expectedType` -/
def assertHasType (e expectedType : Expr) : MetaM Unit := do
let eType ← inferType e
if !(←isDefEq eType expectedType) then
throwError "{e} {← mkHasTypeButIsExpectedMsg eType expectedType}"

/-- Throw an error if `e` is not def-eq to `expected` -/
def assertIsDefEq (e expected : Expr) : MetaM Unit := do
if !(←isDefEq e expected) then
throwError "expected:\n {expected}\nbut found:\n {e}"

/--
Rewrites `e` via some `eq`, producing a proof `e = e'` for some `e'`.
Rewrites with a fresh metavariable as the ambient goal.
Fails if the rewrite produces any subgoals.
-/
-- source: https://github.com/leanprover-community/mathlib4/blob/b35703fe5a80f1fa74b82a2adc22f3631316a5b3/Mathlib/Lean/Expr/Basic.lean#L476-L477
def rewrite (e eq : Expr) : MetaM Expr := do
let ⟨_, eq', []⟩ ← (← mkFreshExprMVar none).mvarId!.rewrite e eq
| throwError "Expr.rewrite may not produce subgoals."
return eq'

/--
Rewrites the type of `e` via some `eq`, then moves `e` into the new type via `Eq.mp`.
Rewrites with a fresh metavariable as the ambient goal.
Fails if the rewrite produces any subgoals.
-/
-- source: https://github.com/leanprover-community/mathlib4/blob/b35703fe5a80f1fa74b82a2adc22f3631316a5b3/Mathlib/Lean/Expr/Basic.lean#L476-L477
def rewriteType (e eq : Expr) : MetaM Expr := do
mkEqMP (← rewrite (← inferType e) eq) e

/-! ## Tracing helpers -/

def traceHeartbeats (cls : Name) (header : Option String := none) :
Expand Down
96 changes: 36 additions & 60 deletions Tactics/Sym/AxEffects.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Tactics.Common
import Tactics.Attr
import Tactics.Simp
import Tactics.Sym.Common
import Tactics.Sym.MemoryEffects

import Std.Data.HashMap

Expand Down Expand Up @@ -59,17 +60,8 @@ structure AxEffects where
where `f₁, ⋯, fₙ` are the keys of `fields`
-/
nonEffectProof : Expr
/-- An expression of a (potentially empty) sequence of `write_mem`s
to the initial state, which describes the effects on memory.
See `memoryEffectProof` for more detail -/
memoryEffect : Expr
/-- An expression that contains the proof of:
```lean
∀ n addr,
read_mem_bytes n addr <currentState>
= read_mem_bytes n addr <memoryEffect>
``` -/
memoryEffectProof : Expr
/-- The memory effects -/
memoryEffects : MemoryEffects
/-- A proof that `<currentState>.program = <initialState>.program` -/
programProof : Expr
/-- An optional proof of `CheckSPAlignment <currentState>`.
Expand Down Expand Up @@ -100,8 +92,8 @@ variable {m} [Monad m] [MonadReaderOf AxEffects m]
def getCurrentState : m Expr := do return (← read).currentState
def getInitialState : m Expr := do return (← read).initialState
def getNonEffectProof : m Expr := do return (← read).nonEffectProof
def getMemoryEffect : m Expr := do return (← read).memoryEffect
def getMemoryEffectProof : m Expr := do return (← read).memoryEffectProof
def getMemoryEffect : m Expr := do return (← read).memoryEffects.effects
def getMemoryEffectProof : m Expr := do return (← read).memoryEffects.proof
def getProgramProof : m Expr := do return (← read).programProof

def getStackAlignmentProof? : m (Option Expr) := do
Expand Down Expand Up @@ -136,15 +128,7 @@ def initial (state : Expr) : AxEffects where
-- `fun f => rfl`
mkLambda `f .default (mkConst ``StateField) <|
mkEqReflArmState <| mkApp2 (mkConst ``r) (.bvar 0) state
memoryEffect := state
memoryEffectProof :=
-- `fun n addr => rfl`
mkLambda `n .default (mkConst ``Nat) <|
let bv64 := mkApp (mkConst ``BitVec) (toExpr 64)
mkLambda `addr .default bv64 <|
mkApp2 (.const ``Eq.refl [1])
(mkApp (mkConst ``BitVec) <| mkNatMul (.bvar 1) (toExpr 8))
(mkApp3 (mkConst ``read_mem_bytes) (.bvar 1) (.bvar 0) state)
memoryEffects := .initial state
programProof :=
-- `rfl`
mkAppN (.const ``Eq.refl [1]) #[
Expand All @@ -170,8 +154,7 @@ instance : ToMessageData AxEffects where
currentState := {eff.currentState},
fields := {eff.fields},
nonEffectProof := {eff.nonEffectProof},
memoryEffect := {eff.memoryEffect},
memoryEffectProof := {eff.memoryEffectProof},
memoryEffects := {eff.memoryEffects},
programProof := {eff.programProof}
}"

Expand Down Expand Up @@ -280,7 +263,7 @@ Note that no effort is made to preserve `currentStateEq`; it is set to `none`!
-/
private def update_write_mem (eff : AxEffects) (n addr val : Expr) :
MetaM AxEffects :=
withTraceNode m!"processing: write_mem {n} {addr} {val} …" (tag := "updateWriteMem") <| do
Sym.withTraceNode m!"processing: write_mem {n} {addr} {val} …" (tag := "updateWriteMem") <| do

-- Update each field
let fields ← eff.fields.toList.mapM fun ⟨fld, {value, proof}⟩ => do
Expand All @@ -298,11 +281,10 @@ private def update_write_mem (eff : AxEffects) (n addr val : Expr) :
mkLambdaFVars args proof
-- ^^ `fun f ... => Eq.trans (@r_of_write_mem_bytes f ...) <proof>`

-- Update the memory effects proof
let memoryEffectProof :=
-- `read_mem_bytes_write_mem_bytes_of_read_mem_eq <memoryEffectProof> ...`
mkAppN (mkConst ``read_mem_bytes_write_mem_bytes_of_read_mem_eq)
#[eff.currentState, eff.memoryEffect, eff.memoryEffectProof, n, addr, val]
-- Update the memory effects
let memoryEffects ←
eff.memoryEffects.updateWriteMem eff.currentState n addr val


-- Update the program proof
let programProof ←
Expand All @@ -318,15 +300,13 @@ private def update_write_mem (eff : AxEffects) (n addr val : Expr) :
#[eff.currentState, n, addr, val, proof]

-- Assemble the result
let addWrite (e : Expr) :=
-- `@write_mem_bytes <n> <addr> <val> <e>`
mkApp4 (mkConst ``write_mem_bytes) n addr val e
let currentState := -- `@write_mem_bytes <n> <addr> <val> <currentState>`
mkApp4 (mkConst ``write_mem_bytes) n addr val eff.currentState
let eff := { eff with
currentState := addWrite eff.currentState
currentState
fields := .ofList fields
nonEffectProof
memoryEffect := addWrite eff.memoryEffect
memoryEffectProof
memoryEffects
programProof
stackAlignmentProof?
}
Expand All @@ -341,7 +321,7 @@ Note that no effort is made to preserve `currentStateEq`; it is set to `none`!
-/
private def update_w (eff : AxEffects) (fld val : Expr) :
MetaM AxEffects := do
withTraceNode m!"processing: w {fld} {val} …" (tag := "updateWrite") <| do
Sym.withTraceNode m!"processing: w {fld} {val} …" (tag := "updateWrite") <| do
let rField ← reflectStateField fld

-- Update all other fields
Expand Down Expand Up @@ -398,11 +378,8 @@ private def update_w (eff : AxEffects) (fld val : Expr) :
withLocalDeclD name h_neq_type fun h_neq =>
k (args.push h_neq) h_neq

-- Update the memory effect proof
let memoryEffectProof :=
-- `read_mem_bytes_w_of_read_mem_eq ...`
mkAppN (mkConst ``read_mem_bytes_w_of_read_mem_eq)
#[eff.currentState, eff.memoryEffect, eff.memoryEffectProof, fld, val]
-- Update the memory effects
let memoryEffects ← eff.memoryEffects.updateWrite eff.currentState fld val

-- Update the program proof
let programProof ←
Expand Down Expand Up @@ -434,8 +411,7 @@ private def update_w (eff : AxEffects) (fld val : Expr) :
currentState := mkApp3 (mkConst ``w) fld val eff.currentState
fields := Std.HashMap.ofList fields
nonEffectProof
-- memory effects are unchanged
memoryEffectProof
memoryEffects
programProof
stackAlignmentProof?
sideConditions
Expand Down Expand Up @@ -498,7 +474,7 @@ def fromExpr (e : Expr) : MetaM AxEffects := do
set `s` to be the new `currentState`, and update all proofs accordingly -/
def adjustCurrentStateWithEq (eff : AxEffects) (s eq : Expr) :
MetaM AxEffects := do
withTraceNode m!"adjustCurrentStateWithEq" (tag := "adjustCurrentStateWithEq") do
Sym.withTraceNode m!"adjustCurrentStateWithEq" (tag := "adjustCurrentStateWithEq") do
trace[Tactic.sym] "rewriting along {eq}"
eff.traceCurrentState

Expand All @@ -515,17 +491,15 @@ def adjustCurrentStateWithEq (eff : AxEffects) (s eq : Expr) :
pure (field, {fieldEff with proof})
let fields := .ofList fields

withTraceNode m!"rewriting other proofs" (tag := "rewriteMisc") <| do
Sym.withTraceNode m!"rewriting other proofs" (tag := "rewriteMisc") <| do
let nonEffectProof ← rewriteType eff.nonEffectProof eq
let memoryEffectProof ← rewriteType eff.memoryEffectProof eq
-- ^^ TODO: what happens if `memoryEffect` is the same as `currentState`?
-- Presumably, we would *not* want to encapsulate `memoryEffect` here
let memoryEffects ← eff.memoryEffects.adjustCurrentStateWithEq eq
let programProof ← rewriteType eff.programProof eq
let stackAlignmentProof? ← eff.stackAlignmentProof?.mapM
(rewriteType · eq)

return { eff with
currentState, fields, nonEffectProof, memoryEffectProof, programProof,
currentState, fields, nonEffectProof, memoryEffects, programProof,
stackAlignmentProof?
}

Expand Down Expand Up @@ -642,7 +616,7 @@ NOTE: does not necessarily validate *which* type an expression has,
validation will still pass if types are different to those we claim in the
docstrings -/
def validate (eff : AxEffects) : MetaM Unit := do
withTraceNode "validating that the axiomatic effects are well-formed"
Sym.withTraceNode "validating that the axiomatic effects are well-formed"
(tag := "validate") <| do
eff.traceCurrentState

Expand All @@ -653,13 +627,13 @@ def validate (eff : AxEffects) : MetaM Unit := do
check fieldEff.value
check fieldEff.proof

eff.memoryEffects.validate
check eff.nonEffectProof
check eff.memoryEffect
check eff.memoryEffectProof
check eff.programProof
if let some h := eff.stackAlignmentProof? then
check h


/-! ## Tactic Environment -/
section Tactic
open Elab.Tactic
Expand All @@ -678,7 +652,7 @@ that was just added to the local context -/
def addHypothesesToLContext (eff : AxEffects) (hypPrefix : String := "h_")
(mvar : Option MVarId := none) :
TacticM AxEffects :=
withTraceNode m!"adding hypotheses to local context"
Sym.withTraceNode m!"adding hypotheses to local context"
(tag := "addHypothesesToLContext") do
eff.traceCurrentState
let mut goal ← mvar.getDM getMainGoal
Expand All @@ -704,12 +678,14 @@ def addHypothesesToLContext (eff : AxEffects) (hypPrefix : String := "h_")
let nonEffectProof := Expr.fvar nonEffectProof
goal := goal'

trace[Tactic.sym] "adding memory effects with {eff.memoryEffectProof}"
trace[Tactic.sym] "adding memory effects with {eff.memoryEffects.proof}"
let ⟨memoryEffectProof, goal'⟩ ← goal.withContext do
let name := .mkSimple s!"{hypPrefix}memory_effects"
let proof := eff.memoryEffectProof
let proof := eff.memoryEffects.proof
replaceOrNote goal name proof
let memoryEffectProof := Expr.fvar memoryEffectProof
let memoryEffects := { eff.memoryEffects with
proof := Expr.fvar memoryEffectProof
}
goal := goal'

trace[Tactic.sym] "adding program hypothesis with {eff.programProof}"
Expand All @@ -735,7 +711,7 @@ def addHypothesesToLContext (eff : AxEffects) (hypPrefix : String := "h_")

replaceMainGoal [goal]
return {eff with
fields, nonEffectProof, memoryEffectProof, programProof,
fields, nonEffectProof, memoryEffects, programProof,
stackAlignmentProof?
}
where
Expand All @@ -755,7 +731,7 @@ where
/-- Return an array of `SimpTheorem`s of the proofs contained in
the given `AxEffects` -/
def toSimpTheorems (eff : AxEffects) : MetaM (Array SimpTheorem) := do
withTraceNode m!"computing SimpTheorems for (non-)effect hypotheses"
Sym.withTraceNode m!"computing SimpTheorems for (non-)effect hypotheses"
(tag := "toSimpTheorems") <| do
let lctx ← getLCtx
let baseName? :=
Expand Down Expand Up @@ -789,7 +765,7 @@ def toSimpTheorems (eff : AxEffects) : MetaM (Array SimpTheorem) := do
thms ← add thms proof s!"field_{field}" (prio := 1500)

thms ← add thms eff.nonEffectProof "nonEffectProof"
thms ← add thms eff.memoryEffectProof "memoryEffectProof"
thms ← add thms eff.memoryEffects.proof "memoryEffectProof"
thms ← add thms eff.programProof "programProof"
if let some stackAlignmentProof := eff.stackAlignmentProof? then
thms ← add thms stackAlignmentProof "stackAlignmentProof"
Expand Down
7 changes: 7 additions & 0 deletions Tactics/Sym/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def withVerboseTraceNode (msg : MessageData) (k : m α)
: m α := do
Lean.withTraceNode `Tactic.sym.verbose (fun _ => pure msg) k collapsed tag

/-- Create a trace note that folds `header` with `(NOTE: can be large)`,
and prints `msg` under such a trace node.
-/
def traceLargeMsg (header : MessageData) (msg : MessageData) : MetaM Unit :=
withTraceNode m!"{header} (NOTE: can be large)" do
trace[Tactic.sym] msg

end Tracing

end Sym
Loading

0 comments on commit 0ac14c8

Please sign in to comment.