diff --git a/prusti-tests/tests/verify/fail/fold-unfold/old-local.rs b/prusti-tests/tests/verify/fail/fold-unfold/old-local.rs new file mode 100644 index 00000000000..3681ffc2ba6 --- /dev/null +++ b/prusti-tests/tests/verify/fail/fold-unfold/old-local.rs @@ -0,0 +1,10 @@ +extern crate prusti_contracts; +use prusti_contracts::*; + +fn blah(y: i32){ + let x = 1; + prusti_assert!(old(x) == 1); //~ ERROR old expressions should not contain local variables +} + +fn main(){ +} diff --git a/prusti-tests/tests/verify/pass/quick/old-in-loop.rs b/prusti-tests/tests/verify/pass/quick/old-in-loop.rs new file mode 100644 index 00000000000..c1f59e54b80 --- /dev/null +++ b/prusti-tests/tests/verify/pass/quick/old-in-loop.rs @@ -0,0 +1,20 @@ +use prusti_contracts::*; + +pub struct Wrapper(u32); + +impl Wrapper { + #[pure] + pub fn get(&self) -> u32 { + self.0 + } + +} + +fn capitalize(vec: &mut Wrapper) { + while true { + body_invariant!(vec.get() == old(vec.get())); + } +} + +fn main(){ +} diff --git a/prusti-viper/src/encoder/mir/pure/specifications/interface.rs b/prusti-viper/src/encoder/mir/pure/specifications/interface.rs index f927a37cf12..192dea0f2a4 100644 --- a/prusti-viper/src/encoder/mir/pure/specifications/interface.rs +++ b/prusti-viper/src/encoder/mir/pure/specifications/interface.rs @@ -5,7 +5,7 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. use crate::encoder::{ - errors::{SpannedEncodingResult, WithSpan}, + errors::{SpannedEncodingError, SpannedEncodingResult, WithSpan}, mir::{ places::PlacesEncoderInterface, pure::{ @@ -28,13 +28,14 @@ use crate::encoder::{ snapshot::interface::SnapshotEncoderInterface, }; use prusti_rustc_interface::{ + data_structures::fx::FxHashSet, hir::def_id::DefId, middle::{mir, ty::GenericArgsRef}, span::Span, }; use vir_crate::{ high::{self as vir_high, operations::ty::Typed}, - polymorphic as vir_poly, + polymorphic::{self as vir_poly, ExprWalker}, }; pub(crate) trait SpecificationEncoderInterface<'tcx> { @@ -93,6 +94,7 @@ pub(crate) trait SpecificationEncoderInterface<'tcx> { invariant_block: mir::BasicBlock, // in which the invariant is defined parent_def_id: DefId, substs: GenericArgsRef<'tcx>, + is_loop_invariant: bool, // Because this is also used for assert/assume/refute as well ) -> SpannedEncodingResult; } @@ -303,6 +305,7 @@ impl<'v, 'tcx: 'v> SpecificationEncoderInterface<'tcx> for crate::encoder::Encod invariant_block: mir::BasicBlock, // in which the invariant is defined parent_def_id: DefId, substs: GenericArgsRef<'tcx>, + is_loop_invariant: bool, // because this function is also used for encoding assert/assume ) -> SpannedEncodingResult { // identify closure aggregate assign (the invariant body) let closure_assigns = mir.basic_blocks[invariant_block] @@ -368,6 +371,67 @@ impl<'v, 'tcx: 'v> SpecificationEncoderInterface<'tcx> for crate::encoder::Encod // TODO: deal with old(...) ? let final_invariant = invariant.unwrap().into_expr().unwrap(); + + if !is_loop_invariant { + ensure_no_old_local_vars( + &final_invariant, + mir.args_iter() + .map(|local| self.encode_local_name(mir, local).unwrap()) + .collect(), + span, + )?; + } + Ok(final_invariant) } } + +fn ensure_no_old_local_vars( + expr: &vir_poly::Expr, + args: FxHashSet, + span: Span, +) -> Result<(), SpannedEncodingError> { + struct InvalidInOldFinder<'a> { + invalid: &'a mut bool, + args: &'a FxHashSet, + } + + impl<'a> vir_poly::ExprWalker for InvalidInOldFinder<'a> { + fn walk_local(&mut self, expr: &vir_poly::Local) { + if !self.args.contains(&expr.variable.name) { + *self.invalid = true; + } + } + } + + struct OldWalker<'a> { + invalid: bool, + args: &'a FxHashSet, + } + + impl<'a> vir_poly::ExprWalker for OldWalker<'a> { + fn walk_labelled_old(&mut self, expr: &vir_poly::LabelledOld) { + let mut invalid_finder = InvalidInOldFinder { + invalid: &mut self.invalid, + args: self.args, + }; + if expr.label == PRECONDITION_LABEL { + invalid_finder.walk(&expr.base); + } + } + } + + let mut walker = OldWalker { + invalid: false, + args: &args, + }; + + walker.walk(expr); + if walker.invalid { + return Err(SpannedEncodingError::incorrect( + "old expressions should not contain local variables", + span, + )); + } + Ok(()) +} diff --git a/prusti-viper/src/encoder/procedure_encoder.rs b/prusti-viper/src/encoder/procedure_encoder.rs index d76a13f40cc..0abc6dfb96c 100644 --- a/prusti-viper/src/encoder/procedure_encoder.rs +++ b/prusti-viper/src/encoder/procedure_encoder.rs @@ -246,9 +246,13 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { if self.encoder.get_prusti_assumption(cl_def_id).is_none() { return Ok(false); } - let assume_expr = - self.encoder - .encode_invariant(self.mir, bb, self.proc_def_id, cl_substs)?; + let assume_expr = self.encoder.encode_invariant( + self.mir, + bb, + self.proc_def_id, + cl_substs, + false, + )?; let assume_stmt = vir::Stmt::Inhale(vir::Inhale { expr: assume_expr }); @@ -281,9 +285,13 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { .encoder .get_definition_span(assertion.assertion.to_def_id()); - let assert_expr = - self.encoder - .encode_invariant(self.mir, bb, self.proc_def_id, cl_substs)?; + let assert_expr = self.encoder.encode_invariant( + self.mir, + bb, + self.proc_def_id, + cl_substs, + false, + )?; let assert_stmt = vir::Stmt::Assert(vir::Assert { expr: assert_expr, @@ -319,9 +327,13 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { .encoder .get_definition_span(refutation.refutation.to_def_id()); - let refute_expr = - self.encoder - .encode_invariant(self.mir, bb, self.proc_def_id, cl_substs)?; + let refute_expr = self.encoder.encode_invariant( + self.mir, + bb, + self.proc_def_id, + cl_substs, + false, + )?; let refute_stmt = vir::Stmt::Refute(vir::Refute { expr: refute_expr, @@ -5566,6 +5578,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { bbi, self.proc_def_id, cl_substs, + true, )?); let invariant = match spec { prusti_interface::specs::typed::LoopSpecification::Invariant(inv) => {