diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 48076bfc02b..f51be965b6c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -110,7 +110,7 @@ jobs: # Run a subset of the tests with the purification optimization enabled # to ensure that we do not introduce regressions. purification-tests: - needs: [fmt-check, clippy-check, check-deps, smir-check, quick-tests] + #needs: [fmt-check, clippy-check, check-deps, smir-check, quick-tests] runs-on: ubuntu-latest env: PRUSTI_ENABLE_PURIFICATION_OPTIMIZATION: true @@ -162,6 +162,36 @@ jobs: # python x.py test --all pass/pure-fn/ref-mut-arg.rs # python x.py test --all pass/rosetta/Ackermann_function.rs # python x.py test --all pass/rosetta/Heapsort.rs + - name: custom_heap_encoding + env: + PRUSTI_VIPER_BACKEND: carbon + PRUSTI_CUSTOM_HEAP_ENCODING: true + PRUSTI_TRACE_WITH_SYMBOLIC_EXECUTION: false + PRUSTI_PURIFY_WITH_SYMBOLIC_EXECUTION: false + run: | + python x.py test custom_heap_encoding + - name: purify_with_symbolic_execution + env: + PRUSTI_VIPER_BACKEND: carbon + PRUSTI_CUSTOM_HEAP_ENCODING: false + PRUSTI_PURIFY_WITH_SYMBOLIC_EXECUTION: true + run: | + python x.py test custom_heap_encoding + - name: custom_heap_encoding and purify_with_symbolic_execution + env: + PRUSTI_VIPER_BACKEND: carbon + PRUSTI_CUSTOM_HEAP_ENCODING: true + PRUSTI_PURIFY_WITH_SYMBOLIC_EXECUTION: true + run: | + python x.py test custom_heap_encoding + - name: trace_with_symbolic_execution + env: + PRUSTI_VIPER_BACKEND: silicon + PRUSTI_CUSTOM_HEAP_ENCODING: false + PRUSTI_TRACE_WITH_SYMBOLIC_EXECUTION: false + PRUSTI_PURIFY_WITH_SYMBOLIC_EXECUTION: false + run: | + python x.py test custom_heap_encoding - name: Run with purification. env: PRUSTI_VIPER_BACKEND: silicon diff --git a/Cargo.lock b/Cargo.lock index a5660da1c8a..aba50d6c414 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -43,7 +43,7 @@ version = "0.1.0" dependencies = [ "compiletest_rs", "derive_more", - "env_logger", + "env_logger 0.10.0", "glob", "log", "prusti-rustc-interface", @@ -876,6 +876,24 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bd4b30a6560bbd9b4620f4de34c3f14f60848e58a9b7216801afcb4c7b31c3c" +[[package]] +name = "egg" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6e969a475908119d4603393dfe05a17f676d9570c493c90763321aa950de2c" +dependencies = [ + "env_logger 0.9.3", + "fxhash", + "hashbrown", + "indexmap", + "instant", + "log", + "smallvec", + "symbol_table", + "symbolic_expressions", + "thiserror", +] + [[package]] name = "either" version = "1.8.0" @@ -891,6 +909,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "env_logger" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" +dependencies = [ + "log", +] + [[package]] name = "env_logger" version = "0.10.0" @@ -2089,7 +2116,7 @@ name = "prusti" version = "0.2.1" dependencies = [ "chrono", - "env_logger", + "env_logger 0.10.0", "lazy_static", "log", "prusti-common", @@ -2163,7 +2190,7 @@ version = "0.1.0" dependencies = [ "bincode", "clap", - "env_logger", + "env_logger 0.10.0", "lazy_static", "log", "num_cpus", @@ -2204,7 +2231,7 @@ version = "0.2.0" dependencies = [ "cargo-test-support", "compiletest_rs", - "env_logger", + "env_logger 0.10.0", "log", "prusti", "prusti-launch", @@ -2235,6 +2262,7 @@ dependencies = [ "backtrace", "derive_more", "diffy", + "egg", "itertools", "lazy_static", "log", @@ -2756,6 +2784,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "smallvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" + [[package]] name = "smt-log-analyzer" version = "0.1.0" @@ -2815,6 +2849,22 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "symbol_table" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32bf088d1d7df2b2b6711b06da3471bc86677383c57b27251e18c56df8deac14" +dependencies = [ + "ahash", + "hashbrown", +] + +[[package]] +name = "symbolic_expressions" +version = "5.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c68d531d83ec6c531150584c42a4290911964d5f0d79132b193b67252a23b71" + [[package]] name = "syn" version = "1.0.107" @@ -2842,7 +2892,7 @@ dependencies = [ name = "systest" version = "0.1.0" dependencies = [ - "env_logger", + "env_logger 0.10.0", "error-chain", "jni", "jni-gen", @@ -2903,7 +2953,7 @@ dependencies = [ "clap", "color-backtrace", "csv", - "env_logger", + "env_logger 0.10.0", "failure", "glob", "log", @@ -3247,7 +3297,7 @@ version = "0.1.0" dependencies = [ "bencher", "bincode", - "env_logger", + "env_logger 0.10.0", "error-chain", "futures", "jni", @@ -3265,7 +3315,7 @@ dependencies = [ name = "viper-sys" version = "0.1.0" dependencies = [ - "env_logger", + "env_logger 0.10.0", "error-chain", "jni", "jni-gen", diff --git a/prusti-common/src/vir/low_to_viper/ast.rs b/prusti-common/src/vir/low_to_viper/ast.rs index 6e56e887b01..c570c3c4bda 100644 --- a/prusti-common/src/vir/low_to_viper/ast.rs +++ b/prusti-common/src/vir/low_to_viper/ast.rs @@ -46,6 +46,7 @@ impl<'v> ToViper<'v, viper::Stmt<'v>> for Statement { fn to_viper(&self, context: Context, ast: &AstFactory<'v>) -> viper::Stmt<'v> { match self { Statement::Comment(statement) => statement.to_viper(context, ast), + Statement::Label(statement) => statement.to_viper(context, ast), Statement::LogEvent(statement) => statement.to_viper(context, ast), Statement::Assume(statement) => statement.to_viper(context, ast), Statement::Assert(statement) => statement.to_viper(context, ast), @@ -67,6 +68,12 @@ impl<'v> ToViper<'v, viper::Stmt<'v>> for statement::Comment { } } +impl<'v> ToViper<'v, viper::Stmt<'v>> for statement::Label { + fn to_viper(&self, _context: Context, ast: &AstFactory<'v>) -> viper::Stmt<'v> { + ast.label(&self.label, &[]) + } +} + impl<'v> ToViper<'v, viper::Stmt<'v>> for statement::LogEvent { fn to_viper(&self, context: Context, ast: &AstFactory<'v>) -> viper::Stmt<'v> { assert!( @@ -135,6 +142,11 @@ impl<'v> ToViper<'v, viper::Stmt<'v>> for statement::Fold { !self.position.is_default(), "Statement with default position: {self}" ); + assert!( + self.expression.is_predicate_access_predicate(), + "fold {}", + self.expression + ); ast.fold_with_pos( self.expression.to_viper(context, ast), self.position.to_viper(context, ast), @@ -148,6 +160,11 @@ impl<'v> ToViper<'v, viper::Stmt<'v>> for statement::Unfold { !self.position.is_default(), "Statement with default position: {self}" ); + assert!( + self.expression.is_predicate_access_predicate(), + "unfold {}", + self.expression + ); ast.unfold_with_pos( self.expression.to_viper(context, ast), self.position.to_viper(context, ast), @@ -199,6 +216,10 @@ impl<'v> ToViper<'v, viper::Stmt<'v>> for statement::MethodCall { impl<'v> ToViper<'v, viper::Stmt<'v>> for statement::Assign { fn to_viper(&self, context: Context, ast: &AstFactory<'v>) -> viper::Stmt<'v> { + assert!( + !self.position.is_default(), + "Statement with default position: {self}" + ); let target_expression = Expression::local(self.target.clone(), self.position); ast.abstract_assign( target_expression.to_viper(context, ast), @@ -225,7 +246,7 @@ impl<'v> ToViper<'v, viper::Expr<'v>> for Expression { Expression::MagicWand(expression) => expression.to_viper(context, ast), Expression::PredicateAccessPredicate(expression) => expression.to_viper(context, ast), // Expression::FieldAccessPredicate(expression) => expression.to_viper(context, ast), - // Expression::Unfolding(expression) => expression.to_viper(context, ast), + Expression::Unfolding(expression) => expression.to_viper(context, ast), Expression::UnaryOp(expression) => expression.to_viper(context, ast), Expression::BinaryOp(expression) => expression.to_viper(context, ast), Expression::PermBinaryOp(expression) => expression.to_viper(context, ast), @@ -351,6 +372,16 @@ impl<'v> ToViper<'v, viper::Expr<'v>> for expression::PredicateAccessPredicate { } } +impl<'v> ToViper<'v, viper::Expr<'v>> for expression::Unfolding { + fn to_viper(&self, context: Context, ast: &AstFactory<'v>) -> viper::Expr<'v> { + ast.unfolding_with_pos( + self.predicate.to_viper(context, ast), + self.base.to_viper(context, ast), + self.position.to_viper(context, ast), + ) + } +} + impl<'v> ToViper<'v, viper::Expr<'v>> for expression::UnaryOp { fn to_viper(&self, context: Context, ast: &AstFactory<'v>) -> viper::Expr<'v> { match self.op_kind { diff --git a/prusti-common/src/vir/low_to_viper/cfg.rs b/prusti-common/src/vir/low_to_viper/cfg.rs index 95a937d337c..fa3318c5426 100644 --- a/prusti-common/src/vir/low_to_viper/cfg.rs +++ b/prusti-common/src/vir/low_to_viper/cfg.rs @@ -15,14 +15,19 @@ impl<'a, 'v> ToViper<'v, viper::Method<'v>> for &'a ProcedureDecl { for local in &self.locals { declarations.push(local.to_viper_decl(context, ast).into()); } - for block in &self.basic_blocks { - declarations.push(block.label.to_viper_decl(context, ast).into()); - statements.push(block.label.to_viper(context, ast)); + let traversal_order = self.get_topological_sort(); + for label in &traversal_order { + let block = self.basic_blocks.get(label).unwrap(); + declarations.push(label.to_viper_decl(context, ast).into()); + statements.push(label.to_viper(context, ast)); statements.extend(block.statements.to_viper(context, ast)); statements.push(block.successor.to_viper(context, ast)); } statements.push(ast.label(RETURN_LABEL, &[])); declarations.push(ast.label(RETURN_LABEL, &[]).into()); + for label in &self.custom_labels { + declarations.push(label.to_viper_decl(context, ast).into()); + } let body = Some(ast.seqn(&statements, &declarations)); ast.method(&self.name, &[], &[], &[], &[], body) } diff --git a/prusti-common/src/vir/program.rs b/prusti-common/src/vir/program.rs index c7bc8765128..ba369aca7bc 100644 --- a/prusti-common/src/vir/program.rs +++ b/prusti-common/src/vir/program.rs @@ -21,12 +21,14 @@ impl Program { } } pub fn get_check_mode(&self) -> vir::common::check_mode::CheckMode { + // FIXME: Remove because this is not needed anymore. match self { - Program::Legacy(_) => vir::common::check_mode::CheckMode::Both, + Program::Legacy(_) => vir::common::check_mode::CheckMode::MemorySafetyWithFunctional, Program::Low(program) => program.check_mode, } } pub fn get_name_with_check_mode(&self) -> String { + // FIXME: Remove because this is not needed anymore. format!("{}-{}", self.get_name(), self.get_check_mode()) } } diff --git a/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs b/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs index 2c56de6d0a2..38f4fd89763 100644 --- a/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs +++ b/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs @@ -16,6 +16,12 @@ pub fn invariant(_attr: TokenStream, tokens: TokenStream) -> TokenStream { tokens } +#[cfg(not(feature = "prusti"))] +#[proc_macro_attribute] +pub fn structural_invariant(_attr: TokenStream, tokens: TokenStream) -> TokenStream { + tokens +} + #[cfg(not(feature = "prusti"))] #[proc_macro_attribute] pub fn ensures(_attr: TokenStream, tokens: TokenStream) -> TokenStream { @@ -118,6 +124,132 @@ pub fn body_variant(_tokens: TokenStream) -> TokenStream { TokenStream::new() } +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn manually_manage(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn pack(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn unpack(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn pack_ref(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn unpack_ref(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn pack_mut_ref(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn unpack_mut_ref(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn take_lifetime(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn join(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn join_range(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn split(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn split_range(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn stash_range(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn restore_stash_range(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn close_ref(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn open_ref(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn close_mut_ref(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn open_mut_ref(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn forget_initialization(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn restore(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn set_union_active_field(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + // ---------------------- // --- PRUSTI ENABLED --- @@ -204,7 +336,13 @@ pub fn extern_spec(attr: TokenStream, tokens: TokenStream) -> TokenStream { #[cfg(feature = "prusti")] #[proc_macro_attribute] pub fn invariant(attr: TokenStream, tokens: TokenStream) -> TokenStream { - prusti_specs::invariant(attr.into(), tokens.into()).into() + prusti_specs::invariant(attr.into(), tokens.into(), false).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro_attribute] +pub fn structural_invariant(attr: TokenStream, tokens: TokenStream) -> TokenStream { + prusti_specs::invariant(attr.into(), tokens.into(), true).into() } #[cfg(feature = "prusti")] @@ -249,5 +387,131 @@ pub fn body_variant(tokens: TokenStream) -> TokenStream { prusti_specs::body_variant(tokens.into()).into() } +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn manually_manage(tokens: TokenStream) -> TokenStream { + prusti_specs::manually_manage(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn pack(tokens: TokenStream) -> TokenStream { + prusti_specs::pack(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn unpack(tokens: TokenStream) -> TokenStream { + prusti_specs::unpack(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn pack_ref(tokens: TokenStream) -> TokenStream { + prusti_specs::pack_ref(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn unpack_ref(tokens: TokenStream) -> TokenStream { + prusti_specs::unpack_ref(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn pack_mut_ref(tokens: TokenStream) -> TokenStream { + prusti_specs::pack_mut_ref(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn unpack_mut_ref(tokens: TokenStream) -> TokenStream { + prusti_specs::unpack_mut_ref(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn take_lifetime(tokens: TokenStream) -> TokenStream { + prusti_specs::take_lifetime(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn join(tokens: TokenStream) -> TokenStream { + prusti_specs::join(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn join_range(tokens: TokenStream) -> TokenStream { + prusti_specs::join_range(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn split(tokens: TokenStream) -> TokenStream { + prusti_specs::split(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn split_range(tokens: TokenStream) -> TokenStream { + prusti_specs::split_range(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn stash_range(tokens: TokenStream) -> TokenStream { + prusti_specs::stash_range(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn restore_stash_range(tokens: TokenStream) -> TokenStream { + prusti_specs::restore_stash_range(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn close_ref(tokens: TokenStream) -> TokenStream { + prusti_specs::close_ref(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn open_ref(tokens: TokenStream) -> TokenStream { + prusti_specs::open_ref(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn close_mut_ref(tokens: TokenStream) -> TokenStream { + prusti_specs::close_mut_ref(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn open_mut_ref(tokens: TokenStream) -> TokenStream { + prusti_specs::open_mut_ref(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn forget_initialization(tokens: TokenStream) -> TokenStream { + prusti_specs::forget_initialization(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn restore(tokens: TokenStream) -> TokenStream { + prusti_specs::restore(tokens.into()).into() +} + +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn set_union_active_field(tokens: TokenStream) -> TokenStream { + prusti_specs::set_union_active_field(tokens.into()).into() +} + // Ensure that you've also crated a transparent `#[cfg(not(feature = "prusti"))]` // version of your new macro above! diff --git a/prusti-contracts/prusti-contracts/src/core_spec.rs b/prusti-contracts/prusti-contracts/src/core_spec.rs index 61fa73e42b8..3853b87477a 100644 --- a/prusti-contracts/prusti-contracts/src/core_spec.rs +++ b/prusti-contracts/prusti-contracts/src/core_spec.rs @@ -16,3 +16,48 @@ impl ::core::result::Result { #[requires(matches!(self, Ok(_)))] fn unwrap(self) -> T; } + +// Crashes ☹ +type Pointer = *const T; +#[extern_spec] +impl Pointer { + #[trusted] + #[terminates] + #[pure] + // FIXME: This is needed because this function is special cased only in the + // pure encoder and not in the impure one. + #[ensures(result == self.is_null())] + fn is_null(self) -> bool; +} + +type MutPointer = *mut T; +#[extern_spec] +impl MutPointer { + #[trusted] + #[terminates] + #[pure] + // FIXME: This is needed because this function is special cased only in the + // pure encoder and not in the impure one. + #[ensures(result == self.is_null())] + fn is_null(self) -> bool; +} + +#[extern_spec] +mod core { + mod mem { + #[pure] + // FIXME: This is needed because this function is special cased only in the + // pure encoder and not in the impure one. + #[ensures(result == core::mem::size_of::())] + pub fn size_of() -> usize; + + #[pure] + // FIXME: What are the guarantees? + // https://doc.rust-lang.org/std/mem/fn.align_of.html says nothing… + #[ensures(result > 0)] + // FIXME: This is needed because this function is special cased only in the + // pure encoder and not in the impure one. + #[ensures(result == core::mem::align_of::())] + pub fn align_of() -> usize; + } +} diff --git a/prusti-contracts/prusti-contracts/src/lib.rs b/prusti-contracts/prusti-contracts/src/lib.rs index 9c299eeb7ef..5f9bccffd28 100644 --- a/prusti-contracts/prusti-contracts/src/lib.rs +++ b/prusti-contracts/prusti-contracts/src/lib.rs @@ -21,6 +21,10 @@ pub use prusti_contracts_proc_macros::trusted; /// A macro for type invariants. pub use prusti_contracts_proc_macros::invariant; +/// A macro for structural type invariants. A type with a structural +/// invariant needs to be managed manually by the user. +pub use prusti_contracts_proc_macros::structural_invariant; + /// A macro for writing a loop body invariant. pub use prusti_contracts_proc_macros::body_invariant; @@ -60,6 +64,70 @@ pub use prusti_contracts_proc_macros::terminates; /// A macro to annotate body variant of a loop to prove termination pub use prusti_contracts_proc_macros::body_variant; +/// A macro to mark the place as manually managed. +pub use prusti_contracts_proc_macros::manually_manage; + +/// A macro to manually pack a place capability. +pub use prusti_contracts_proc_macros::pack; + +/// A macro to manually unpack a place capability. +pub use prusti_contracts_proc_macros::unpack; + +/// A macro to manually pack a place capability. +pub use prusti_contracts_proc_macros::pack_ref; + +/// A macro to manually unpack a place capability. +pub use prusti_contracts_proc_macros::unpack_ref; + +/// A macro to manually pack a place capability. +pub use prusti_contracts_proc_macros::pack_mut_ref; + +/// A macro to manually unpack a place capability. +pub use prusti_contracts_proc_macros::unpack_mut_ref; + +/// A macro to obtain a lifetime of a variable. +pub use prusti_contracts_proc_macros::take_lifetime; + +/// A macro to manually join a place capability. +pub use prusti_contracts_proc_macros::join; + +/// A macro to manually join a range of memory blocks into one. +pub use prusti_contracts_proc_macros::join_range; + +/// A macro to manually split a place capability. +pub use prusti_contracts_proc_macros::split; + +/// A macro to manually split a memory block into a range of memory blocks. +pub use prusti_contracts_proc_macros::split_range; + +/// A macro to stash away a range of own capabilities to get access to +/// underlying raw memory. +pub use prusti_contracts_proc_macros::stash_range; + +/// A macro to restore the stash away a range of own capabilities. +pub use prusti_contracts_proc_macros::restore_stash_range; + +/// A macro to manually close a reference. +pub use prusti_contracts_proc_macros::close_ref; + +/// A macro to manually open a reference. +pub use prusti_contracts_proc_macros::open_ref; + +/// A macro to manually close a reference. +pub use prusti_contracts_proc_macros::close_mut_ref; + +/// A macro to manually open a reference. +pub use prusti_contracts_proc_macros::open_mut_ref; + +/// A macro to forget that a place is initialized. +pub use prusti_contracts_proc_macros::forget_initialization; + +/// A macro to restore a place capability. +pub use prusti_contracts_proc_macros::restore; + +/// A macro to set a specific field of the union as active. +pub use prusti_contracts_proc_macros::set_union_active_field; + #[cfg(not(feature = "prusti"))] mod private { use core::marker::PhantomData; @@ -103,6 +171,11 @@ mod private { pub struct Ghost { _phantom: PhantomData, } + + /// A type allowing to refer to a lifetime in places where Rust syntax does + /// not allow it. It should not be possible to construct from Rust code, + /// hence the private unit inside. + pub struct Lifetime(()); } #[cfg(feature = "prusti")] @@ -115,15 +188,21 @@ mod private { /// A macro for defining a closure with a specification. pub use prusti_contracts_proc_macros::{closure, pure, trusted}; - pub fn prusti_set_union_active_field(_arg: T) { - unreachable!(); - } + // pub fn prusti_set_union_active_field(_arg: T) { + // unreachable!(); + // } #[pure] pub fn prusti_terminates_trusted() -> Int { Int::new(1) } + /// A type allowing to refer to a lifetime in places where Rust syntax does + /// not allow it. It should not be possible to construct from Rust code, + /// hence the private unit inside. + #[derive(Copy, Clone)] + pub struct Lifetime(()); + /// a mathematical (unbounded) integer type /// it should not be constructed from running rust code, hence the private unit inside #[derive(Copy, Clone, PartialEq, Eq)] @@ -338,4 +417,224 @@ pub fn snapshot_equality(_l: T, _r: T) -> bool { true } +#[doc(hidden)] +#[trusted] +pub fn prusti_manually_manage(_arg: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_pack_place(_arg: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_unpack_place(_arg: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_pack_ref_place(_lifetime_name: &'static str, _arg: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_unpack_ref_place(_lifetime_name: &'static str, _arg: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_pack__mut_ref_place(_lifetime_name: &'static str, _arg: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_unpack_mut_ref_place(_lifetime_name: &'static str, _arg: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_take_lifetime(_arg: T, _lifetime_name: &'static str) -> Lifetime { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_join_place(_arg: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_join_range(_arg: T, _start_index: usize, _end_index: usize) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_split_place(_arg: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_split_range(_arg: T, _start_index: usize, _end_index: usize) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_stash_range( + _arg: T, + _start_index: usize, + _end_index: usize, + _witness: &'static str, +) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_restore_stash_range(_arg: T, _new_start_index: usize, _witness: &'static str) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_close_ref_place(_witness: &'static str) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_open_ref_place(_lifetime: &'static str, _arg: T, _witness: &'static str) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_close_mut_ref_place(_witness: &'static str) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_open_mut_ref_place(_lifetime: &'static str, _arg: T, _witness: &'static str) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_forget_initialization(_arg: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_restore_place(_arg1: T, _arg2: T) { + unreachable!(); +} + +#[doc(hidden)] +#[trusted] +pub fn prusti_set_union_active_field(_arg: T) { + unreachable!(); +} + +/// Indicates that we have the `own` capability to the specified place. +#[doc(hidden)] +#[trusted] +pub fn prusti_own(_place: T) -> bool { + unreachable!(); +} + +#[macro_export] +macro_rules! own { + ($place:expr) => { + $crate::prusti_own(unsafe { core::ptr::addr_of!($place) }) + }; +} + +/// Indicates that we have the `own` capability to the specified range. +#[doc(hidden)] +#[trusted] +pub fn prusti_own_range(_address: T, _start: usize, _end: usize) -> bool { + unreachable!(); +} + +#[macro_export] +macro_rules! own_range { + ($address:expr, $start:expr, $end:expr) => { + $crate::prusti_own_range(unsafe { core::ptr::addr_of!($address) }, $start, $end) + }; +} + +/// Indicates that we have the `raw` capability to the specified address. +#[doc(hidden)] +#[trusted] +pub fn prusti_raw(_address: T, _size: usize) -> bool { + true +} + +#[macro_export] +macro_rules! raw { + ($place:expr, $size: expr) => { + $crate::prusti_raw(unsafe { core::ptr::addr_of!($place) }, $size) + }; +} + +/// Indicates that we have the `raw` capability to the specified range. +#[doc(hidden)] +#[trusted] +pub fn prusti_raw_range(_address: T, _size: usize, _start: usize, _end: usize) -> bool { + unreachable!(); +} + +#[macro_export] +macro_rules! raw_range { + ($address:expr, $size:expr, $start:expr, $end:expr) => { + $crate::prusti_raw_range( + unsafe { core::ptr::addr_of!($address) }, + $size, + $start, + $end, + ) + }; +} + +/// Indicates that we have the capability to deallocate. +#[doc(hidden)] +#[trusted] +pub fn prusti_raw_dealloc(_address: T, _size: usize) -> bool { + true +} + +#[macro_export] +macro_rules! raw_dealloc { + ($place:expr, $size: expr, $align: expr) => { + $crate::prusti_raw_dealloc(unsafe { core::ptr::addr_of!($place) }, $size) + }; +} + +/// Temporarily unpacks the owned predicate at the given location. +#[doc(hidden)] +#[trusted] +pub fn prusti_unpacking(_place: T, _body: U) -> U { + unimplemented!() +} + +#[macro_export] +macro_rules! unpacking { + ($place:expr, $body: expr) => { + $crate::prusti_unpacking(unsafe { core::ptr::addr_of!($place) }, $body) + }; +} + pub use private::*; diff --git a/prusti-contracts/prusti-specs/src/lib.rs b/prusti-contracts/prusti-specs/src/lib.rs index 80178b82247..5d44e96ce55 100644 --- a/prusti-contracts/prusti-specs/src/lib.rs +++ b/prusti-contracts/prusti-specs/src/lib.rs @@ -683,7 +683,7 @@ pub fn trusted(attr: TokenStream, tokens: TokenStream) -> TokenStream { } } -pub fn invariant(attr: TokenStream, tokens: TokenStream) -> TokenStream { +pub fn invariant(attr: TokenStream, tokens: TokenStream, is_structural: bool) -> TokenStream { let mut rewriter = rewriter::AstRewriter::new(); let spec_id = rewriter.generate_spec_id(); let spec_id_str = spec_id.to_string(); @@ -691,19 +691,33 @@ pub fn invariant(attr: TokenStream, tokens: TokenStream) -> TokenStream { let item: syn::DeriveInput = handle_result!(syn::parse2(tokens)); let item_span = item.span(); let item_ident = item.ident.clone(); + let item_name_structural = if is_structural { + "structural" + } else { + "non_structural" + }; let item_name = syn::Ident::new( - &format!("prusti_invariant_item_{item_ident}_{spec_id}"), + &format!( + "prusti_invariant_item_{}_{}_{}", + item_name_structural, item_ident, spec_id + ), item_span, ); let attr = handle_result!(parse_prusti(attr)); + let is_structural_tokens = if is_structural { + quote_spanned!(item_span => #[prusti::type_invariant_structural]) + } else { + quote_spanned!(item_span => #[prusti::type_invariant_non_structural]) + }; // TODO: move some of this to AstRewriter? // see AstRewriter::generate_spec_item_fn for explanation of syntax below let spec_item: syn::ItemFn = parse_quote_spanned! {item_span=> #[allow(unused_must_use, unused_parens, unused_variables, dead_code, non_snake_case)] #[prusti::spec_only] #[prusti::type_invariant_spec] + #is_structural_tokens #[prusti::spec_id = #spec_id_str] fn #item_name(self) -> bool { !!((#attr) : bool) @@ -1090,3 +1104,354 @@ pub fn ghost(tokens: TokenStream) -> TokenStream { syn_errors } } + +pub fn manually_manage(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_manually_manage}) +} + +pub fn pack(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_pack_place}) +} + +pub fn unpack(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_unpack_place}) +} + +pub fn pack_ref(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_pack_ref_place}) +} + +pub fn unpack_ref(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_unpack_ref_place}) +} + +pub fn pack_mut_ref(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_pack_mut_ref_place}) +} + +pub fn unpack_mut_ref(tokens: TokenStream) -> TokenStream { + // generate_place_function(tokens, quote!{prusti_unpack_mut_ref_place}) + let (lifetime_name, reference) = + handle_result!(parse_two_expressions::(tokens)); + let lifetime_name_str = handle_result!(expression_to_string(&lifetime_name)); + unsafe_spec_function_call(quote! { + prusti_unpack_mut_ref_place(#lifetime_name_str, std::ptr::addr_of!(#reference)) + }) +} + +fn parse_two_expressions( + tokens: TokenStream, +) -> syn::Result<(syn::Expr, syn::Expr)> { + let parser = syn::punctuated::Punctuated::::parse_terminated; + let mut expressions = syn::parse::Parser::parse2(parser, tokens)?; + let second = expressions + .pop() + .ok_or_else(|| syn::Error::new(Span::call_site(), "Expected two expressions"))?; + let first = expressions + .pop() + .ok_or_else(|| syn::Error::new(Span::call_site(), "Expected two expressions"))?; + Ok((first.into_value(), second.into_value())) +} + +fn parse_three_expressions( + tokens: TokenStream, +) -> syn::Result<(syn::Expr, syn::Expr, syn::Expr)> { + let parser = syn::punctuated::Punctuated::::parse_terminated; + let mut expressions = syn::parse::Parser::parse2(parser, tokens)?; + let third = expressions + .pop() + .ok_or_else(|| syn::Error::new(Span::call_site(), "Expected three expressions"))?; + let second = expressions + .pop() + .ok_or_else(|| syn::Error::new(Span::call_site(), "Expected three expressions"))?; + let first = expressions + .pop() + .ok_or_else(|| syn::Error::new(Span::call_site(), "Expected three expressions"))?; + Ok((first.into_value(), second.into_value(), third.into_value())) +} + +fn parse_four_expressions( + tokens: TokenStream, +) -> syn::Result<(syn::Expr, syn::Expr, syn::Expr, syn::Expr)> { + let parser = syn::punctuated::Punctuated::::parse_terminated; + let mut expressions = syn::parse::Parser::parse2(parser, tokens)?; + let fourth = expressions + .pop() + .ok_or_else(|| syn::Error::new(Span::call_site(), "Expected four expressions"))?; + let third = expressions + .pop() + .ok_or_else(|| syn::Error::new(Span::call_site(), "Expected four expressions"))?; + let second = expressions + .pop() + .ok_or_else(|| syn::Error::new(Span::call_site(), "Expected four expressions"))?; + let first = expressions + .pop() + .ok_or_else(|| syn::Error::new(Span::call_site(), "Expected four expressions"))?; + Ok(( + first.into_value(), + second.into_value(), + third.into_value(), + fourth.into_value(), + )) +} + +fn expression_to_string(expr: &syn::Expr) -> syn::Result { + if let syn::Expr::Path(syn::ExprPath { + qself: None, path, .. + }) = expr + { + if let Some(ident) = path.get_ident() { + return Ok(ident.to_string()); + } + } + Err(syn::Error::new(expr.span(), "needs to be an identifier")) +} + +pub fn unsafe_spec_function_call(call: TokenStream) -> TokenStream { + let callsite_span = Span::call_site(); + quote_spanned! { callsite_span => + #[allow(unused_must_use, unused_variables)] + #[prusti::specs_version = #SPECS_VERSION] + if false { + #[prusti::spec_only] + || -> bool { true }; + unsafe { #call }; + } + } +} + +pub fn take_lifetime(tokens: TokenStream) -> TokenStream { + let (reference, lifetime_name) = + handle_result!(parse_two_expressions::(tokens)); + let lifetime_name_str = handle_result!(expression_to_string(&lifetime_name)); + unsafe_spec_function_call(quote! { + prusti_take_lifetime(std::ptr::addr_of!(#reference), #lifetime_name_str) + }) + // let parser = syn::punctuated::Punctuated::]>::parse_terminated; + // let mut args = handle_result!(syn::parse::Parser::parse2(parser, tokens)); + // let lifetime = if let Some(lifetime) = args.pop() { + // lifetime.into_value() + // } else { + // return syn::Error::new( + // args.span(), + // "`take_lifetime!` needs to contain two arguments `` and ``" + // ).to_compile_error(); + // }; + // let lifetime_str = if let syn::Expr::Path(syn::ExprPath { qself: None, path, ..}) = lifetime { + // if let Some(ident) = path.get_ident() { + // ident.to_string() + // } else { + // return syn::Error::new( + // path.span(), + // "lifetime name needs to be an identifier" + // ).to_compile_error(); + // } + // } else { + // return syn::Error::new( + // lifetime.span(), + // "lifetime name needs to be an identifier" + // ).to_compile_error(); + // }; + // let reference = if let Some(reference) = args.pop() { + // reference.into_value() + // } else { + // return syn::Error::new( + // args.span(), + // "`take_lifetime!` needs to contain two arguments `` and ``" + // ).to_compile_error(); + // }; + // let callsite_span = Span::call_site(); + // quote_spanned! { callsite_span => + // #[allow(unused_must_use, unused_variables)] + // #[prusti::specs_version = #SPECS_VERSION] + // if false { + // #[prusti::spec_only] + // || -> bool { true }; + // unsafe { prusti_take_lifetime(std::ptr::addr_of!(#reference), #lifetime_str) }; + // } + // } +} + +pub fn join(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_join_place}) +} + +pub fn join_range(tokens: TokenStream) -> TokenStream { + let (pointer, start_index, end_index) = + handle_result!(parse_three_expressions::(tokens)); + unsafe_spec_function_call(quote! { + prusti_join_range(std::ptr::addr_of!(#pointer), {#start_index}, #end_index) + }) +} + +pub fn split(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_split_place}) +} + +pub fn split_range(tokens: TokenStream) -> TokenStream { + let (pointer, start_index, end_index) = + handle_result!(parse_three_expressions::(tokens)); + unsafe_spec_function_call(quote! { + prusti_split_range(std::ptr::addr_of!(#pointer), {#start_index}, #end_index) + }) +} + +/// FIXME: For `start_index` and `end_index`, we should do the same as for +/// `body_invariant!`. +pub fn stash_range(tokens: TokenStream) -> TokenStream { + let (pointer, start_index, end_index, witness) = + handle_result!(parse_four_expressions::(tokens)); + let witness_str = handle_result!(expression_to_string(&witness)); + unsafe_spec_function_call(quote! { + prusti_stash_range(std::ptr::addr_of!(#pointer), {#start_index}, {#end_index}, #witness_str) + }) +} + +/// FIXME: For `new_start_index`, we should do the same as for +/// `body_invariant!`. +pub fn restore_stash_range(tokens: TokenStream) -> TokenStream { + let (pointer, new_start_index, witness) = + handle_result!(parse_three_expressions::(tokens)); + let witness_str = handle_result!(expression_to_string(&witness)); + unsafe_spec_function_call(quote! { + prusti_restore_stash_range(std::ptr::addr_of!(#pointer), {#new_start_index}, #witness_str) + }) +} + +pub fn close_ref(tokens: TokenStream) -> TokenStream { + // generate_place_function(tokens, quote!{prusti_close_ref_place}); + let witness: syn::Ident = handle_result!(syn::parse2(tokens)); + let witness_str = witness.to_string(); + // let callsite_span = Span::call_site(); + unsafe_spec_function_call(quote! { prusti_close_ref_place(#witness_str) }) + // quote_spanned! { callsite_span => + // #[allow(unused_must_use, unused_variables)] + // #[prusti::specs_version = #SPECS_VERSION] + // if false { + // #[prusti::spec_only] + // || -> bool { true }; + // unsafe { prusti_close_ref_place(#witness_str) }; + // } + // } +} + +pub fn open_ref(tokens: TokenStream) -> TokenStream { + let (lifetime_name, reference, witness) = + handle_result!(parse_three_expressions::(tokens)); + let lifetime_name_str = handle_result!(expression_to_string(&lifetime_name)); + let witness_str = handle_result!(expression_to_string(&witness)); + unsafe_spec_function_call(quote! { + prusti_open_ref_place(#lifetime_name_str, std::ptr::addr_of!(#reference), #witness_str) + }) + // let parser = syn::punctuated::Punctuated::]>::parse_terminated; + // let mut args = handle_result!(syn::parse::Parser::parse2(parser, tokens)); + // let witness = if let Some(witness) = args.pop() { + // witness.into_value() + // } else { + // return syn::Error::new( + // args.span(), + // "`open_ref!` needs to contain two arguments `` and ``" + // ).to_compile_error(); + // }; + // let witness_str = if let syn::Expr::Path(syn::ExprPath { qself: None, path, ..}) = witness { + // if let Some(ident) = path.get_ident() { + // ident.to_string() + // } else { + // return syn::Error::new( + // path.span(), + // "witness needs to be an identifier" + // ).to_compile_error(); + // } + // } else { + // return syn::Error::new( + // witness.span(), + // "witness needs to be an identifier" + // ).to_compile_error(); + // }; + // let reference = if let Some(reference) = args.pop() { + // reference.into_value() + // } else { + // return syn::Error::new( + // args.span(), + // "`open_ref!` needs to contain two arguments `` and ``" + // ).to_compile_error(); + // }; + // let callsite_span = Span::call_site(); + // quote_spanned! { callsite_span => + // #[allow(unused_must_use, unused_variables)] + // #[prusti::specs_version = #SPECS_VERSION] + // if false { + // #[prusti::spec_only] + // || -> bool { true }; + // unsafe { prusti_open_ref_place(std::ptr::addr_of!(#reference), #witness_str) }; + // } + // } +} + +pub fn close_mut_ref(tokens: TokenStream) -> TokenStream { + let witness: syn::Ident = handle_result!(syn::parse2(tokens)); + let witness_str = witness.to_string(); + unsafe_spec_function_call(quote! { prusti_close_mut_ref_place(#witness_str) }) +} + +pub fn open_mut_ref(tokens: TokenStream) -> TokenStream { + let (lifetime_name, reference, witness) = + handle_result!(parse_three_expressions::(tokens)); + let lifetime_name_str = handle_result!(expression_to_string(&lifetime_name)); + let witness_str = handle_result!(expression_to_string(&witness)); + unsafe_spec_function_call(quote! { + prusti_open_mut_ref_place(#lifetime_name_str, std::ptr::addr_of!(#reference), #witness_str) + }) +} + +pub fn set_union_active_field(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_set_union_active_field}) +} + +pub fn forget_initialization(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_forget_initialization}) +} + +fn generate_place_function(tokens: TokenStream, function: TokenStream) -> TokenStream { + let callsite_span = Span::call_site(); + quote_spanned! { callsite_span => + #[allow(unused_must_use, unused_variables)] + #[prusti::specs_version = #SPECS_VERSION] + if false { + #[prusti::spec_only] + || -> bool { true }; + unsafe { #function(std::ptr::addr_of!(#tokens)) }; + } + } +} + +pub fn restore(tokens: TokenStream) -> TokenStream { + let parser = syn::punctuated::Punctuated::::parse_terminated; + let mut args = handle_result!(syn::parse::Parser::parse2(parser, tokens)); + let restored_place = if let Some(restored_place) = args.pop() { + restored_place.into_value() + } else { + return syn::Error::new( + args.span(), + "`restore!` needs to contain two arguments `` and ``" + ).to_compile_error(); + }; + let borrowing_place = if let Some(borrowing_place) = args.pop() { + borrowing_place.into_value() + } else { + return syn::Error::new( + args.span(), + "`restore!` needs to contain two arguments `` and ``" + ).to_compile_error(); + }; + let callsite_span = Span::call_site(); + quote_spanned! { callsite_span => + #[allow(unused_must_use, unused_variables)] + #[prusti::specs_version = #SPECS_VERSION] + if false { + #[prusti::spec_only] + || -> bool { true }; + unsafe { prusti_restore_place(std::ptr::addr_of!(#borrowing_place), std::ptr::addr_of!(#restored_place)) }; + } + } +} diff --git a/prusti-interface/src/environment/body.rs b/prusti-interface/src/environment/body.rs index 43c24eaea55..73f8b0fef00 100644 --- a/prusti-interface/src/environment/body.rs +++ b/prusti-interface/src/environment/body.rs @@ -182,12 +182,27 @@ impl<'tcx> EnvBody<'tcx> { /// Get the MIR body of a local impure function, without any substitutions. pub fn get_impure_fn_body_identity(&self, def_id: LocalDefId) -> MirBody<'tcx> { - let mut impure = self.local_impure_fns.borrow_mut(); - impure - .entry(def_id) - .or_insert_with(|| Self::load_local_mir_with_facts(self.tcx, def_id)) - .body - .clone() + // let mut impure = self.local_impure_fns.borrow_mut(); + // impure + // .entry(def_id) + // .or_insert_with(|| Self::load_local_mir_with_facts(self.tcx, def_id)) + // .body + // .clone() + self.borrow_impure_fn_body_identity(def_id).clone() + } + + /// Borrow the MIR body of a local impure function, without any substitutions. + pub fn borrow_impure_fn_body_identity( + &self, + def_id: LocalDefId, + ) -> std::cell::RefMut> { + let impure = self.local_impure_fns.borrow_mut(); + std::cell::RefMut::map(impure, |impure| { + &mut impure + .entry(def_id) + .or_insert_with(|| Self::load_local_mir_with_facts(self.tcx, def_id)) + .body + }) } /// Get the MIR body of a local impure function, monomorphised diff --git a/prusti-interface/src/environment/mod.rs b/prusti-interface/src/environment/mod.rs index eedb5a0130d..d6495a05efa 100644 --- a/prusti-interface/src/environment/mod.rs +++ b/prusti-interface/src/environment/mod.rs @@ -140,21 +140,19 @@ impl<'tcx> Environment<'tcx> { called_def_id: ProcedureDefId, call_substs: SubstsRef<'tcx>, ) -> bool { - if called_def_id == caller_def_id { - true - } else { + if called_def_id != caller_def_id && called_def_id.is_local() { let param_env = self.tcx().param_env(caller_def_id); if let Some(instance) = self .tcx() .resolve_instance(param_env.and((called_def_id, call_substs))) .unwrap() { - self.tcx() - .mir_callgraph_reachable((instance, caller_def_id.expect_local())) - } else { - true + return self + .tcx() + .mir_callgraph_reachable((instance, caller_def_id.expect_local())); } } + true } /// Get the current version of the `prusti` crate diff --git a/prusti-interface/src/lib.rs b/prusti-interface/src/lib.rs index d2f8fd5c803..b8fdbdeec58 100644 --- a/prusti-interface/src/lib.rs +++ b/prusti-interface/src/lib.rs @@ -8,6 +8,7 @@ #![deny(unused_must_use)] #![deny(unsafe_op_in_unsafe_fn)] +#![allow(clippy::nonminimal_bool)] #![feature(rustc_private)] #![feature(box_syntax)] #![feature(box_patterns)] diff --git a/prusti-interface/src/specs/mod.rs b/prusti-interface/src/specs/mod.rs index 4362b448a3a..48a0a7406a6 100644 --- a/prusti-interface/src/specs/mod.rs +++ b/prusti-interface/src/specs/mod.rs @@ -57,6 +57,7 @@ impl From<&ProcedureSpecRefs> for ProcedureSpecificationKind { #[derive(Debug, Default)] struct TypeSpecRefs { invariants: Vec, + structural_invariants: Vec, trusted: bool, model: Option<(String, LocalDefId)>, countexample_print: Vec<(Option, LocalDefId)>, @@ -155,7 +156,7 @@ impl<'a, 'tcx> SpecCollector<'a, 'tcx> { ))); } SpecIdRef::Terminates(spec_id) => { - spec.set_terminates(*self.spec_functions.get(spec_id).unwrap()); + spec.set_terminates(self.spec_functions.get(spec_id).unwrap().to_def_id()); } } } @@ -218,7 +219,9 @@ impl<'a, 'tcx> SpecCollector<'a, 'tcx> { fn determine_type_specs(&self, def_spec: &mut typed::DefSpecificationMap) { for (type_id, refs) in self.type_specs.iter() { - if !refs.invariants.is_empty() && !prusti_common::config::enable_type_invariants() { + if !(refs.invariants.is_empty() && refs.structural_invariants.is_empty()) + && !prusti_common::config::enable_type_invariants() + { let span = self.env.query.get_def_span(*type_id); PrustiError::unsupported( "Type invariants need to be enabled with the feature flag `enable_type_invariants`", @@ -238,6 +241,13 @@ impl<'a, 'tcx> SpecCollector<'a, 'tcx> { .map(LocalDefId::to_def_id) .collect(), ), + structural_invariant: SpecificationItem::Inherent( + refs.structural_invariants + .clone() + .into_iter() + .map(LocalDefId::to_def_id) + .collect(), + ), trusted: SpecificationItem::Inherent(refs.trusted), model: refs.model.clone(), counterexample_print: refs.countexample_print.clone(), @@ -437,11 +447,20 @@ impl<'a, 'tcx> intravisit::Visitor<'tcx> for SpecCollector<'a, 'tcx> { let hir = self.env.query.hir(); let impl_id = hir.parent_id(hir.parent_id(self_id)); let type_id = get_type_id_from_impl_node(hir.get(impl_id)).unwrap(); - self.type_specs - .entry(type_id.as_local().unwrap()) - .or_default() - .invariants - .push(local_id); + if has_prusti_attr(attrs, "type_invariant_structural") { + self.type_specs + .entry(type_id.as_local().unwrap()) + .or_default() + .structural_invariants + .push(local_id); + } else { + assert!(has_prusti_attr(attrs, "type_invariant_non_structural")); + self.type_specs + .entry(type_id.as_local().unwrap()) + .or_default() + .invariants + .push(local_id); + } } // Collect trusted type flag diff --git a/prusti-interface/src/specs/typed.rs b/prusti-interface/src/specs/typed.rs index e12290812db..01ecb690ab7 100644 --- a/prusti-interface/src/specs/typed.rs +++ b/prusti-interface/src/specs/typed.rs @@ -79,7 +79,7 @@ impl DefSpecificationMap { specs.extend(posts); } if let Some(Some(term)) = spec.terminates.extract_with_selective_replacement() { - specs.push(term.to_def_id()); + specs.push(*term); } if let Some(pledges) = spec.pledges.extract_with_selective_replacement() { specs.extend(pledges.iter().filter_map(|pledge| pledge.lhs)); @@ -102,6 +102,12 @@ impl DefSpecificationMap { if let Some(invariants) = spec.invariant.extract_with_selective_replacement() { specs.extend(invariants); } + if let Some(invariants) = spec + .structural_invariant + .extract_with_selective_replacement() + { + specs.extend(invariants); + } } (specs, pure_fns, predicates) } @@ -202,7 +208,7 @@ pub struct ProcedureSpecification { pub posts: SpecificationItem>, pub pledges: SpecificationItem>, pub trusted: SpecificationItem, - pub terminates: SpecificationItem>, + pub terminates: SpecificationItem>, pub purity: SpecificationItem>, // for type-conditional spec refinements } @@ -266,6 +272,7 @@ pub struct TypeSpecification { // `extern_spec` for type invs is supported it could differ. pub source: DefId, pub invariant: SpecificationItem>, + pub structural_invariant: SpecificationItem>, pub trusted: SpecificationItem, pub model: Option<(String, LocalDefId)>, pub counterexample_print: Vec<(Option, LocalDefId)>, @@ -276,6 +283,7 @@ impl TypeSpecification { TypeSpecification { source, invariant: SpecificationItem::Empty, + structural_invariant: SpecificationItem::Empty, trusted: SpecificationItem::Inherent(false), model: None, counterexample_print: vec![], @@ -472,7 +480,7 @@ impl SpecGraph { } /// Sets the termination flag for the base spec and all constrained specs. - pub fn set_terminates(&mut self, terminates: LocalDefId) { + pub fn set_terminates(&mut self, terminates: DefId) { self.base_spec.terminates.set(Some(terminates)); self.specs_with_constraints .values_mut() diff --git a/prusti-server/src/process_verification.rs b/prusti-server/src/process_verification.rs index bc227525505..31bf6cbc0da 100644 --- a/prusti-server/src/process_verification.rs +++ b/prusti-server/src/process_verification.rs @@ -99,13 +99,13 @@ pub fn process_verification_request<'v, 't: 'v>( ast_utils.with_local_frame(16, || { let viper_program = build_or_dump_viper_program(); - let program_name = request.program.get_name(); + let program_name = request.program.get_name_with_check_mode(); // Create a new verifier each time. // Workaround for https://github.com/viperproject/prusti-dev/issues/744 let mut stopwatch = Stopwatch::start("prusti-server", "verifier startup"); let mut verifier = - new_viper_verifier(program_name, verification_context, request.backend_config); + new_viper_verifier(&program_name, verification_context, request.backend_config); stopwatch.start_next("verification"); let mut result = verifier.verify(viper_program); diff --git a/prusti-tests/tests/verify_overflow/fail/core_proof/custom_heap_encoding/simple.rs b/prusti-tests/tests/verify_overflow/fail/core_proof/custom_heap_encoding/simple.rs new file mode 100644 index 00000000000..0afc4a2d794 --- /dev/null +++ b/prusti-tests/tests/verify_overflow/fail/core_proof/custom_heap_encoding/simple.rs @@ -0,0 +1,17 @@ +// compile-flags: -Punsafe_core_proof=true + +use prusti_contracts::*; + +unsafe fn test_assert1() { + let a = 5; + assert!(a == 5); +} + +unsafe fn test_assert2() { + let a = 5; + assert!(a == 6); //~ ERROR: the asserted expression might not hold +} + +#[trusted] +fn main() {} + diff --git a/prusti-tests/tests/verify_overflow/fail/core_proof/framing.rs b/prusti-tests/tests/verify_overflow/fail/core_proof/framing.rs new file mode 100644 index 00000000000..83c1d7dfd57 --- /dev/null +++ b/prusti-tests/tests/verify_overflow/fail/core_proof/framing.rs @@ -0,0 +1,408 @@ +// compile-flags: -Punsafe_core_proof=true -Penable_type_invariants=true + +#![deny(unsafe_op_in_unsafe_fn)] + +use prusti_contracts::*; + +// TODO: Check only on the definition side. Add tests. + +//#[ensures(!result.is_null() ==> own!((*result).x) && unsafe { (*result).x } == 5)] +//unsafe fn test01() -> *mut Pair { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).x = 5; } + //} + //pair +//} + +//#[ensures(!result.is_null() ==> unsafe { (*result).x } == 5)] //~ ERROR: the place must be framed by permissions +//unsafe fn test01_non_framed() -> *mut Pair { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).x = 5; } + //} + //pair +//} + +//#[ensures(!result.is_null() ==> own!(*result) && unsafe { (*result).x } == 5)] +//unsafe fn test02() -> *mut Pair { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //pair +//} + +//#[ensures(!result.is_null() ==> own!(*result) && unsafe { (*result).x } == 5)] +//unsafe fn test02_missing_pack() -> *mut Pair { //~ ERROR: there might be insufficient permission to dereference a raw pointer + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //} + //pair +//} + +//#[ensures(!result.is_null() ==> unsafe { (*result).x } == 5)] //~ ERROR: there might be insufficient permission to dereference a raw pointer + ////^ ERROR: the postcondition might not be self-framing +//unsafe fn test02_non_framed() -> *mut Pair { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //pair +//} + +//#[ensures(!result.is_null() ==> own!(*result) && unsafe { (*result).x } == 5)] //~ ERROR: only unsafe functions can use permissions in their contracts +//fn test02_safe() -> *mut Pair { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //pair +//} + +//#[ensures(!result.is_null() ==> //~ ERROR: permission predicates can be only in positive positions + //own!(*result) && unsafe { !(*result).is_null() } ==> + //own!(**result) && unsafe { (**result).x } == 5)] +//unsafe fn test03() -> *mut *mut Pair { + //let pp = unsafe { + //alloc(std::mem::size_of::<*mut Pair>(), std::mem::align_of::<*mut Pair>()) + //}; + //let ppair = (pp as *mut *mut Pair); + //ppair +//} + +//#[ensures(!result.is_null() ==> + //own!(*result) && ( + //unsafe { !(*result).is_null() } ==> + //own!(**result) && unsafe { (**result).x } == 5))] +//unsafe fn test04() -> *mut *mut Pair { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pp = unsafe { + //alloc(std::mem::size_of::<*mut Pair>(), std::mem::align_of::<*mut Pair>()) + //}; + //let pair = (p as *mut Pair); + //let ppair = (pp as *mut *mut Pair); + //let mut v = 0; + //if !ppair.is_null() { + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //unsafe { *ppair = pair; } + //if !pair.is_null() { + //unpack!(**ppair); + //unsafe { v = (**ppair).x; } + //pack!(**ppair); + //} + //} + //ppair +//} + +//#[ensures(!result.is_null() ==> //~ ERROR: postcondition might not hold + //own!(*result) && ( + //unsafe { !(*result).is_null() } ==> + //own!(**result) && unsafe { (**result).x } == 6))] +//unsafe fn test04_wrong_value() -> *mut *mut Pair { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pp = unsafe { + //alloc(std::mem::size_of::<*mut Pair>(), std::mem::align_of::<*mut Pair>()) + //}; + //let pair = (p as *mut Pair); + //let ppair = (pp as *mut *mut Pair); + //let mut v = 0; + //if !ppair.is_null() { + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //unsafe { *ppair = pair; } + //if !pair.is_null() { + //unpack!(**ppair); + //unsafe { v = (**ppair).x; } + //pack!(**ppair); + //} + //} + //ppair +//} + +//#[ensures(!result.1.is_null() ==> + //own!(*result.1) && ( + //unsafe { !(*result.1).is_null() } ==> + //own!(**result.1) && unsafe { (**result.1).x } == 5 && + //result.0 == 5))] +//unsafe fn test05() -> (i32, *mut *mut Pair) { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pp = unsafe { + //alloc(std::mem::size_of::<*mut Pair>(), std::mem::align_of::<*mut Pair>()) + //}; + //let pair = (p as *mut Pair); + //let ppair = (pp as *mut *mut Pair); + //let mut v = 0; + //if !ppair.is_null() { + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //unsafe { *ppair = pair; } + //if !pair.is_null() { + //unpack!(**ppair); + //unsafe { v = (**ppair).x; } + //pack!(**ppair); + //} + //} + //(v, ppair) +//} + +//#[ensures(!result.1.is_null() ==> //~ ERROR: postcondition might not hold + //own!(*result.1) && ( + //unsafe { !(*result.1).is_null() } ==> + //own!(**result.1) && unsafe { (**result.1).x } == 6 && + //result.0 == 5))] +//unsafe fn test05_wrong_value_1() -> (i32, *mut *mut Pair) { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pp = unsafe { + //alloc(std::mem::size_of::<*mut Pair>(), std::mem::align_of::<*mut Pair>()) + //}; + //let pair = (p as *mut Pair); + //let ppair = (pp as *mut *mut Pair); + //let mut v = 0; + //if !ppair.is_null() { + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //unsafe { *ppair = pair; } + //if !pair.is_null() { + //unpack!(**ppair); + //unsafe { v = (**ppair).x; } + //pack!(**ppair); + //} + //} + //(v, ppair) +//} + +//#[ensures(!result.1.is_null() ==> //~ ERROR: postcondition might not hold + //own!(*result.1) && ( + //unsafe { !(*result.1).is_null() } ==> + //own!(**result.1) && unsafe { (**result.1).x } == 5 && + //result.0 == 6))] +//unsafe fn test05_wrong_value_2() -> (i32, *mut *mut Pair) { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pp = unsafe { + //alloc(std::mem::size_of::<*mut Pair>(), std::mem::align_of::<*mut Pair>()) + //}; + //let pair = (p as *mut Pair); + //let ppair = (pp as *mut *mut Pair); + //let mut v = 0; + //if !ppair.is_null() { + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //unsafe { *ppair = pair; } + //if !pair.is_null() { + //unpack!(**ppair); + //unsafe { v = (**ppair).x; } + //pack!(**ppair); + //} + //} + //(v, ppair) +//} + +//#[structural_invariant(!self.p.is_null() ==> own!(*self.p) && unsafe { (*self.p).x } == 5)] +//struct T6 { + //p: *mut Pair, +//} + +//fn test06(_: T6) {} + +//#[structural_invariant(!self.p.is_null() ==> unsafe { (*self.p).x } == 5)] +//struct T6MissingOwn { //~ ERROR: there might be insufficient permission to dereference a raw pointer + //p: *mut Pair, +//} + +//fn test06_missing_own(_: T6MissingOwn) {} + +//#[structural_invariant(!self.p.is_null() ==> own!(*self.p))] +#[structural_invariant(!self.p.is_null() ==> own!(*self.p) && unsafe {(*self.p).x} == 5)] +struct T4 { + p: *mut Pair, +} + +//#[ensures(!result.p.is_null() ==> unsafe { (*result.p).x } == 5)] +//unsafe fn test04() -> T4 { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //T4 { p: pair } +//} + +//#[ensures(unsafe { (*result.p).x } == 5)] +//unsafe fn test04_not_framed() -> T4 { //~ ERROR: there might be insufficient permission to dereference a raw pointer + ////^ ERROR: the postcondition might not be self-framing. + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //T4 { p: pair } +//} + +#[structural_invariant(!self.p.is_null() ==> own!((*self.p).x))] +struct T5 { + p: *mut Pair, +} + +#[ensures(!result.p.is_null() ==> unsafe { (*result.p).x } == 5)] +fn test05_safe() -> T5 { + let p = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let pair = (p as *mut Pair); + if !pair.is_null() { + split!(*pair); + //unsafe { (*pair).y = 4; } + unsafe { (*pair).x = 5; } + //pack!(*pair); + } + T5 { p: pair } +} + +//#[ensures(unsafe { (*result.p).x } == 5)] //~ ERROR: postcondition might not hold +//fn test04_safe_not_framed() -> T4 { //~ ERROR: there might be insufficient permission to dereference a raw pointer + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //T4 { p: pair } +//} + +//#[structural_invariant(!self.p.is_null() ==> own!((*self.p).x))] +//struct T2 { + //p: *mut Pair, +//} + +//#[ensures(!result.p.is_null() ==> framed!((*result.p).x, unsafe { (*result.p).x }) == 5)] +//fn test03() -> T1 { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).y = 4; } + //unsafe { (*pair).x = 5; } + //pack!(*pair); + //} + //T1 { p: pair } +//} + +//#[ensures(!result.p.is_null() ==> unsafe { (*result.p).x } == 5)] //~ ERROR: Permissions +//fn test03_non_framed() -> T1 { + //let p = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let pair = (p as *mut Pair); + //if !pair.is_null() { + //split!(*pair); + //unsafe { (*pair).x = 5; } + //} + //T1 { p: pair } +//} + +#[trusted] +#[requires(align > 0)] +#[ensures(!result.is_null() ==> ( + raw!(*result, size) && + raw_dealloc!(*result, size, align) +))] +// https://doc.rust-lang.org/alloc/alloc/fn.alloc.html +unsafe fn alloc(size: usize, align: usize) -> *mut u8 { + unimplemented!(); +} + +#[trusted] +#[requires( + raw!(*ptr, size) && + raw_dealloc!(*ptr, size, align) +)] +unsafe fn dealloc(ptr: *mut u8, size: usize, align: usize) { + unimplemented!(); +} + +struct Pair { + x: i32, + y: i32, +} + +fn main() {} diff --git a/prusti-tests/tests/verify_overflow/fail/core_proof/framing/functions.rs b/prusti-tests/tests/verify_overflow/fail/core_proof/framing/functions.rs new file mode 100644 index 00000000000..25deada0649 --- /dev/null +++ b/prusti-tests/tests/verify_overflow/fail/core_proof/framing/functions.rs @@ -0,0 +1,47 @@ +// compile-flags: -Punsafe_core_proof=true -Penable_type_invariants=true + +#![deny(unsafe_op_in_unsafe_fn)] + +use prusti_contracts::*; + +#[structural_invariant(!self.p.is_null() ==> own!(*self.p))] +struct T1 { + p: *mut i32, +} + +#[pure] +fn test1_get_p(x: &T1) -> i32 { + if self.p.is_null() { + 0 + } else { + unsafe { *self.p } + } +} + + +#[trusted] +#[requires(align > 0)] +#[ensures(!result.is_null() ==> ( + raw!(*result, size) && + raw_dealloc!(*result, size, align) +))] +// https://doc.rust-lang.org/alloc/alloc/fn.alloc.html +unsafe fn alloc(size: usize, align: usize) -> *mut u8 { + unimplemented!(); +} + +#[trusted] +#[requires( + raw!(*ptr, size) && + raw_dealloc!(*ptr, size, align) +)] +unsafe fn dealloc(ptr: *mut u8, size: usize, align: usize) { + unimplemented!(); +} + +struct Pair { + x: i32, + y: i32, +} + +fn main() {} diff --git a/prusti-tests/tests/verify_overflow/fail/core_proof/framing/simple.rs b/prusti-tests/tests/verify_overflow/fail/core_proof/framing/simple.rs new file mode 100644 index 00000000000..cd8b3c7991e --- /dev/null +++ b/prusti-tests/tests/verify_overflow/fail/core_proof/framing/simple.rs @@ -0,0 +1,244 @@ +// compile-flags: -Punsafe_core_proof=true -Penable_type_invariants=true + +#![deny(unsafe_op_in_unsafe_fn)] + +use prusti_contracts::*; + +#[ensures(!result.is_null() ==> unsafe { *result } == 5)] //~ ERROR: the place must be framed by permissions +fn test01_safe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + } + p +} + +#[ensures(!result.is_null() ==> own!(*result) && unsafe { *result } == 5)] //~ ERROR: only unsafe functions can use permissions in their contracts +fn test02_safe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + } + p +} + +#[ensures(!result.is_null() ==> own!(*result) && unsafe { *result } == 5)] +unsafe fn test03_unsafe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + } + p +} + +#[ensures(!result.is_null() ==> own!(*result) && unsafe { *result } == 6)] //~ ERROR: postcondition might not hold. +unsafe fn test04_unsafe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + } + p +} + +unsafe fn test05_unsafe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + unsafe { *p = 5; } //~ ERROR: the accessed memory location must be allocated and uninitialized + p +} + +#[ensures(own!(*result))] +unsafe fn test06_unsafe() -> *mut i32 { //~ ERROR: there might be insufficient permission to dereference a raw pointer + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + } + p +} + +unsafe fn test07_unsafe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + assert!(unsafe { *p } == 5); + } + p +} + +fn test07_safe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + assert!(unsafe { *p } == 5); + } + p +} + +fn callee() {} + +unsafe fn test08_unsafe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + callee(); + assert!(unsafe { *p } == 5); + } + p +} + +fn test08_safe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + callee(); + // Calling non-pure functions havoc the heap when in permissions + // are disabled: + assert!(unsafe { *p } == 5); //~ ERROR: the type invariant of the constructed object might not hold + //^ ERROR: the type invariant of the constructed object might not hold + } + p +} + +#[pure] +#[terminates] +fn pure_callee() {} + +unsafe fn test09_unsafe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + pure_callee(); + assert!(unsafe { *p } == 5); + } + p +} + +fn test09_safe() -> *mut i32 { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + pure_callee(); + assert!(unsafe { *p } == 5); + } + p +} + +#[ensures(!result.0.is_null() ==> unsafe { *result.0 } == 5)] //~ ERROR: the place must be framed by permissions +fn test11_safe() -> (*mut i32, *mut i32) { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + } + (p, p) +} + +#[ensures(!result.0.is_null() ==> own!(*result.0) && unsafe { *result.0 } == 5)] //~ ERROR: only unsafe functions can use permissions in their contracts +fn test12_safe() -> (*mut i32, *mut i32) { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + } + (p, p) +} + +#[ensures(!result.0.is_null() ==> own!(*result.0) && unsafe { *result.0 } == 5)] +unsafe fn test13_unsafe() -> (*mut i32, *mut i32) { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + } + (p, p) +} + +// Note: This works and `test14_unsafe_semantic_aliasing` fails because +// framing of unsafe function postconditions is done by Viper. +#[ensures(result.0 == result.1)] +#[ensures(!result.0.is_null() ==> own!(*result.0) && unsafe { *result.1 } == 5)] +unsafe fn test13_unsafe_semantic_aliasing() -> (*mut i32, *mut i32) { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + } + (p, p) +} + +#[ensures(!result.0.is_null() ==> own!(*result.0) && unsafe { *result.1 } == 5)] //~ ERROR: the postcondition might not be self-framing. +unsafe fn test14_unsafe_semantic_aliasing() -> (*mut i32, *mut i32) { + let p_alloc = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p_alloc as *mut i32); + if !p.is_null() { + unsafe { *p = 5; } + } + (p, p) +} + +#[trusted] +#[requires(align > 0)] +#[ensures(!result.is_null() ==> ( + raw!(*result, size) && + raw_dealloc!(*result, size, align) +))] +// https://doc.rust-lang.org/alloc/alloc/fn.alloc.html +unsafe fn alloc(size: usize, align: usize) -> *mut u8 { + unimplemented!(); +} + +#[trusted] +#[requires( + raw!(*ptr, size) && + raw_dealloc!(*ptr, size, align) +)] +unsafe fn dealloc(ptr: *mut u8, size: usize, align: usize) { + unimplemented!(); +} + +fn main() {} diff --git a/prusti-tests/tests/verify_overflow/fail/core_proof/invariants.rs b/prusti-tests/tests/verify_overflow/fail/core_proof/invariants.rs new file mode 100644 index 00000000000..8467b7b605f --- /dev/null +++ b/prusti-tests/tests/verify_overflow/fail/core_proof/invariants.rs @@ -0,0 +1,556 @@ +// compile-flags: -Punsafe_core_proof=true -Penable_type_invariants=true +// -Pverify_specifications_with_core_proof=true +// -Puse_snapshot_parameters_in_predicates=true + +use prusti_contracts::*; + +// struct T1 { +// f: i32, +// } + +// fn test1(mut a: T1) -> T1 { +// let b = std::ptr::addr_of_mut!(a); +// unpack!(*b); +// unpack!((*b).f); +// unsafe { (*b).f = 4; } +// pack!(*b); +// restore!(*b, a); +// assert!(a.f == 4); +// a +// } + +// fn test2(mut a: T1) -> T1 { +// let b = std::ptr::addr_of_mut!(a); +// unpack!(*b); +// forget_initialization!((*b).f); +// unsafe { (*b).f = 4; } +// pack!(*b); +// restore!(*b, a); +// assert!(a.f == 4); +// a +// } + +// fn test3(mut a: T1) -> T1 { +// let b = std::ptr::addr_of_mut!(a); +// unpack!(*b); +// unpack!((*b).f); +// unsafe { (*b).f = 4; } +// pack!(*b); +// restore!(*b, a); +// assert!(a.f == 5); //~ ERROR: the asserted expression might not hold +// a +// } + +// fn test4(mut a: T1) -> T1 { +// let b = std::ptr::addr_of_mut!(a); +// unpack!(*b); +// forget_initialization!((*b).f); +// unsafe { (*b).f = 4; } +// pack!(*b); +// restore!(*b, a); +// assert!(a.f == 5); //~ ERROR: the asserted expression might not hold +// a +// } + +// fn test5(mut a: T1) -> T1 { +// let b = std::ptr::addr_of_mut!(a); +// unpack!(*b); +// forget_initialization!((*b).f); +// unsafe { (*b).f = 4; } +// assert!( unsafe { (*b).f } == 4); +// pack!(*b); +// restore!(*b, a); +// a +// } + +// fn test6(mut a: T1) -> T1 { +// let b = std::ptr::addr_of_mut!(a); +// unpack!(*b); +// forget_initialization!((*b).f); +// assert!( unsafe { (*b).f } == 4); //~ ERROR: the asserted expression might not hold +// //^ ERROR: the accessed place may not be allocated or initialized +// unsafe { (*b).f = 4; } +// pack!(*b); +// restore!(*b, a); +// a +// } + +// fn test7(mut a: T1) -> T1 { +// let b = std::ptr::addr_of_mut!(a); +// unpack!(*b); +// assert!( unsafe { (*b).f } == 4); //~ ERROR: the asserted expression might not hold +// forget_initialization!((*b).f); +// unsafe { (*b).f = 4; } +// pack!(*b); +// restore!(*b, a); +// a +// } + +// #[requires(b ==> own!(*p))] +// #[ensures(b ==> ((own!(*p)) && unsafe { *p } == 4))] +// unsafe fn test8(p: *mut i32, b: bool) { +// if b { +// forget_initialization!(*p); +// unsafe { *p = 4 }; +// } +// } + +// #[ensures(result.f == 4)] +// fn test9(mut a: T1) -> T1 { +// let b = std::ptr::addr_of_mut!(a.f); +// unsafe { test8(b, true); } +// restore!(*b, a.f); +// a +// } + +// #[requires(b ==> own!(*p))] +// #[ensures(b ==> ((own!(*p)) && unsafe { *p } == 5))] //~ ERROR: postcondition might not hold. +// unsafe fn test10(p: *mut i32, b: bool) { +// if b { +// forget_initialization!(*p); +// unsafe { *p = 4 }; +// } +// } + +// #[ensures(result.f == 5)] //~ ERROR: postcondition might not hold. +// fn test11(mut a: T1) -> T1 { +// let b = std::ptr::addr_of_mut!(a.f); +// unsafe { test8(b, true); } +// restore!(*b, a.f); +// a +// } + +// struct T2 { +// f: i32, +// g: i32, +// } + +// #[ensures(result.f == 4 && result.g == a.g)] +// fn test12(mut a: T2) -> T2 { +// let b = std::ptr::addr_of_mut!(a); +// unpack!(*b); +// unpack!((*b).f); +// unsafe { (*b).f = 4; } +// pack!(*b); +// restore!(*b, a); +// assert!(a.f == 4); +// a +// } + +// #[ensures(result.f == 5 && result.g == a.g)] +// fn test13(mut a: T2) -> T2 { //~ ERROR: postcondition might not hold. +// let b = std::ptr::addr_of_mut!(a); +// unpack!(*b); +// unpack!((*b).f); +// unsafe { (*b).f = 4; } +// pack!(*b); +// restore!(*b, a); +// assert!(a.f == 4); +// a +// } + +// #[requires(b ==> (own!(*p) && unsafe { *p } < 20))] +// #[ensures(b ==> (own!(*p) && unsafe { *p } == old(unsafe { *p }) + 1))] +// unsafe fn test14(p: *mut i32, b: bool) { +// if b { +// // FIXME: unsafe { *p += 1 }; +// let tmp = unsafe { *p }; +// forget_initialization!(*p); +// unsafe { *p = tmp + 1 }; +// } +// } + +// #[ensures(result.f == 7)] +// fn test15(mut a: T1) -> T1 { +// a.f = 6; +// let b = std::ptr::addr_of_mut!(a.f); +// unsafe { test14(b, true); } +// restore!(*b, a.f); +// a +// } + +// #[ensures(result.f == 8)] +// fn test16(mut a: T1) -> T1 { +// a.f = 6; +// let b = std::ptr::addr_of_mut!(a.f); +// unsafe { test14(b, true); } +// restore!(*b, a.f); +// a +// } + +// fn alloc_client() { +// let size = std::mem::size_of::(); +// let align = std::mem::align_of::(); +// let ptr = unsafe { alloc(size, align) }; +// if !ptr.is_null() { +// unsafe { *(ptr as *mut u16) = 42; } +// assert!(unsafe { *(ptr as *mut u16) } == 42); +// let ptr_u16 = (ptr as *mut u16); +// forget_initialization!(*ptr_u16); // FIXME: We should support (ptr as *mut u16). +// unsafe { dealloc(ptr, size, align) }; +// } +// } + +// fn alloc_client2() { +// let size = std::mem::size_of::(); +// let align = std::mem::align_of::(); +// let ptr = unsafe { alloc(size, align) }; +// if !ptr.is_null() { +// unsafe { *(ptr as *mut u16) = 42; } +// assert!(unsafe { *(ptr as *mut u16) } == 43); //~ ERROR: the asserted expression might not hold +// let ptr_u16 = (ptr as *mut u16); +// forget_initialization!(*ptr_u16); // FIXME: We should support (ptr as *mut u16). +// unsafe { dealloc(ptr, size, align) }; +// } +// } + +// fn alloc_client3() { +// let size = std::mem::size_of::(); +// let align = std::mem::align_of::(); +// let ptr = unsafe { alloc(size, align) }; +// unsafe { *(ptr as *mut u16) = 42; } //~ ERROR: the accessed memory location must be allocated and uninitialized +// } + +// #[requires(x < 5)] +// unsafe fn check_x(x: u32) {} + +// #[structural_invariant(self.x < 5)] +// struct T3 { +// x: u32, +// } + +// fn test17(a: T3) { +// unpack!(a); +// unsafe { check_x(a.x) } +// pack!(a); +// forget_initialization!(a); +// } + +// #[structural_invariant( +// !self.p1.is_null() ==> ( +// raw!(*self.p1, std::mem::size_of::()) && +// raw_dealloc!(*self.p1, std::mem::size_of::(), std::mem::align_of::()) +// ) +// )] +// #[structural_invariant( +// !self.p2.is_null() ==> ( +// own!(*self.p2) && +// raw_dealloc!(*self.p2, std::mem::size_of::(), std::mem::align_of::()) +// ) +// )] +// struct T4 { +// p1: *mut i32, +// p2: *mut i32, +// } + +// impl T4 { +// fn new() -> Self { +// let p1 = unsafe { +// alloc(std::mem::size_of::(), std::mem::align_of::()) +// }; +// let p2 = unsafe { +// alloc(std::mem::size_of::(), std::mem::align_of::()) +// }; +// if !p2.is_null() { +// unsafe { *(p2 as *mut i32) = 42; } +// } +// Self { p1: (p1 as *mut i32), p2: (p2 as *mut i32) } +// } +// } + +// #[structural_invariant( +// !self.p2.is_null() ==> ( +// own!(*self.p2) && +// raw_dealloc!(*self.p2, std::mem::size_of::(), std::mem::align_of::()) && +// unsafe { *self.p2 == 42 } && +// 1 == 1 && +// 2 == 2 && +// 3 == 3 && +// 4 == 4 && +// 5 == 5 && +// 6 == 6 +// ) +// )] +// struct T5 { +// p2: *mut i32, +// } + +// impl T5 { +// fn new() -> Self { +// let p2 = unsafe { +// alloc(std::mem::size_of::(), std::mem::align_of::()) +// }; +// if !p2.is_null() { +// unsafe { *(p2 as *mut i32) = 42; } +// } +// Self { p2: (p2 as *mut i32) } +// } +// fn new_fail() -> Self { +// let p2 = unsafe { +// alloc(std::mem::size_of::(), std::mem::align_of::()) +// }; +// if !p2.is_null() { +// unsafe { *(p2 as *mut i32) = 43; } +// } +// Self { p2: (p2 as *mut i32) } //~ ERROR: The type invariant of the constructed object might not hold +// } +// } + +#[structural_invariant( + !self.p.is_null() ==> ( + raw_dealloc!(*self.p, std::mem::size_of::(), std::mem::align_of::()) && + raw!((*self.p).x, std::mem::size_of::()) && + own!((*self.p).y) && + unsafe { (*self.p).y } == self.v + ) +)] +struct T6 { + v: i32, + p: *mut Pair, +} + +impl T6 { + #[ensures(result.v == 42)] + #[ensures( + unpacking!( + result, + !result.p.is_null() ==> + (unpacking!((*result.p).y, unsafe { (*result.p).y }) == 42) + ) + )] + fn new() -> Self { + let p2 = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p2 as *mut Pair); + if !p2.is_null() { + split!(*p); + unsafe { (*p).y = 42; } + } + Self { p, v: 42 } + } +// #[ensures(result.v == 42)] +// #[ensures( +// unpacking!( //~ ERROR: postcondition might not hold. +// result, +// !result.p.is_null() ==> +// (unpacking!((*result.p).y, unsafe { (*result.p).y }) == 43) +// ) +// )] +// fn new_fail_wrong_value() -> Self { +// let p2 = unsafe { +// alloc(std::mem::size_of::(), std::mem::align_of::()) +// }; +// let p = (p2 as *mut Pair); +// if !p2.is_null() { +// split!(*p); +// unsafe { (*p).y = 42; } +// } +// Self { p, v: 42 } +// } +// #[ensures(result.v == 42)] +// #[ensures( +// !result.p.is_null() ==> +// (unpacking!((*result.p).y, unsafe { (*result.p).y }) == 42) +// )] +// fn new_fail_missing_outer_unpacking() -> Self { //~ ERROR: there might be insufficient permission to dereference a raw pointer +// let p2 = unsafe { +// alloc(std::mem::size_of::(), std::mem::align_of::()) +// }; +// let p = (p2 as *mut Pair); +// if !p2.is_null() { +// split!(*p); +// unsafe { (*p).y = 42; } +// } +// Self { p, v: 42 } +// } + #[ensures(result.v == 42)] + #[ensures( + unpacking!( + result, + !result.p.is_null() ==> + (unsafe { (*result.p).y } == 42) + ) + )] + fn new_fail_missing_inner_unpacking() -> Self { + let p2 = unsafe { + alloc(std::mem::size_of::(), std::mem::align_of::()) + }; + let p = (p2 as *mut Pair); + if !p2.is_null() { + split!(*p); + unsafe { (*p).y = 42; } + } + Self { p, v: 42 } + } +// #[ensures(result.v == 42)] +// #[ensures( +// unpacking!( +// result, +// !result.p.is_null() ==> +// (unpacking!(*result.p, unsafe { (*result.p).y }) == 42) +// ) +// )] +// fn new_fail_wrong_inner_unpacking() -> Self { //~ ERROR: there might be insufficient permission to dereference a raw pointer +// let p2 = unsafe { +// alloc(std::mem::size_of::(), std::mem::align_of::()) +// }; +// let p = (p2 as *mut Pair); +// if !p2.is_null() { +// split!(*p); +// unsafe { (*p).y = 42; } +// } +// Self { p, v: 42 } +// } + + //#[ensures(result.v == 42)] + //#[ensures(!result.p.is_null() ==> (unsafe { (*result.p).y } == 42))] //~ ERROR: there might be insufficient permission to dereference a raw pointer + //fn new_fail1() -> Self { + //let p2 = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let p = (p2 as *mut Pair); + //if !p2.is_null() { + //split!(*p); + //unsafe { (*p).y = 42; } + //} + //Self { p, v: 42 } + //} + //fn new_fail() -> Self { + //let p2 = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let p = (p2 as *mut Pair); + //if !p2.is_null() { + //split!(*p); + //unsafe { (*p).y = 43; } + //} + //Self { p, v: 42 } //~ ERROR: The type invariant of the constructed object might not hold + //} + //#[ensures(result.v == 42)] + //#[ensures((unsafe { (*result.p).y } == 42))] //~ ERROR: there might be insufficient permission to dereference a raw pointer + ////^ ERROR: postcondition might not hold + //fn new_fail2() -> Self { + //let p2 = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let p = (p2 as *mut Pair); + //if !p2.is_null() { + //split!(*p); + //unsafe { (*p).y = 42; } + //} + //Self { p, v: 42 } + //} + + //#[ensures(result.v == 42)] + //// TODO: Make sure to distinguish unpacking!((*result.p), ...) from + //// unpacking!((*result.p).y, ...) + //#[ensures((unpacking!(result.p, unsafe { (*result.p).y }) == 42))] + //fn new_fail3() -> Self { + //let p2 = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let p = (p2 as *mut Pair); + //if !p2.is_null() { + //split!(*p); + //unsafe { (*p).y = 42; } + //} + //Self { p, v: 42 } + //} + + //#[ensures(result.v == 42)] + //// TODO: Make sure to distinguish unpacking!((*result.p), ...) from + //// unpacking!((*result.p).y, ...) + //#[ensures((unpacking!((*result.p).y, unsafe { (*result.p).y }) == 42))] + //fn new_fail4() -> Self { + //let p2 = unsafe { + //alloc(std::mem::size_of::(), std::mem::align_of::()) + //}; + //let p = (p2 as *mut Pair); + //if !p2.is_null() { + //split!(*p); + //unsafe { (*p).y = 42; } + //} + //Self { p, v: 42 } + //} + // #[pure] + // fn value(&self) -> i32 { + // if self.p2.is_null() { + // 0 + // } else { + // unsafe { *self.p2 } + // } + // } +} + +// #[structural_invariant( +// !self.p2.is_null() ==> ( +// own!(*self.p2) && +// raw_dealloc!(*self.p2, std::mem::size_of::(), std::mem::align_of::()) && +// unsafe { *self.p2 } == self.v +// ) +// )] +// struct T6 { +// v: i32, +// p2: *mut i32, +// } + +// impl T6 { +// // #[ensures(result.v == 42)] +// // #[ensures(!result.p2.is_null() ==> (unsafe { *result.p2 } == 42))] +// // fn new() -> Self { +// // let p2 = unsafe { +// // alloc(std::mem::size_of::(), std::mem::align_of::()) +// // }; +// // if !p2.is_null() { +// // unsafe { *(p2 as *mut i32) = 42; } +// // } +// // Self { p2: (p2 as *mut i32), v: 42 } +// // } +// #[ensures(result.v == 42)] +// #[ensures(unsafe { *result.p2 } == 42)] //~ ERROR: Permissions +// fn new_fail() -> Self { +// let p2 = unsafe { +// alloc(std::mem::size_of::(), std::mem::align_of::()) +// }; +// if !p2.is_null() { +// unsafe { *(p2 as *mut i32) = 42; } +// } +// Self { p2: (p2 as *mut i32), v: 42 } +// } +// // #[pure] +// // fn value(&self) -> i32 { +// // if self.p2.is_null() { +// // 0 +// // } else { +// // unsafe { *self.p2 } +// // } +// // } +// } + +#[trusted] +#[requires(align > 0)] +#[ensures(!result.is_null() ==> ( + raw!(*result, size) && + raw_dealloc!(*result, size, align) +))] +// https://doc.rust-lang.org/alloc/alloc/fn.alloc.html +unsafe fn alloc(size: usize, align: usize) -> *mut u8 { + unimplemented!(); +} + +#[trusted] +#[requires( + raw!(*ptr, size) && + raw_dealloc!(*ptr, size, align) +)] +unsafe fn dealloc(ptr: *mut u8, size: usize, align: usize) { + unimplemented!(); +} + +struct Pair { + x: i32, + y: i32, +} + +fn main() {} diff --git a/prusti-tests/tests/verify_overflow/fail/core_proof/invariants2.rs b/prusti-tests/tests/verify_overflow/fail/core_proof/invariants2.rs new file mode 100644 index 00000000000..e6f09180adf --- /dev/null +++ b/prusti-tests/tests/verify_overflow/fail/core_proof/invariants2.rs @@ -0,0 +1,38 @@ +// compile-flags: -Punsafe_core_proof=true -Penable_type_invariants=true -Pverify_specifications_with_core_proof=true +// +// These tests need core-proof for specs. + +use prusti_contracts::*; + +struct T1 { + f: i32, +} + +fn test01(mut a: T1, mut b: T1) { + let z = b.f; + let x = std::ptr::addr_of_mut!(a); + let y = std::ptr::addr_of_mut!(b); + unpack!(*x); + unpack!((*x).f); + unsafe { (*x).f = 4; } + pack!(*x); + restore!(*x, a); + restore!(*y, b); + assert!(a.f == 4); + assert!(z == b.f); +} + +fn test02(mut a: T1, mut b: T1) { + let z = b.f; + let x = std::ptr::addr_of_mut!(a); + let y = std::ptr::addr_of_mut!(b); + unpack!(*x); + unpack!((*x).f); + unsafe { (*x).f = 4; } + pack!(*x); + restore!(*x, a); + restore!(*y, b); + assert!(a.f == 5); //~ ERROR: the asserted expression might not hold +} + +fn main() {} diff --git a/prusti-tests/tests/verify_overflow/fail/core_proof/pointers.rs b/prusti-tests/tests/verify_overflow/fail/core_proof/pointers.rs index adfb04fd917..014c8fae2a4 100644 --- a/prusti-tests/tests/verify_overflow/fail/core_proof/pointers.rs +++ b/prusti-tests/tests/verify_overflow/fail/core_proof/pointers.rs @@ -10,18 +10,22 @@ use prusti_contracts::*; fn test1() { let a = 4u32; - let _x = std::ptr::addr_of!(a); + let x = std::ptr::addr_of!(a); + restore!(*x, a); } fn test2() { let mut a = 4u32; - let _x = std::ptr::addr_of_mut!(a); + let x = std::ptr::addr_of_mut!(a); + restore!(*x, a); } fn test3() { let a = 4u32; let x = std::ptr::addr_of!(a); + restore!(*x, a); let y = std::ptr::addr_of!(a); + restore!(*y, a); assert!(x == y); } @@ -29,7 +33,9 @@ fn test4() { let a = 4u32; let b = 4u32; let x = std::ptr::addr_of!(a); + restore!(*x, a); let y = std::ptr::addr_of!(b); + restore!(*y, b); assert!(x == y); //~ ERROR } @@ -37,7 +43,9 @@ fn test5() { let a = 4u32; let b = 4u32; let x = std::ptr::addr_of!(a); + restore!(*x, a); let y = std::ptr::addr_of!(b); + restore!(*y, b); assert!(x != y); //~ ERROR } @@ -45,7 +53,9 @@ fn test6() { let a = 4u32; let b = 4u32; let x = std::ptr::addr_of!(a); + restore!(*x, a); let y = std::ptr::addr_of!(b); + restore!(*y, b); assert!(!(x == y)); //~ ERROR } @@ -64,6 +74,8 @@ fn test7() { let x = std::ptr::addr_of!(a); let y = std::ptr::addr_of!(c.f.g); assert!(x != y); //~ ERROR + restore!(*x, a); + restore!(*y, c.f.g); } fn test8() { @@ -73,6 +85,8 @@ fn test8() { let x = std::ptr::addr_of!(a); let y = std::ptr::addr_of!(c.f.g); assert!(!(x == y)); //~ ERROR + restore!(*x, a); + restore!(*y, c.f.g); } fn test9() { @@ -82,6 +96,8 @@ fn test9() { let x = std::ptr::addr_of!(a); let y = std::ptr::addr_of!(c.f.g); assert!(x == y); //~ ERROR + restore!(*x, a); + restore!(*y, c.f.g); } fn main() {} diff --git a/prusti-tests/tests/verify_overflow/pass/core_proof/custom_heap_encoding/performance_test.rs b/prusti-tests/tests/verify_overflow/pass/core_proof/custom_heap_encoding/performance_test.rs new file mode 100644 index 00000000000..30184e1bfbd --- /dev/null +++ b/prusti-tests/tests/verify_overflow/pass/core_proof/custom_heap_encoding/performance_test.rs @@ -0,0 +1,972 @@ +// compile-flags: -Punsafe_core_proof=true -Pverification_deadline=120 + +use prusti_contracts::*; + +struct T {} + +//fn test001() { + //let a = T{}; + //let a = a; +//} + +//fn test002() { + //let a = T{}; + //let a = a; + //let a = a; +//} + +//fn test003() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test004() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test005() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test006() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test007() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test008() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test009() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test010() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test011() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test012() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test013() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test014() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test015() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test016() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test017() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test018() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test019() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test020() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test021() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test022() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test023() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test024() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test025() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} +//fn test030() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test040() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test050() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test060() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test070() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test080() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +//fn test090() { + //let a = T{}; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; + //let a = a; +//} + +fn test100() { + let a = T{}; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; + let a = a; +} + +#[trusted] +fn main() {} + diff --git a/prusti-tests/tests/verify_overflow/pass/core_proof/pointers.rs b/prusti-tests/tests/verify_overflow/pass/core_proof/pointers.rs index 842b551257f..bfcbc5b01e3 100644 --- a/prusti-tests/tests/verify_overflow/pass/core_proof/pointers.rs +++ b/prusti-tests/tests/verify_overflow/pass/core_proof/pointers.rs @@ -4,18 +4,22 @@ use prusti_contracts::*; fn test1() { let a = 4u32; - let _x = std::ptr::addr_of!(a); + let x = std::ptr::addr_of!(a); + restore!(*x, a); } fn test2() { let mut a = 4u32; - let _x = std::ptr::addr_of_mut!(a); + let x = std::ptr::addr_of_mut!(a); + restore!(*x, a); } fn test3() { let a = 4u32; let x = std::ptr::addr_of!(a); + restore!(*x, a); let y = std::ptr::addr_of!(a); + restore!(*y, a); assert!(x == y); } diff --git a/prusti-tests/tests/verify_overflow/pass/core_proof/types.rs b/prusti-tests/tests/verify_overflow/pass/core_proof/types.rs index 180a246a332..2b53cb97d63 100644 --- a/prusti-tests/tests/verify_overflow/pass/core_proof/types.rs +++ b/prusti-tests/tests/verify_overflow/pass/core_proof/types.rs @@ -1,4 +1,5 @@ // compile-flags: -Punsafe_core_proof=true -Pverify_types=true +// -Puse_snapshot_parameters_in_predicates=true use prusti_contracts::*; @@ -16,117 +17,117 @@ struct T3 { struct T4<'a> { f: &'a mut T1, - g: &'a T1, + //g: &'a T1, } struct T5<'a, 'b, 'c> { f: &'a mut T1, g: &'c mut T4<'b>, - h: &'a T1, - i: &'c T4<'b>, -} - -struct T6<'a, 'b, 'c> { - f: &'a mut T1, - g: &'a mut &'b mut T1, - h: &'a mut &'b mut &'c mut T1, - i: &'a T1, - j: &'a &'b T1, - k: &'a &'b &'c T1, -} - -struct T7<'a> { - f: &'a mut T1, - g: &'a mut &'a mut T1, - h: &'a mut &'a mut &'a mut T1, - i: &'a T1, - j: &'a &'a T1, - k: &'a &'a &'a T1, -} - -struct T8 { - f: [i32; 10], - g: [[T2; 10]; 10], - h: [[[T2; 10]; 10]; 10], -} - -enum T9 { - F([i32; 10]), - G([[T2; 10]; 10]), - H([[[T2; 10]; 10]; 10]), -} - -struct T10 { - f: [T8; 10], - g: [[T8; 10]; 10], - h: [[[T9; 10]; 10]; 10], -} - -struct T11<'a, 'b, 'c, 'd> { - f: &'a mut [&'b mut T8; 10], - g: &'a mut [&'b mut [&'c mut T9; 10]; 10], - h: &'a mut [&'b mut [&'c mut [&'d mut T9; 10]; 10]; 10], - i: &'a [&'b T8; 10], - j: &'a [&'b [&'c T9; 10]; 10], - k: &'a [&'b [&'c [&'d T9; 10]; 10]; 10], -} - -struct T12<'a> { - f: &'a mut [&'a mut T9; 10], - g: &'a mut [&'a mut [&'a mut T9; 10]; 10], - h: &'a mut [&'a mut [&'a mut [&'a mut T9; 10]; 10]; 10], - i: &'a [&'a T9; 10], - j: &'a [&'a [&'a T9; 10]; 10], - k: &'a [&'a [&'a [&'a T9; 10]; 10]; 10], -} - -struct T13 { - f: (), - g: (T1, T2, T3<(T2, T1)>), -} - -struct T14<'a, 'b, 'c> { - f: &'a mut (), - g: &'a mut (&'b mut T1, &'c mut T2), - i: &'a (), - j: &'a (&'b T1, &'c T2), -} + //h: &'a T1, + //i: &'c T4<'b>, +} + +//struct T6<'a, 'b, 'c> { + //f: &'a mut T1, + //g: &'a mut &'b mut T1, + //h: &'a mut &'b mut &'c mut T1, + //i: &'a T1, + //j: &'a &'b T1, + //k: &'a &'b &'c T1, +//} + +//struct T7<'a> { + //f: &'a mut T1, + //g: &'a mut &'a mut T1, + //h: &'a mut &'a mut &'a mut T1, + //i: &'a T1, + //j: &'a &'a T1, + //k: &'a &'a &'a T1, +//} + +//struct T8 { + //f: [i32; 10], + //g: [[T2; 10]; 10], + //h: [[[T2; 10]; 10]; 10], +//} + +//enum T9 { + //F([i32; 10]), + //G([[T2; 10]; 10]), + //H([[[T2; 10]; 10]; 10]), +//} + +//struct T10 { + //f: [T8; 10], + //g: [[T8; 10]; 10], + //h: [[[T9; 10]; 10]; 10], +//} + +//struct T11<'a, 'b, 'c, 'd> { + //f: &'a mut [&'b mut T8; 10], + //g: &'a mut [&'b mut [&'c mut T9; 10]; 10], + //h: &'a mut [&'b mut [&'c mut [&'d mut T9; 10]; 10]; 10], + //i: &'a [&'b T8; 10], + //j: &'a [&'b [&'c T9; 10]; 10], + //k: &'a [&'b [&'c [&'d T9; 10]; 10]; 10], +//} + +//struct T12<'a> { + //f: &'a mut [&'a mut T9; 10], + //g: &'a mut [&'a mut [&'a mut T9; 10]; 10], + //h: &'a mut [&'a mut [&'a mut [&'a mut T9; 10]; 10]; 10], + //i: &'a [&'a T9; 10], + //j: &'a [&'a [&'a T9; 10]; 10], + //k: &'a [&'a [&'a [&'a T9; 10]; 10]; 10], +//} + +//struct T13 { + //f: (), + //g: (T1, T2, T3<(T2, T1)>), +//} + +//struct T14<'a, 'b, 'c> { + //f: &'a mut (), + //g: &'a mut (&'b mut T1, &'c mut T2), + //i: &'a (), + //j: &'a (&'b T1, &'c T2), +//} struct T15<'a> { f: &'a mut [T1], g: &'a mut [T1; 10], - i: &'a [T1], - j: &'a [T1; 10], -} - -struct T16<'a, 'b> { - f: &'a mut [&'b mut T1], - g: &'a mut [&'b mut T1; 10], - i: &'a [&'b T1], - j: &'a [&'b T1; 10], -} - -enum T17<'a, 'b> { - Left (&'a mut [T1; 10]), - Right (&'b mut [T2; 10]), - SharedLeft (&'a [T1; 10]), - SharedRight (&'b [T2; 10]), -} - -enum T18<'a> { - Left(&'a mut [T1]), - Right([T2; 10]), -} - -union T19<'a, 'b> { - f: &'a mut [&'b mut T1], - g: &'a mut [&'b mut T1; 10], -} - -struct T20 { - f: *mut u8, - g: *mut [u8], -} + //i: &'a [T1], + //j: &'a [T1; 10], +} + +//struct T16<'a, 'b> { + //f: &'a mut [&'b mut T1], + //g: &'a mut [&'b mut T1; 10], + //i: &'a [&'b T1], + //j: &'a [&'b T1; 10], +//} + +//enum T17<'a, 'b> { + //Left (&'a mut [T1; 10]), + //Right (&'b mut [T2; 10]), + //SharedLeft (&'a [T1; 10]), + //SharedRight (&'b [T2; 10]), +//} + +//enum T18<'a> { + //Left(&'a mut [T1]), + //Right([T2; 10]), +//} + +//union T19<'a, 'b> { + //f: &'a mut [&'b mut T1], + //g: &'a mut [&'b mut T1; 10], +//} + +//struct T20 { + //f: *mut u8, + //g: *mut [u8], +//} #[trusted] fn main() {} diff --git a/prusti-utils/src/config.rs b/prusti-utils/src/config.rs index 4abf70ba6cc..84b009efa0d 100644 --- a/prusti-utils/src/config.rs +++ b/prusti-utils/src/config.rs @@ -114,6 +114,9 @@ lazy_static::lazy_static! { settings.set_default("enable_purification_optimization", false).unwrap(); // settings.set_default("enable_manual_axiomatization", false).unwrap(); settings.set_default("unsafe_core_proof", false).unwrap(); + settings.set_default("custom_heap_encoding", false).unwrap(); + settings.set_default("trace_with_symbolic_execution", false).unwrap(); + settings.set_default("purify_with_symbolic_execution", false).unwrap(); settings.set_default("verify_core_proof", true).unwrap(); settings.set_default("verify_specifications", true).unwrap(); settings.set_default("verify_types", false).unwrap(); @@ -121,6 +124,7 @@ lazy_static::lazy_static! { settings.set_default("verify_specifications_backend", "Silicon").unwrap(); settings.set_default("use_eval_axioms", true).unwrap(); settings.set_default("inline_caller_for", false).unwrap(); + settings.set_default("use_snapshot_parameters_in_predicates", false).unwrap(); settings.set_default("check_no_drops", false).unwrap(); settings.set_default("enable_type_invariants", false).unwrap(); settings.set_default("use_new_encoder", true).unwrap(); @@ -862,6 +866,34 @@ pub fn unsafe_core_proof() -> bool { read_setting("unsafe_core_proof") } +/// Use symbolic execution to split the procedure into traces that are verified +/// separately. +/// +/// **Note:** This option is taken into account only when `unsafe_core_proof` is +/// true. +pub fn trace_with_symbolic_execution() -> bool { + read_setting("trace_with_symbolic_execution") || purify_with_symbolic_execution() +} + +/// Use symbolic execution based purification. +/// +/// **Note:** This option is taken into account only when `unsafe_core_proof` is +/// true. +/// +/// **Note:** This option automatically enables +/// `trace_with_symbolic_execution`. +pub fn purify_with_symbolic_execution() -> bool { + read_setting("purify_with_symbolic_execution") +} + +/// Use custom heap encoding. +/// +/// **Note:** This option is taken into account only when `unsafe_core_proof` is +/// true. +pub fn custom_heap_encoding() -> bool { + read_setting("custom_heap_encoding") +} + /// Whether the core proof (memory safety) should be verified. /// /// **Note:** This option is taken into account only when `unsafe_core_proof` is @@ -916,6 +948,14 @@ pub fn inline_caller_for() -> bool { read_setting("inline_caller_for") } +/// Whether to make the snapshot, an explicit parameter of the predicate. +/// +/// **Note:** This option is taken into account only when `unsafe_core_proof` is +/// true. +pub fn use_snapshot_parameters_in_predicates() -> bool { + read_setting("use_snapshot_parameters_in_predicates") +} + /// When enabled, replaces calls to the drop function with `assert false`. /// /// **Note:** This option is used only for testing. diff --git a/prusti-viper/Cargo.toml b/prusti-viper/Cargo.toml index d35bca947a3..15aaca2ffea 100644 --- a/prusti-viper/Cargo.toml +++ b/prusti-viper/Cargo.toml @@ -26,6 +26,7 @@ backtrace = "0.3" rustc-hash = "1.1.0" derive_more = "0.99.16" itertools = "0.10.3" +egg = "0.9.2" [dev-dependencies] lazy_static = "1.4" diff --git a/prusti-viper/src/encoder/counterexamples/interface.rs b/prusti-viper/src/encoder/counterexamples/interface.rs index 183cd20cc55..a752c739bd1 100644 --- a/prusti-viper/src/encoder/counterexamples/interface.rs +++ b/prusti-viper/src/encoder/counterexamples/interface.rs @@ -24,7 +24,7 @@ impl MirProcedureMapping { procedure .basic_blocks .iter() - .map(|basic_block| { + .map(|(label, basic_block)| { let mut stmts = Vec::new(); for statement in &basic_block.statements { @@ -51,7 +51,7 @@ impl MirProcedureMapping { } }; BasicBlock { - label: basic_block.label.name.clone(), + label: label.name.clone(), successor, stmts, } @@ -92,8 +92,9 @@ impl<'v, 'tcx: 'v> MirProcedureMappingInterface for super::super::Encoder<'v, 't fn add_mapping(&mut self, program: &vir_low::Program) { if let Some(vir_low_procedure) = program.procedures.first() { //at the moment a counterexample is only produced for the specifications-poof - if matches!(program.check_mode, CheckMode::Specifications) - || matches!(program.check_mode, CheckMode::Both) + if program.check_mode.check_specifications() + // matches!(program.check_mode, CheckMode::Specifications) + // || matches!(program.check_mode, CheckMode::Both) { let procedure_new = self .mir_procedure_mapping diff --git a/prusti-viper/src/encoder/encoder.rs b/prusti-viper/src/encoder/encoder.rs index c1f0c4fcbed..638282b074a 100644 --- a/prusti-viper/src/encoder/encoder.rs +++ b/prusti-viper/src/encoder/encoder.rs @@ -698,6 +698,19 @@ impl<'v, 'tcx> Encoder<'v, 'tcx> { } } + /// Returns true iff `def_id` is a function that uses raw pointers. + fn is_internally_unsafe_function(&self, def_id: ProcedureDefId) -> bool { + let mir = self.env.body.borrow_impure_fn_body_identity(def_id.expect_local()); + for local_decl in &mir.local_decls { + if let prusti_rustc_interface::middle::ty::TyKind::RawPtr(_) = + local_decl.ty.kind() + { + return true; + } + } + return false; + } + pub fn process_encoding_queue(&mut self) { if let Err(error) = self.initialize() { panic!("The initialization of the encoder failed with the error: {error:?}"); @@ -715,26 +728,36 @@ impl<'v, 'tcx> Encoder<'v, 'tcx> { if config::unsafe_core_proof() { if self.env.query.is_unsafe_function(proc_def_id) { - if let Err(error) = self.encode_lifetimes_core_proof(proc_def_id, CheckMode::Both) { + if let Err(error) = self.encode_lifetimes_core_proof(proc_def_id, CheckMode::UnsafeSafety) { self.register_encoding_error(error); - debug!("Error encoding function: {:?} {}", proc_def_id, CheckMode::Both); + debug!("Error encoding function: {:?} {}", proc_def_id, CheckMode::UnsafeSafety); } } else { - if config::verify_core_proof() { - if let Err(error) = self.encode_lifetimes_core_proof(proc_def_id, CheckMode::CoreProof) { - self.register_encoding_error(error); - debug!("Error encoding function: {:?} {}", proc_def_id, CheckMode::CoreProof); + if config::verify_specifications_with_core_proof() || self.is_internally_unsafe_function(proc_def_id) { + if config::verify_core_proof() { + if let Err(error) = self.encode_lifetimes_core_proof(proc_def_id, CheckMode::MemorySafety) { + self.register_encoding_error(error); + debug!("Error encoding function: {:?} {}", proc_def_id, CheckMode::MemorySafety); + } } - } - if config::verify_specifications() { - let check_mode = if config::verify_specifications_with_core_proof() { - CheckMode::Both - } else { - CheckMode::Specifications - }; - if let Err(error) = self.encode_lifetimes_core_proof(proc_def_id, check_mode) { - self.register_encoding_error(error); - debug!("Error encoding function: {:?} {}", proc_def_id, check_mode); + if config::verify_specifications() { + if let Err(error) = self.encode_lifetimes_core_proof(proc_def_id, CheckMode::MemorySafetyWithFunctional) { + self.register_encoding_error(error); + debug!("Error encoding function: {:?} {}", proc_def_id, CheckMode::MemorySafetyWithFunctional); + } + } + } else { + if config::verify_core_proof() { + if let Err(error) = self.encode_lifetimes_core_proof(proc_def_id, CheckMode::PurificationSoudness) { + self.register_encoding_error(error); + debug!("Error encoding function: {:?} {}", proc_def_id, CheckMode::PurificationSoudness); + } + } + if config::verify_specifications() { + if let Err(error) = self.encode_lifetimes_core_proof(proc_def_id, CheckMode::PurificationFunctional) { + self.register_encoding_error(error); + debug!("Error encoding function: {:?} {}", proc_def_id, CheckMode::PurificationFunctional); + } } } } @@ -792,9 +815,9 @@ impl<'v, 'tcx> Encoder<'v, 'tcx> { } EncodingTask::Type { ty } => { if config::unsafe_core_proof() && config::verify_core_proof() && config::verify_types() { - if let Err(error) = self.encode_core_proof_for_type(ty, CheckMode::CoreProof) { + if let Err(error) = self.encode_core_proof_for_type(ty, CheckMode::MemorySafety) { self.register_encoding_error(error); - debug!("Error encoding type: {:?} {}", ty, CheckMode::CoreProof); + debug!("Error encoding type: {:?} {}", ty, CheckMode::MemorySafety); } } } diff --git a/prusti-viper/src/encoder/errors/error_manager.rs b/prusti-viper/src/encoder/errors/error_manager.rs index 9a0a6e3804e..32c7447dda6 100644 --- a/prusti-viper/src/encoder/errors/error_manager.rs +++ b/prusti-viper/src/encoder/errors/error_manager.rs @@ -37,6 +37,7 @@ pub enum PanicCause { pub enum BuiltinMethodKind { WriteConstant, MovePlace, + CopyPlace, IntoMemoryBlock, SplitMemoryBlock, JoinMemoryBlock, @@ -44,6 +45,7 @@ pub enum BuiltinMethodKind { ChangeUniqueRefPlace, DuplicateFracRef, Assign, + RestoreRawBorrowed, } /// In case of verification error, this enum will contain additional information @@ -173,6 +175,10 @@ pub enum ErrorCtxt { CloseMutRef, /// Failed to encode CloseFracRef CloseFracRef, + /// Closing a reference failed. + CloseRef, + /// Opening a reference failed. + OpenRef, /// Failed to set an active variant of an union. SetEnumVariant, /// A user assumption raised an error @@ -180,6 +186,30 @@ pub enum ErrorCtxt { /// The state that fold-unfold algorithm deduced as unreachable, is actually /// reachable. UnreachableFoldingState, + /// A user-specified pack operation failed. + Pack, + /// A user-specified unpack operation failed. + Unpack, + /// A user-specified forget-initialization operation failed. + ForgetInitialization, + /// Restore a place borrowed via raw pointer. + RestoreRawBorrowed, + /// An error in the definition of the type invariant. + TypeInvariantDefinition, + /// Pointer dereference in the postcondition is not framed by permissions. + /// + /// Note: This can also be reported when the underlying solver failing to + /// prove that the postcondition implies itself. + MethodPostconditionFraming, + /// An unexpected error when assuming false to end method postcondition + /// framing check. + UnexpectedAssumeEndMethodPostconditionFraming, + StashRange, + RestoreStashRange, + JoinRange, + SplitRange, + // /// Permission error when dereferencing a raw pointer. + // EnsureOwnedPredicate, } /// The error manager @@ -402,6 +432,7 @@ impl<'tcx> ErrorManager<'tcx> { .set_help("This might be a bug in the Rust compiler.") } + ("exhale.failed:assertion.false", ErrorCtxt::ExhaleMethodPrecondition) | ("assert.failed:assertion.false", ErrorCtxt::ExhaleMethodPrecondition) => { PrustiError::verification("precondition might not hold.", error_span) .set_failing_assertion(opt_cause_span) @@ -557,11 +588,19 @@ impl<'tcx> ErrorManager<'tcx> { .set_failing_assertion(opt_cause_span) } - ("assert.failed:assertion.false", ErrorCtxt::AssertMethodPostcondition) => { + ("assert.failed:assertion.false", ErrorCtxt::AssertMethodPostcondition) + |("exhale.failed:assertion.false", ErrorCtxt::AssertMethodPostcondition)=> { PrustiError::verification("postcondition might not hold.".to_string(), error_span) .push_primary_span(opt_cause_span) } + ("inhale.failed:insufficient.permission", ErrorCtxt::MethodPostconditionFraming) + | ("application.precondition:insufficient.permission", ErrorCtxt::MethodPostconditionFraming) => { + PrustiError::verification("the postcondition might not be self-framing.".to_string(), error_span) + .push_primary_span(opt_cause_span) + .set_help("This error might be also caused by prover failing to prove that the postcondition implies itself") + } + ( "assert.failed:assertion.false", ErrorCtxt::AssertMethodPostconditionTypeInvariants, @@ -697,6 +736,44 @@ impl<'tcx> ErrorManager<'tcx> { ) } + ("call.precondition:insufficient.permission", ErrorCtxt::CopyPlace) => { + PrustiError::verification( + "the accessed place may not be allocated or initialized".to_string(), + error_span + ).set_failing_assertion(opt_cause_span) + } + + ("call.precondition:insufficient.permission", ErrorCtxt::WritePlace) => { + PrustiError::verification( + "the accessed memory location must be allocated and uninitialized".to_string(), + error_span + ).set_failing_assertion(opt_cause_span) + } + + ("exhale.failed:insufficient.permission", ErrorCtxt::AssertMethodPostcondition) | + ("application.precondition:insufficient.permission", ErrorCtxt::AssertMethodPostcondition) | + ("application.precondition:insufficient.permission", ErrorCtxt::TypeInvariantDefinition) => { + PrustiError::verification( + "there might be insufficient permission to dereference a raw pointer".to_string(), + error_span + ).set_failing_assertion(opt_cause_span) + } + + ("call.precondition:assertion.false", ErrorCtxt::Assign | ErrorCtxt::CopyPlace) => { + PrustiError::verification( + "the type invariant of the constructed object might not hold".to_string(), + error_span + ).set_failing_assertion(opt_cause_span) + } + + ("call.precondition:insufficient.permission", ErrorCtxt::LifetimeEncoding) => { + PrustiError::verification( + "there might be insufficient permission to a lifetime token".to_string(), + error_span + ).set_failing_assertion(opt_cause_span) + .set_help("This could be caused by an unclosed reference.") + } + (full_err_id, ErrorCtxt::Unexpected) => { PrustiError::internal( format!( diff --git a/prusti-viper/src/encoder/high/lower/expression.rs b/prusti-viper/src/encoder/high/lower/expression.rs index 3f53a8d5354..3b67a118136 100644 --- a/prusti-viper/src/encoder/high/lower/expression.rs +++ b/prusti-viper/src/encoder/high/lower/expression.rs @@ -75,6 +75,12 @@ impl IntoPolymorphic for vir_high::Expression { vir_high::Expression::Downcast(expression) => { vir_poly::Expr::Downcast(expression.lower(encoder)) } + vir_high::Expression::AccPredicate(_expression) => { + todo!() + } + vir_high::Expression::Unfolding(_expression) => { + todo!() + } } } } diff --git a/prusti-viper/src/encoder/high/procedures/inference/mod.rs b/prusti-viper/src/encoder/high/procedures/inference/mod.rs index c624ab4e426..867a5fac10a 100644 --- a/prusti-viper/src/encoder/high/procedures/inference/mod.rs +++ b/prusti-viper/src/encoder/high/procedures/inference/mod.rs @@ -25,6 +25,7 @@ mod permission; mod semantics; mod state; mod visitor; +mod unfolding_expressions; pub(super) fn infer_shape_operations<'v, 'tcx: 'v>( encoder: &mut Encoder<'v, 'tcx>, diff --git a/prusti-viper/src/encoder/high/procedures/inference/permission.rs b/prusti-viper/src/encoder/high/procedures/inference/permission.rs index 31c88f8379f..b499a26e2f7 100644 --- a/prusti-viper/src/encoder/high/procedures/inference/permission.rs +++ b/prusti-viper/src/encoder/high/procedures/inference/permission.rs @@ -44,3 +44,13 @@ pub(in super::super) enum PermissionKind { MemoryBlock, Owned, } + +impl Permission { + pub(in super::super) fn place(&self) -> &vir_typed::Expression { + match self { + Permission::MemoryBlock(place) => place, + Permission::Owned(place) => place, + Permission::MutBorrowed(MutBorrowed { place, .. }) => place, + } + } +} diff --git a/prusti-viper/src/encoder/high/procedures/inference/semantics.rs b/prusti-viper/src/encoder/high/procedures/inference/semantics.rs index 2f800993b0e..45b7390f482 100644 --- a/prusti-viper/src/encoder/high/procedures/inference/semantics.rs +++ b/prusti-viper/src/encoder/high/procedures/inference/semantics.rs @@ -18,9 +18,30 @@ pub(in super::super) fn collect_permission_changes<'v, 'tcx>( &mut consumed_permissions, &mut produced_permissions, )?; + consumed_permissions.retain(|permission| !permission.place().is_behind_pointer_dereference()); + produced_permissions.retain(|permission| !permission.place().is_behind_pointer_dereference()); + // remove_after_pointer_deref(&mut consumed_permissions); + // remove_after_pointer_deref(&mut produced_permissions); Ok((consumed_permissions, produced_permissions)) } +// fn remove_after_pointer_deref(permissions: &mut Vec) { +// permissions.retain_mut(|permission| { +// match permission { +// Permission::MemoryBlock(place) => { +// !place.is_behind_pointer_dereference() +// } +// Permission::Owned(place) => { +// if let Some(pointer_place) = place.get_first_dereferenced_pointer() { +// *place = pointer_place.clone(); +// } +// true +// } +// Permission::MutBorrowed(_) => unreachable!(), +// } +// }); +// } + trait CollectPermissionChanges { #[allow(clippy::ptr_arg)] // Clippy false positive. fn collect<'v, 'tcx>( @@ -45,10 +66,16 @@ impl CollectPermissionChanges for vir_typed::Statement { vir_typed::Statement::OldLabel(statement) => { statement.collect(encoder, consumed_permissions, produced_permissions) } - vir_typed::Statement::Inhale(statement) => { + vir_typed::Statement::InhalePredicate(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::ExhalePredicate(statement) => { statement.collect(encoder, consumed_permissions, produced_permissions) } - vir_typed::Statement::Exhale(statement) => { + vir_typed::Statement::InhaleExpression(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::ExhaleExpression(statement) => { statement.collect(encoder, consumed_permissions, produced_permissions) } vir_typed::Statement::Consume(statement) => { @@ -60,6 +87,9 @@ impl CollectPermissionChanges for vir_typed::Statement { vir_typed::Statement::GhostHavoc(statement) => { statement.collect(encoder, consumed_permissions, produced_permissions) } + vir_typed::Statement::HeapHavoc(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } vir_typed::Statement::GhostAssign(statement) => { statement.collect(encoder, consumed_permissions, produced_permissions) } @@ -93,6 +123,36 @@ impl CollectPermissionChanges for vir_typed::Statement { vir_typed::Statement::SetUnionVariant(statement) => { statement.collect(encoder, consumed_permissions, produced_permissions) } + vir_typed::Statement::Pack(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::Unpack(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::Join(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::JoinRange(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::Split(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::SplitRange(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::StashRange(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::StashRangeRestore(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::ForgetInitialization(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } + vir_typed::Statement::RestoreRawBorrowed(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } vir_typed::Statement::NewLft(statement) => { statement.collect(encoder, consumed_permissions, produced_permissions) } @@ -133,6 +193,17 @@ impl CollectPermissionChanges for vir_typed::Statement { } } +impl CollectPermissionChanges for vir_typed::HeapHavoc { + fn collect<'v, 'tcx>( + &self, + _encoder: &mut Encoder<'v, 'tcx>, + _consumed_permissions: &mut Vec, + _produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + Ok(()) + } +} + impl CollectPermissionChanges for vir_typed::GhostHavoc { fn collect<'v, 'tcx>( &self, @@ -193,14 +264,17 @@ fn extract_managed_predicate_place( vir_typed::Predicate::MemoryBlockStackDrop(_) | vir_typed::Predicate::LifetimeToken(_) | vir_typed::Predicate::MemoryBlockHeap(_) - | vir_typed::Predicate::MemoryBlockHeapDrop(_) => { + | vir_typed::Predicate::MemoryBlockHeapRange(_) + | vir_typed::Predicate::MemoryBlockHeapDrop(_) + | vir_typed::Predicate::OwnedRange(_) + | vir_typed::Predicate::OwnedSet(_) => { // Unmanaged predicates. Ok(None) } } } -impl CollectPermissionChanges for vir_typed::Inhale { +impl CollectPermissionChanges for vir_typed::InhalePredicate { fn collect<'v, 'tcx>( &self, _encoder: &mut Encoder<'v, 'tcx>, @@ -212,7 +286,7 @@ impl CollectPermissionChanges for vir_typed::Inhale { } } -impl CollectPermissionChanges for vir_typed::Exhale { +impl CollectPermissionChanges for vir_typed::ExhalePredicate { fn collect<'v, 'tcx>( &self, _encoder: &mut Encoder<'v, 'tcx>, @@ -224,6 +298,28 @@ impl CollectPermissionChanges for vir_typed::Exhale { } } +impl CollectPermissionChanges for vir_typed::InhaleExpression { + fn collect<'v, 'tcx>( + &self, + _encoder: &mut Encoder<'v, 'tcx>, + _consumed_permissions: &mut Vec, + _produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + Ok(()) + } +} + +impl CollectPermissionChanges for vir_typed::ExhaleExpression { + fn collect<'v, 'tcx>( + &self, + _encoder: &mut Encoder<'v, 'tcx>, + _consumed_permissions: &mut Vec, + _produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + Ok(()) + } +} + impl CollectPermissionChanges for vir_typed::Consume { fn collect<'v, 'tcx>( &self, @@ -294,6 +390,9 @@ impl CollectPermissionChanges for vir_typed::CopyPlace { consumed_permissions: &mut Vec, produced_permissions: &mut Vec, ) -> SpannedEncodingResult<()> { + // if let Some(source_pointer_place) = self.source.get_first_dereferenced_pointer() { + + // } consumed_permissions.push(Permission::MemoryBlock(self.target.clone())); consumed_permissions.push(Permission::Owned(self.source.clone())); produced_permissions.push(Permission::Owned(self.target.clone())); @@ -383,6 +482,9 @@ impl CollectPermissionChanges for vir_typed::Rvalue { Self::Len(rvalue) => { rvalue.collect(encoder, consumed_permissions, produced_permissions) } + Self::Cast(rvalue) => { + rvalue.collect(encoder, consumed_permissions, produced_permissions) + } Self::UnaryOp(rvalue) => { rvalue.collect(encoder, consumed_permissions, produced_permissions) } @@ -456,7 +558,7 @@ impl CollectPermissionChanges for vir_typed::ast::rvalue::AddressOf { &self, _encoder: &mut Encoder<'v, 'tcx>, consumed_permissions: &mut Vec, - produced_permissions: &mut Vec, + _produced_permissions: &mut Vec, ) -> SpannedEncodingResult<()> { // To take an address of a place on a stack, it must not be moved out. // The following fails to compile: @@ -476,7 +578,6 @@ impl CollectPermissionChanges for vir_typed::ast::rvalue::AddressOf { // } // ``` consumed_permissions.push(Permission::Owned(self.place.clone())); - produced_permissions.push(Permission::Owned(self.place.clone())); Ok(()) } } @@ -494,6 +595,19 @@ impl CollectPermissionChanges for vir_typed::ast::rvalue::Len { } } +impl CollectPermissionChanges for vir_typed::ast::rvalue::Cast { + fn collect<'v, 'tcx>( + &self, + encoder: &mut Encoder<'v, 'tcx>, + consumed_permissions: &mut Vec, + produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + self.operand + .collect(encoder, consumed_permissions, produced_permissions)?; + Ok(()) + } +} + impl CollectPermissionChanges for vir_typed::ast::rvalue::UnaryOp { fn collect<'v, 'tcx>( &self, @@ -617,6 +731,211 @@ impl CollectPermissionChanges for vir_typed::SetUnionVariant { } } +fn add_struct_expansion( + place: &vir_typed::Expression, + fields: Vec, + permissions: &mut Vec, +) { + let position = place.position(); + for field in fields { + permissions.push(Permission::Owned(vir_typed::Expression::field( + place.clone(), + field, + position, + ))); + } +} + +impl CollectPermissionChanges for vir_typed::Pack { + fn collect<'v, 'tcx>( + &self, + encoder: &mut Encoder<'v, 'tcx>, + consumed_permissions: &mut Vec, + produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + if self.place.is_behind_pointer_dereference() { + produced_permissions.push(Permission::Owned(self.place.clone())); + } else { + let type_decl = encoder.encode_type_def_typed(self.place.get_type())?; + if let vir_typed::TypeDecl::Struct(decl) = type_decl { + if decl.is_manually_managed_type() { + produced_permissions.push(Permission::Owned(self.place.clone())); + add_struct_expansion(&self.place, decl.fields, consumed_permissions); + } else { + unimplemented!( + "Unpacking an automatically managed type: {}\n{}", + self.place, + self.place.get_type(), + ); + } + } else { + unimplemented!( + "Report a proper error message that only structs can be unfolded: {:?}", + self.place + ); + } + } + Ok(()) + } +} + +impl CollectPermissionChanges for vir_typed::Unpack { + fn collect<'v, 'tcx>( + &self, + encoder: &mut Encoder<'v, 'tcx>, + consumed_permissions: &mut Vec, + produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + if self.place.is_behind_pointer_dereference() { + consumed_permissions.push(Permission::Owned(self.place.clone())); + } else { + let type_decl = encoder.encode_type_def_typed(self.place.get_type())?; + if let vir_typed::TypeDecl::Struct(decl) = type_decl { + if decl.is_manually_managed_type() { + consumed_permissions.push(Permission::Owned(self.place.clone())); + add_struct_expansion(&self.place, decl.fields, produced_permissions); + } else { + unimplemented!( + "Unpacking an automatically managed type: {}\n{}", + self.place, + self.place.get_type() + ); + } + } else { + unimplemented!( + "Report a proper error message that only structs can be unfolded: {}", + self.place + ); + } + } + Ok(()) + } +} + +impl CollectPermissionChanges for vir_typed::Join { + fn collect<'v, 'tcx>( + &self, + _encoder: &mut Encoder<'v, 'tcx>, + _consumed_permissions: &mut Vec, + _produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + if !self.place.is_behind_pointer_dereference() { + unimplemented!( + "Report a proper error message that only memory blocks behind \ + a raw pointer could be joined by hand: {}", + self.place + ); + } + Ok(()) + } +} + +impl CollectPermissionChanges for vir_typed::JoinRange { + fn collect<'v, 'tcx>( + &self, + _encoder: &mut Encoder<'v, 'tcx>, + _consumed_permissions: &mut Vec, + _produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + Ok(()) + } +} + +impl CollectPermissionChanges for vir_typed::Split { + fn collect<'v, 'tcx>( + &self, + _encoder: &mut Encoder<'v, 'tcx>, + _consumed_permissions: &mut Vec, + _produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + if !self.place.is_behind_pointer_dereference() { + unimplemented!( + "Report a proper error message that only memory blocks behind \ + a raw pointer could be split by hand: {}", + self.place + ); + } + Ok(()) + } +} + +impl CollectPermissionChanges for vir_typed::SplitRange { + fn collect<'v, 'tcx>( + &self, + _encoder: &mut Encoder<'v, 'tcx>, + _consumed_permissions: &mut Vec, + _produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + Ok(()) + } +} + +impl CollectPermissionChanges for vir_typed::StashRange { + fn collect<'v, 'tcx>( + &self, + _encoder: &mut Encoder<'v, 'tcx>, + _consumed_permissions: &mut Vec, + _produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + Ok(()) + } +} + +impl CollectPermissionChanges for vir_typed::StashRangeRestore { + fn collect<'v, 'tcx>( + &self, + _encoder: &mut Encoder<'v, 'tcx>, + _consumed_permissions: &mut Vec, + _produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + Ok(()) + } +} + +impl CollectPermissionChanges for vir_typed::ForgetInitialization { + fn collect<'v, 'tcx>( + &self, + encoder: &mut Encoder<'v, 'tcx>, + consumed_permissions: &mut Vec, + _produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + if self.place.is_behind_pointer_dereference() { + consumed_permissions.push(Permission::Owned(self.place.clone())); + } else { + let type_decl = encoder.encode_type_def_typed(self.place.get_type())?; + if let vir_typed::TypeDecl::Struct(decl) = &type_decl { + if decl.is_manually_managed_type() { + consumed_permissions.push(Permission::Owned(self.place.clone())); + } else { + unimplemented!( + "Forgetting initialization of an automatically managed type: {:?}\n{:?}", + self.place, + type_decl + ); + } + } else { + unimplemented!( + "Report a proper error message that only structs can be unfolded: {:?}", + self.place + ); + } + } + Ok(()) + } +} + +impl CollectPermissionChanges for vir_typed::RestoreRawBorrowed { + fn collect<'v, 'tcx>( + &self, + _encoder: &mut Encoder<'v, 'tcx>, + _consumed_permissions: &mut Vec, + produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + produced_permissions.push(Permission::Owned(self.restored_place.clone())); + Ok(()) + } +} + impl CollectPermissionChanges for vir_typed::NewLft { fn collect<'v, 'tcx>( &self, diff --git a/prusti-viper/src/encoder/high/procedures/inference/state/predicate_state_on_path.rs b/prusti-viper/src/encoder/high/procedures/inference/state/predicate_state_on_path.rs index 11ec253b1ff..acef2b135f5 100644 --- a/prusti-viper/src/encoder/high/procedures/inference/state/predicate_state_on_path.rs +++ b/prusti-viper/src/encoder/high/procedures/inference/state/predicate_state_on_path.rs @@ -62,7 +62,7 @@ impl PredicateStateOnPath { self.memory_block_stack.is_empty() && self.owned_non_aliased.iter().all(|place| { // `UniqueRef` and `FracRef` predicates can be leaked. - place.get_dereference_base().is_some() + place.get_last_dereferenced_reference().is_some() }) } diff --git a/prusti-viper/src/encoder/high/procedures/inference/unfolding_expressions.rs b/prusti-viper/src/encoder/high/procedures/inference/unfolding_expressions.rs new file mode 100644 index 00000000000..6629204e05a --- /dev/null +++ b/prusti-viper/src/encoder/high/procedures/inference/unfolding_expressions.rs @@ -0,0 +1,194 @@ +use crate::encoder::errors::{SpannedEncodingError, SpannedEncodingResult}; +use vir_crate::{ + common::position::Positioned, + typed::{ + self as vir_typed, + operations::ty::Typed, + visitors::{ + default_fallible_fold_binary_op, default_fallible_fold_expression, + ExpressionFallibleFolder, + }, + }, +}; + +pub(super) fn add_unfolding_expressions( + expression: vir_typed::Expression, +) -> SpannedEncodingResult { + let mut ensurer = Ensurer { + syntactically_framed_places: Vec::new(), + }; + ensurer.fallible_fold_expression(expression) +} + +struct Ensurer { + syntactically_framed_places: Vec, +} + +impl Ensurer { + fn add_unfolding( + &self, + place: vir_typed::Expression, + ) -> SpannedEncodingResult { + for framing_place in &self.syntactically_framed_places { + let mut unfolding_stack = Vec::new(); + if self.add_syntactic_unfolding_rec(&place, framing_place, &mut unfolding_stack)? { + let place = self.apply_unfolding_stack(place, unfolding_stack); + return Ok(place); + } + } + let mut unfolding_stack = Vec::new(); + self.add_self_unfolding_rec(&place, &mut unfolding_stack)?; + let place = self.apply_unfolding_stack(place, unfolding_stack); + Ok(place) + } + + fn apply_unfolding_stack( + &self, + mut place: vir_typed::Expression, + unfolding_stack: Vec, + ) -> vir_typed::Expression { + for unfolded_place in unfolding_stack { + let position = place.position(); + place = vir_typed::Expression::unfolding( + vir_typed::Predicate::owned_non_aliased(unfolded_place, position), + place, + position, + ); + } + place + } + + fn add_syntactic_unfolding_rec( + &self, + place: &vir_typed::Expression, + framing_place: &vir_typed::Expression, + unfolding_stack: &mut Vec, + ) -> SpannedEncodingResult { + if place == framing_place { + return Ok(true); + } else if !place.is_deref() { + if let Some(parent) = place.get_parent_ref() { + if self.add_syntactic_unfolding_rec(framing_place, parent, unfolding_stack)? { + unfolding_stack.push(parent.clone()); + return Ok(true); + } + } + }; + Ok(false) + } + + /// Just unfold on all levels except on deref. + /// + /// FIXME: This should take into account what places are actually framed by + /// the structural invariant. For example, if the invariant contains + /// `own!((*self.p).x)` (that is, it frames only one field of the struct), + /// then we currently will generate one unfolding too many (we would + /// generate unfolding of `self.p` even though we should not). + fn add_self_unfolding_rec( + &self, + place: &vir_typed::Expression, + unfolding_stack: &mut Vec, + ) -> SpannedEncodingResult<()> { + if let Some(parent) = place.get_parent_ref() { + if !parent.get_type().is_pointer() { + unfolding_stack.push(parent.clone()); + } + self.add_self_unfolding_rec(parent, unfolding_stack)?; + } + Ok(()) + } + + // fn add_unfolding(&self, place: vir_typed::Expression) -> SpannedEncodingResult { + // for framing_place in &self.syntactically_framed_places { + // let mut unfolding_stack = Vec::new(); + // if let Some(mut new_place) = self.add_unfolding_rec(&place, framing_place, &mut unfolding_stack)? { + // eprintln!("place: {}", place); + // eprintln!("new_place: {}", new_place); + // for unfolded_place in unfolding_stack { + // eprintln!(" unfolded_place: {}", unfolded_place); + // let position =new_place.position(); + // new_place = vir_typed::Expression::unfolding( + // vir_typed::Predicate::owned_non_aliased(unfolded_place, position), + // new_place, position); + // } + // eprintln!("final_place: {}", new_place); + // return Ok(new_place); + // } + // } + // Ok(place) + // } + + // fn add_unfolding_rec(&self, place: &vir_typed::Expression, framing_place: &vir_typed::Expression, + // unfolding_stack: &mut Vec, + // ) -> SpannedEncodingResult> { + // let new_place = if place == framing_place { + // Some(place.clone()) + // } else { + // if place.is_deref() { + // None + // } else { + // if let Some(parent) = place.get_parent_ref() { + // let result = self.add_unfolding_rec(framing_place, parent, unfolding_stack)?; + // if result.is_some() { + // unfolding_stack.push(parent.clone()); + // } + // result + // } else { + // None + // } + // } + // }; + // Ok(new_place) + // } +} + +impl ExpressionFallibleFolder for Ensurer { + type Error = SpannedEncodingError; + + fn fallible_fold_expression( + &mut self, + expression: vir_typed::Expression, + ) -> Result { + if expression.is_place() && expression.get_last_dereferenced_pointer().is_some() { + self.add_unfolding(expression) + } else { + default_fallible_fold_expression(self, expression) + } + } + + fn fallible_fold_binary_op( + &mut self, + mut binary_op: vir_typed::BinaryOp, + ) -> Result { + match binary_op.op_kind { + vir_typed::BinaryOpKind::And => { + if let vir_typed::Expression::AccPredicate(acc_predicate) = &*binary_op.left { + match &*acc_predicate.predicate { + vir_typed::Predicate::LifetimeToken(_) + | vir_typed::Predicate::MemoryBlockStack(_) + | vir_typed::Predicate::MemoryBlockStackDrop(_) + | vir_typed::Predicate::MemoryBlockHeap(_) + | vir_typed::Predicate::MemoryBlockHeapRange(_) + | vir_typed::Predicate::MemoryBlockHeapDrop(_) => { + default_fallible_fold_binary_op(self, binary_op) + } + vir_typed::Predicate::OwnedNonAliased(predicate) => { + let place = predicate.place.clone(); + binary_op.left = self.fallible_fold_expression_boxed(binary_op.left)?; + self.syntactically_framed_places.push(place); + binary_op.right = + self.fallible_fold_expression_boxed(binary_op.right)?; + self.syntactically_framed_places.pop(); + Ok(binary_op) + } + vir_typed::Predicate::OwnedRange(_) => todo!(), + vir_typed::Predicate::OwnedSet(_) => todo!(), + } + } else { + default_fallible_fold_binary_op(self, binary_op) + } + } + _ => default_fallible_fold_binary_op(self, binary_op), + } + } +} diff --git a/prusti-viper/src/encoder/high/procedures/inference/visitor/context.rs b/prusti-viper/src/encoder/high/procedures/inference/visitor/context.rs index 7e3d987ed60..5be8d28f9e7 100644 --- a/prusti-viper/src/encoder/high/procedures/inference/visitor/context.rs +++ b/prusti-viper/src/encoder/high/procedures/inference/visitor/context.rs @@ -1,6 +1,6 @@ use super::{super::ensurer::ExpandedPermissionKind, Visitor}; use crate::encoder::{ - errors::{ErrorCtxt, SpannedEncodingResult}, + errors::{ErrorCtxt, SpannedEncodingError, SpannedEncodingResult}, high::to_typed::types::HighToTypedTypeEncoderInterface, mir::errors::ErrorInterface, }; @@ -47,14 +47,29 @@ impl<'p, 'v, 'tcx> super::super::ensurer::Context for Visitor<'p, 'v, 'tcx> { let expansion = match type_decl { vir_typed::TypeDecl::Bool | vir_typed::TypeDecl::Int(_) - | vir_typed::TypeDecl::Float(_) - | vir_typed::TypeDecl::Pointer(_) => { - // Primitive type. Convert. - vec![(ExpandedPermissionKind::MemoryBlock, place.clone())] + | vir_typed::TypeDecl::Float(_) => { + // Primitive type. + unreachable!(); + } + vir_typed::TypeDecl::Pointer(_) => { + let target_type = ty.clone().unwrap_pointer().target_type; + let deref_place = + vir_typed::Expression::deref(place.clone(), *target_type, place.position()); + vec![(ExpandedPermissionKind::Same, deref_place)] } vir_typed::TypeDecl::Trusted(_) => unimplemented!("ty: {}", ty), vir_typed::TypeDecl::TypeVar(_) => unimplemented!("ty: {}", ty), - vir_typed::TypeDecl::Struct(decl) => expand_fields(place, decl.fields.iter()), + vir_typed::TypeDecl::Struct(decl) => { + // if decl.is_manually_managed_type() { + // let place_span = self.get_span(guiding_place.position()).unwrap(); + // let error = SpannedEncodingError::incorrect( + // "types with structural invariants are required to be managed manually", + // place_span, + // ); + // return Err(error); + // } + expand_fields(place, decl.fields.iter()) + } vir_typed::TypeDecl::Enum(decl) => { let position = place.position(); let variant_name = place.get_variant_name(guiding_place); diff --git a/prusti-viper/src/encoder/high/procedures/inference/visitor/mod.rs b/prusti-viper/src/encoder/high/procedures/inference/visitor/mod.rs index 592dbb556c2..02e38c1bbd1 100644 --- a/prusti-viper/src/encoder/high/procedures/inference/visitor/mod.rs +++ b/prusti-viper/src/encoder/high/procedures/inference/visitor/mod.rs @@ -20,10 +20,12 @@ use prusti_rustc_interface::hir::def_id::DefId; use rustc_hash::{FxHashMap, FxHashSet}; use std::collections::{btree_map::Entry, BTreeMap}; use vir_crate::{ - common::{display::cjoin, position::Positioned}, + common::{check_mode::CheckMode, display::cjoin, position::Positioned}, middle::{ self as vir_mid, - operations::{TypedToMiddleExpression, TypedToMiddleStatement, TypedToMiddleType}, + operations::{ + ty::Typed, TypedToMiddleExpression, TypedToMiddleStatement, TypedToMiddleType, + }, }, typed::{self as vir_typed}, }; @@ -34,6 +36,7 @@ mod debugging; pub(super) struct Visitor<'p, 'v, 'tcx> { encoder: &'p mut Encoder<'v, 'tcx>, _proc_def_id: DefId, + check_mode: Option, state_at_entry: BTreeMap, /// Used only for debugging purposes. state_at_exit: BTreeMap, @@ -55,6 +58,7 @@ impl<'p, 'v, 'tcx> Visitor<'p, 'v, 'tcx> { Self { encoder, _proc_def_id: proc_def_id, + check_mode: None, state_at_entry: Default::default(), state_at_exit: Default::default(), procedure_name: None, @@ -74,6 +78,7 @@ impl<'p, 'v, 'tcx> Visitor<'p, 'v, 'tcx> { entry_state: FoldUnfoldState, ) -> SpannedEncodingResult { self.procedure_name = Some(procedure.name.clone()); + self.check_mode = Some(procedure.check_mode); let mut path_disambiguators = BTreeMap::new(); for ((from, to), value) in procedure.get_path_disambiguators() { @@ -117,6 +122,7 @@ impl<'p, 'v, 'tcx> Visitor<'p, 'v, 'tcx> { let new_procedure = vir_mid::ProcedureDecl { name: self.procedure_name.take().unwrap(), check_mode, + position: procedure.position, entry: self.entry_label.take().unwrap(), exit: self.lower_label(&procedure.exit), basic_blocks: std::mem::take(&mut self.basic_blocks), @@ -219,7 +225,7 @@ impl<'p, 'v, 'tcx> Visitor<'p, 'v, 'tcx> { self.process_actions(actions)?; state.remove_permissions(&consumed_permissions)?; state.insert_permissions(produced_permissions)?; - match &statement { + match statement { vir_typed::Statement::ObtainMutRef(_) => { // The requirements already performed the needed changes. } @@ -231,7 +237,7 @@ impl<'p, 'v, 'tcx> Visitor<'p, 'v, 'tcx> { // the end of the exit block). state.clear()?; } - vir_typed::Statement::SetUnionVariant(variant_statement) => { + vir_typed::Statement::SetUnionVariant(ref variant_statement) => { let position = variant_statement.position(); // Split the memory block for the union itself. let parent = variant_statement.variant_place.get_parent_ref().unwrap(); @@ -256,6 +262,119 @@ impl<'p, 'v, 'tcx> Visitor<'p, 'v, 'tcx> { self.current_statements .push(statement.typed_to_middle_statement(self.encoder)?); } + vir_typed::Statement::Pack(pack_statement) => { + // state.remove_manually_managed(&pack_statement.place)?; + let position = pack_statement.position(); + let place = pack_statement + .place + .clone() + .typed_to_middle_expression(self.encoder)?; + // let encoded_statement = vir_mid::Statement::fold_owned(place, None, position); + // FIXME: Code duplication. + let encoded_statement = match pack_statement.predicate_kind { + vir_typed::ast::statement::PredicateKind::Owned => { + vir_mid::Statement::fold_owned(place, None, position) + } + vir_typed::ast::statement::PredicateKind::UniqueRef(predicate_kind) => { + // let first_reference = place + // .get_first_dereferenced_reference() + // .expect("TODO: Report a proper error"); + // let vir_mid::Type::Reference(reference) = first_reference.get_type() else { + // unreachable!() + // }; + // let lifetime = reference.lifetime.clone(); + vir_mid::Statement::fold_ref( + place, + predicate_kind.lifetime.typed_to_middle_type(self.encoder)?, + vir_mid::ty::Uniqueness::Unique, + None, + position, + ) + } + vir_typed::ast::statement::PredicateKind::FracRef(_) => todo!(), + }; + self.current_statements.push(encoded_statement); + } + vir_typed::Statement::Unpack(unpack_statement) => { + // state.insert_manually_managed(unpack_statement.place.clone())?; + let position = unpack_statement.position(); + let place = unpack_statement + .place + .typed_to_middle_expression(self.encoder)?; + // FIXME: Code duplication. + let encoded_statement = match unpack_statement.predicate_kind { + vir_typed::ast::statement::PredicateKind::Owned => { + vir_mid::Statement::unfold_owned(place, None, position) + } + vir_typed::ast::statement::PredicateKind::UniqueRef(predicate_kind) => { + // let first_reference = place + // .get_first_dereferenced_reference() + // .expect("TODO: Report a proper error"); + // let vir_mid::Type::Reference(reference) = first_reference.get_type() else { + // unreachable!() + // }; + // let lifetime = reference.lifetime.clone(); + vir_mid::Statement::unfold_ref( + place, + predicate_kind.lifetime.typed_to_middle_type(self.encoder)?, + vir_mid::ty::Uniqueness::Unique, + None, + position, + ) + } + vir_typed::ast::statement::PredicateKind::FracRef(_) => todo!(), + }; + self.current_statements.push(encoded_statement); + } + vir_typed::Statement::Join(join_statement) => { + let position = join_statement.position(); + let place = join_statement + .place + .typed_to_middle_expression(self.encoder)?; + let encoded_statement = vir_mid::Statement::join_block(place, None, None, position); + self.current_statements.push(encoded_statement); + } + vir_typed::Statement::Split(split_statement) => { + let position = split_statement.position(); + let place = split_statement + .place + .typed_to_middle_expression(self.encoder)?; + let encoded_statement = + vir_mid::Statement::split_block(place, None, None, position); + self.current_statements.push(encoded_statement); + } + vir_typed::Statement::ForgetInitialization(forget_statement) => { + // state.insert_manually_managed(forget_statement.place.clone())?; + let position = forget_statement.position(); + let place = forget_statement + .place + .typed_to_middle_expression(self.encoder)?; + let encoded_statement = + vir_mid::Statement::convert_owned_into_memory_block(place, None, position); + self.current_statements.push(encoded_statement); + } + vir_typed::Statement::InhaleExpression(mut inhale_statement) => { + if self.check_mode.unwrap() != CheckMode::PurificationFunctional { + inhale_statement.expression = + super::unfolding_expressions::add_unfolding_expressions( + inhale_statement.expression, + )?; + } + let inhale_statement = inhale_statement.typed_to_middle_statement(self.encoder)?; + self.current_statements + .push(vir_mid::Statement::InhaleExpression(inhale_statement)); + } + vir_typed::Statement::ExhaleExpression(mut exhale_statement) => { + if self.check_mode.unwrap() != CheckMode::PurificationFunctional { + exhale_statement.expression = + super::unfolding_expressions::add_unfolding_expressions( + exhale_statement.expression, + )?; + } + let exhale_statement = exhale_statement.typed_to_middle_statement(self.encoder)?; + self.current_statements + .push(vir_mid::Statement::ExhaleExpression(exhale_statement)); + } _ => { self.current_statements .push(statement.typed_to_middle_statement(self.encoder)?); diff --git a/prusti-viper/src/encoder/high/procedures/interface.rs b/prusti-viper/src/encoder/high/procedures/interface.rs index 94d76dec50f..1a1149266b7 100644 --- a/prusti-viper/src/encoder/high/procedures/interface.rs +++ b/prusti-viper/src/encoder/high/procedures/interface.rs @@ -85,6 +85,7 @@ impl<'v, 'tcx: 'v> Private for super::super::super::Encoder<'v, 'tcx> { let procedure_typed = vir_typed::ProcedureDecl { name: procedure_high.name, check_mode: procedure_high.check_mode, + position: procedure_high.position, entry: procedure_high.entry.high_to_typed_statement(self)?, exit: procedure_high.exit.high_to_typed_statement(self)?, basic_blocks, @@ -98,7 +99,7 @@ pub(crate) trait HighProcedureEncoderInterface<'tcx> { &mut self, proc_def_id: DefId, check_mode: CheckMode, - ) -> SpannedEncodingResult; + ) -> SpannedEncodingResult>; fn encode_type_core_proof( &mut self, ty: ty::Ty<'tcx>, @@ -111,14 +112,17 @@ impl<'v, 'tcx: 'v> HighProcedureEncoderInterface<'tcx> for super::super::super:: &mut self, proc_def_id: DefId, check_mode: CheckMode, - ) -> SpannedEncodingResult { - let procedure_high = self.encode_procedure_core_proof_high(proc_def_id, check_mode)?; - debug!("procedure_high:\n{}", procedure_high); - let procedure_typed = self.procedure_high_to_typed(procedure_high)?; - debug!("procedure_typed:\n{}", procedure_typed); - let procedure = - super::inference::infer_shape_operations(self, proc_def_id, procedure_typed)?; - Ok(procedure) + ) -> SpannedEncodingResult> { + let mut procedures = Vec::new(); + for procedure_high in self.encode_procedure_core_proof_high(proc_def_id, check_mode)? { + debug!("procedure_high:\n{}", procedure_high); + let procedure_typed = self.procedure_high_to_typed(procedure_high)?; + debug!("procedure_typed:\n{}", procedure_typed); + let procedure = + super::inference::infer_shape_operations(self, proc_def_id, procedure_typed)?; + procedures.push(procedure); + } + Ok(procedures) } fn encode_type_core_proof( @@ -126,7 +130,7 @@ impl<'v, 'tcx: 'v> HighProcedureEncoderInterface<'tcx> for super::super::super:: ty: ty::Ty<'tcx>, check_mode: CheckMode, ) -> SpannedEncodingResult { - assert_eq!(check_mode, CheckMode::CoreProof); + assert_eq!(check_mode, CheckMode::MemorySafety); let ty_high = self.encode_type_high(ty)?; ty_high.high_to_middle(self) } diff --git a/prusti-viper/src/encoder/high/to_typed/expression.rs b/prusti-viper/src/encoder/high/to_typed/expression.rs index 9d148071769..cd0ba792370 100644 --- a/prusti-viper/src/encoder/high/to_typed/expression.rs +++ b/prusti-viper/src/encoder/high/to_typed/expression.rs @@ -1,7 +1,10 @@ use crate::encoder::errors::{SpannedEncodingError, SpannedEncodingResult}; use vir_crate::{ - high as vir_high, typed as vir_typed, - typed::operations::{HighToTypedExpressionLowerer, HighToTypedType}, + high as vir_high, + typed::{ + self as vir_typed, + operations::{HighToTypedExpressionLowerer, HighToTypedPredicateLowerer, HighToTypedType}, + }, }; impl<'v, 'tcx> HighToTypedExpressionLowerer for crate::encoder::Encoder<'v, 'tcx> { @@ -64,4 +67,11 @@ impl<'v, 'tcx> HighToTypedExpressionLowerer for crate::encoder::Encoder<'v, 'tcx index: variant_index.index, }) } + + fn high_to_typed_expression_predicate( + &mut self, + predicate: vir_high::Predicate, + ) -> Result { + self.high_to_typed_predicate_predicate(predicate) + } } diff --git a/prusti-viper/src/encoder/high/to_typed/type_decl.rs b/prusti-viper/src/encoder/high/to_typed/type_decl.rs index 45564009fff..6a005e66053 100644 --- a/prusti-viper/src/encoder/high/to_typed/type_decl.rs +++ b/prusti-viper/src/encoder/high/to_typed/type_decl.rs @@ -54,11 +54,13 @@ impl<'v, 'tcx> HighToTypedTypeDeclLowerer for crate::encoder::Encoder<'v, 'tcx> self.generate_tuple_name(&arguments)?, decl.lifetimes.high_to_typed_type(self)?, decl.const_parameters.high_to_typed_expression(self)?, + None, arguments .into_iter() .enumerate() .map(|(index, ty)| vir_typed::FieldDecl::new(format!("tuple_{index}"), index, ty)) .collect(), + Default::default(), )) } @@ -131,4 +133,11 @@ impl<'v, 'tcx> HighToTypedTypeDeclLowerer for crate::encoder::Encoder<'v, 'tcx> variants: decl.variants.high_to_typed_type_decl(self)?, }) } + + fn high_to_typed_type_decl_position( + &mut self, + position: vir_high::Position, + ) -> Result { + Ok(position) + } } diff --git a/prusti-viper/src/encoder/high/to_typed/types/interface.rs b/prusti-viper/src/encoder/high/to_typed/types/interface.rs index 286bf1a57bf..eba8fa936f6 100644 --- a/prusti-viper/src/encoder/high/to_typed/types/interface.rs +++ b/prusti-viper/src/encoder/high/to_typed/types/interface.rs @@ -35,7 +35,7 @@ impl<'v, 'tcx: 'v> HighToTypedTypeEncoderInterface ty: &vir_typed::Type, ) -> SpannedEncodingResult { let high_type = &self.typed_type_encoder_state.encoded_types_inverse[ty]; - let type_decl_high = self.encode_type_def_high(high_type)?; + let type_decl_high = self.encode_type_def_high(high_type, true)?; type_decl_high.high_to_typed_type_decl(self) } diff --git a/prusti-viper/src/encoder/high/types/interface.rs b/prusti-viper/src/encoder/high/types/interface.rs index 9984bffa455..af722ae870a 100644 --- a/prusti-viper/src/encoder/high/types/interface.rs +++ b/prusti-viper/src/encoder/high/types/interface.rs @@ -65,7 +65,7 @@ impl<'v, 'tcx: 'v> HighTypeEncoderInterfacePrivate for super::super::super::Enco let encoded_type = &self.high_type_encoder_state.lowered_types_inverse.borrow() [predicate_name] .clone(); - let encoded_type_decl = self.encode_type_def_high(encoded_type)?; + let encoded_type_decl = self.encode_type_def_high(encoded_type, false)?; // FIXME: Change not to use `with_default_span` here. let predicates = encoded_type_decl .lower(encoded_type, self) @@ -288,7 +288,7 @@ impl<'v, 'tcx: 'v> HighTypeEncoderInterface<'tcx> for super::super::super::Encod ) -> SpannedEncodingResult { let high_type = self.decode_type_mid_into_high(ty.erase_lifetimes().erase_const_generics())?; - let high_type_decl = self.encode_type_def_high(&high_type)?; + let high_type_decl = self.encode_type_def_high(&high_type, true)?; high_type_decl.high_to_middle(self) } } diff --git a/prusti-viper/src/encoder/middle/core_proof/addresses/encoder.rs b/prusti-viper/src/encoder/middle/core_proof/addresses/encoder.rs index 0d7790a1916..1bf0cc59ccf 100644 --- a/prusti-viper/src/encoder/middle/core_proof/addresses/encoder.rs +++ b/prusti-viper/src/encoder/middle/core_proof/addresses/encoder.rs @@ -1,15 +1,29 @@ use super::{super::utils::place_domain_encoder::PlaceExpressionDomainEncoder, AddressesInterface}; -use crate::encoder::{errors::SpannedEncodingResult, middle::core_proof::lowerer::Lowerer}; +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + lowerer::Lowerer, pointers::PointersInterface, references::ReferencesInterface, + snapshots::IntoProcedureSnapshot, + }, +}; use vir_crate::{ low as vir_low, - middle::{self as vir_mid}, + middle::{self as vir_mid, operations::ty::Typed}, }; -pub(super) struct PlaceAddressEncoder {} +pub(super) struct PlaceAddressEncoder { + old_label: Option, +} + +impl PlaceAddressEncoder { + pub(super) fn new() -> Self { + Self { old_label: None } + } +} impl PlaceExpressionDomainEncoder for PlaceAddressEncoder { - fn domain_name(&mut self, _lowerer: &mut Lowerer) -> &str { - "Address" + fn domain_name(&mut self, lowerer: &mut Lowerer) -> &str { + lowerer.address_domain() } fn encode_local( @@ -17,16 +31,25 @@ impl PlaceExpressionDomainEncoder for PlaceAddressEncoder { local: &vir_mid::expression::Local, lowerer: &mut Lowerer, ) -> SpannedEncodingResult { - lowerer.root_address(local) + lowerer.root_address(local, &self.old_label) } fn encode_deref( &mut self, - _deref: &vir_mid::expression::Deref, - _lowerer: &mut Lowerer, + deref: &vir_mid::expression::Deref, + lowerer: &mut Lowerer, _arg: vir_low::Expression, ) -> SpannedEncodingResult { - unreachable!("The address cannot be dereferenced; use the value instead.") + // FIXME: Code duplication with AddressesInterface::extract_root_address + // FIXME: Code duplication with AssertionEncoder. + let base_snapshot = deref.base.to_procedure_snapshot(lowerer)?; + let ty = deref.base.get_type(); + let result = if ty.is_reference() { + lowerer.reference_address(ty, base_snapshot, deref.position)? + } else { + lowerer.pointer_address(ty, base_snapshot, deref.position)? + }; + Ok(result) } fn encode_array_index_axioms( @@ -36,4 +59,12 @@ impl PlaceExpressionDomainEncoder for PlaceAddressEncoder { ) -> SpannedEncodingResult<()> { Ok(()) } + + fn encode_labelled_old( + &mut self, + _expression: &vir_mid::expression::LabelledOld, + _lowerer: &mut Lowerer, + ) -> SpannedEncodingResult { + todo!() + } } diff --git a/prusti-viper/src/encoder/middle/core_proof/addresses/interface.rs b/prusti-viper/src/encoder/middle/core_proof/addresses/interface.rs index 93b37fb1b7c..75bf3e39c44 100644 --- a/prusti-viper/src/encoder/middle/core_proof/addresses/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/addresses/interface.rs @@ -4,23 +4,42 @@ use super::{ use crate::encoder::{ errors::SpannedEncodingResult, middle::core_proof::{ - lowerer::{DomainsLowererInterface, Lowerer, VariablesLowererInterface}, + lowerer::{DomainsLowererInterface, Lowerer}, + pointers::PointersInterface, references::ReferencesInterface, - snapshots::IntoProcedureSnapshot, + snapshots::{ + IntoProcedureSnapshot, IntoSnapshotLowerer, PredicateKind, + ProcedureExpressionToSnapshot, SnapshotVariablesInterface, + }, + type_layouts::TypeLayoutsInterface, }, }; use vir_crate::{ + common::{expression::QuantifierHelpers, position::Positioned}, low as vir_low, middle::{self as vir_mid, operations::ty::Typed}, }; pub(in super::super) trait AddressesInterface { + fn address_domain(&self) -> &'static str; fn address_type(&mut self) -> SpannedEncodingResult; + fn address_null( + &mut self, + position: vir_low::Position, + ) -> SpannedEncodingResult; + fn address_offset( + &mut self, + size: vir_low::Expression, + address: vir_low::Expression, + offset: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; /// Constructs a variable representing the address of the given MIR-level /// variable. fn root_address( &mut self, local: &vir_mid::expression::Local, + old_label: &Option, ) -> SpannedEncodingResult; /// Get the variable representing the root address of this place. fn extract_root_address( @@ -45,19 +64,96 @@ pub(in super::super) trait AddressesInterface { base_address: vir_low::Expression, position: vir_mid::Position, ) -> SpannedEncodingResult; + fn encode_index_address( + &mut self, + base_type: &vir_mid::Type, + base_address: vir_low::Expression, + index: vir_low::Expression, + position: vir_mid::Position, + ) -> SpannedEncodingResult; } impl<'p, 'v: 'p, 'tcx: 'v> AddressesInterface for Lowerer<'p, 'v, 'tcx> { + fn address_domain(&self) -> &'static str { + "Address" + } fn address_type(&mut self) -> SpannedEncodingResult { - self.domain_type("Address") + self.domain_type(self.address_domain()) + } + fn address_null( + &mut self, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let address_type = self.address_type()?; + self.create_domain_func_app( + "Address", + "null_address$", + Vec::new(), + address_type, + position, + ) + } + fn address_offset( + &mut self, + size: vir_low::Expression, + address: vir_low::Expression, + offset: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let address_type = self.address_type()?; + if !self.address_state.is_address_offset_axiom_encoded { + self.address_state.is_address_offset_axiom_encoded = true; + use vir_low::macros::*; + let size_type = self.size_type()?; + var_decls! { + address: Address, + index: Int, + size: {size_type.clone()} + } + let call = self.create_domain_func_app( + "Address", + "offset_address$", + vec![ + size.clone().into(), + address.clone().into(), + index.clone().into(), + ], + address_type.clone(), + position, + )?; + let injective_call = self.create_domain_func_app( + "Address", + "offset_address$inverse", + vec![address.clone().into(), call.clone()], + vir_low::Type::Int, + position, + )?; + let forall_body = expr! { + [injective_call] == index + }; + let body = vir_low::Expression::forall( + vec![size, address, index], + vec![vir_low::Trigger::new(vec![call.clone()])], + forall_body, + ); + let axiom = vir_low::DomainAxiomDecl::new(None, "offset_address$injective", body); + self.declare_axiom("Address", axiom)?; + } + self.create_domain_func_app( + "Address", + "offset_address$", + vec![size, address, offset], + address_type, + position, + ) } fn root_address( &mut self, local: &vir_mid::expression::Local, + old_label: &Option, ) -> SpannedEncodingResult { - let name = format!("{}$address", local.variable.name); - let ty = self.address_type()?; - let address_variable = self.create_variable(name, ty)?; + let address_variable = + self.address_variable_version_at_label(&local.variable.name, old_label)?; Ok(vir_low::Expression::local(address_variable, local.position)) } fn extract_root_address( @@ -66,11 +162,21 @@ impl<'p, 'v: 'p, 'tcx: 'v> AddressesInterface for Lowerer<'p, 'v, 'tcx> { ) -> SpannedEncodingResult { assert!(place.is_place()); let result = match place { - vir_mid::Expression::Local(local) => self.root_address(local)?, + vir_mid::Expression::Local(local) => self.root_address(local, &None)?, vir_mid::Expression::LabelledOld(_) => unimplemented!(), vir_mid::Expression::Deref(deref) => { - let base_snapshot = deref.base.to_procedure_snapshot(self)?; - self.reference_address(deref.base.get_type(), base_snapshot, Default::default())? + // FIXME: Code duplication with PlaceAddressEncoder + let mut place_encoder = + ProcedureExpressionToSnapshot::for_address(PredicateKind::Owned); + let base_snapshot = + place_encoder.expression_to_snapshot(self, &deref.base, false)?; + // let base_snapshot = deref.base.to_procedure_snapshot(self)?; + let ty = deref.base.get_type(); + if ty.is_reference() { + self.reference_address(ty, base_snapshot, place.position())? + } else { + self.pointer_address(ty, base_snapshot, place.position())? + } } _ => self.extract_root_address(place.get_parent_ref().unwrap())?, }; @@ -81,7 +187,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> AddressesInterface for Lowerer<'p, 'v, 'tcx> { &mut self, place: &vir_mid::Expression, ) -> SpannedEncodingResult { - let mut encoder = PlaceAddressEncoder {}; + let mut encoder = PlaceAddressEncoder::new(); encoder.encode_expression(place, self) } fn encode_field_address( @@ -108,4 +214,20 @@ impl<'p, 'v: 'p, 'tcx: 'v> AddressesInterface for Lowerer<'p, 'v, 'tcx> { position, ) } + fn encode_index_address( + &mut self, + base_type: &vir_mid::Type, + base_address: vir_low::Expression, + index: vir_low::Expression, + position: vir_mid::Position, + ) -> SpannedEncodingResult { + // FIXME: This implementation is most likely wrong. Test it properly. + let vir_mid::Type::Pointer(pointer_type) = base_type else { + unreachable!() + }; + let size = self + .encode_type_size_expression2(&*pointer_type.target_type, &*pointer_type.target_type)?; + let start_address = self.pointer_address(base_type, base_address, position)?; + self.address_offset(size, start_address, index.clone().into(), position) + } } diff --git a/prusti-viper/src/encoder/middle/core_proof/addresses/mod.rs b/prusti-viper/src/encoder/middle/core_proof/addresses/mod.rs index 472846c6772..8285a383997 100644 --- a/prusti-viper/src/encoder/middle/core_proof/addresses/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/addresses/mod.rs @@ -1,4 +1,5 @@ mod encoder; mod interface; +mod state; -pub(super) use self::interface::AddressesInterface; +pub(super) use self::{interface::AddressesInterface, state::AddressState}; diff --git a/prusti-viper/src/encoder/middle/core_proof/addresses/state.rs b/prusti-viper/src/encoder/middle/core_proof/addresses/state.rs new file mode 100644 index 00000000000..dbb584af0e3 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/addresses/state.rs @@ -0,0 +1,4 @@ +#[derive(Default)] +pub(in super::super) struct AddressState { + pub(super) is_address_offset_axiom_encoded: bool, +} diff --git a/prusti-viper/src/encoder/middle/core_proof/adts/interface.rs b/prusti-viper/src/encoder/middle/core_proof/adts/interface.rs index 3b64a1054dc..a821bb2717a 100644 --- a/prusti-viper/src/encoder/middle/core_proof/adts/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/adts/interface.rs @@ -2,6 +2,7 @@ use crate::encoder::{ errors::SpannedEncodingResult, middle::core_proof::lowerer::{DomainsLowererInterface, Lowerer}, }; +use prusti_common::config; use rustc_hash::FxHashSet; use std::borrow::Cow; use vir_crate::{ @@ -265,18 +266,14 @@ impl<'p, 'v: 'p, 'tcx: 'v> AdtsInterface for Lowerer<'p, 'v, 'tcx> { } // Injectivity axioms. - if parameters.is_empty() { - // No need to generate injectivity axioms if the constructor has no parameters. - return Ok(()); - } - if generate_injectivity_axioms { // We do not generate injectivity axioms for alternative // constructors (that would be unsound). use vir_low::macros::*; // Bottom-up injectivity axiom. - { + if !parameters.is_empty() { + // We need something to quantify over, so parameters cannot be empty. let mut triggers = Vec::new(); let mut conjuncts = Vec::new(); let constructor_call = self.adt_constructor_variant_call( @@ -344,12 +341,27 @@ impl<'p, 'v: 'p, 'tcx: 'v> AdtsInterface for Lowerer<'p, 'v, 'tcx> { } let constructor_call = self.adt_constructor_variant_call(domain_name, variant_name, arguments)?; + if parameters.is_empty() { + if let Some(guard) = &trigger_guard { + triggers.push(vir_low::Trigger::new(vec![guard.clone()])); + } else { + unimplemented!("figure out what triggers to choose!"); + } + } + if !config::use_snapshot_parameters_in_predicates() && !parameters.is_empty() { + triggers.push(vir_low::Trigger::new(vec![constructor_call.clone()])); + } let equality = expr! { value == [constructor_call] }; let forall_body = if let Some(guard) = guard { expr! { [guard] ==> [equality] } } else { equality }; + assert!( + !triggers.is_empty(), + "empty triggers for {}", + constructor_name + ); let axiom = vir_low::DomainAxiomDecl { comment: None, name: format!("{constructor_name}$top_down_injectivity_axiom"), diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/assertion_encoder.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/assertion_encoder.rs new file mode 100644 index 00000000000..0a698105509 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/assertion_encoder.rs @@ -0,0 +1,449 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + high::types::HighTypeEncoderInterface, + middle::core_proof::{ + builtin_methods::CallContext, + lowerer::Lowerer, + places::PlacesInterface, + pointers::PointersInterface, + predicates::{PredicatesMemoryBlockInterface, PredicatesOwnedInterface}, + references::ReferencesInterface, + snapshots::{IntoSnapshotLowerer, SnapshotValuesInterface, SnapshotVariablesInterface}, + }, +}; +use prusti_common::config; +use std::collections::BTreeMap; +use vir_crate::{ + common::{expression::BinaryOperationHelpers, position::Positioned}, + low::{self as vir_low}, + middle::{self as vir_mid, operations::ty::Typed}, +}; + +// TODO: Delete this file. +pub(in super::super::super) struct AssertionEncoder<'a> { + /// A map from field names to arguments that are being assigned to these + /// fields. + field_arguments: BTreeMap, + heap: &'a Option, + result_value: Option, + replace_self_with_result_value: bool, + in_function: bool, +} + +impl<'a> AssertionEncoder<'a> { + pub(in super::super::super) fn new( + decl: &vir_mid::type_decl::Struct, + operand_values: Vec, + heap: &'a Option, + ) -> Self { + let mut field_arguments = BTreeMap::default(); + // assert_eq!(decl.fields.len(), operand_values.len()); FIXME: Split + // into two assertion encoders: one that uses result value and one that + // usess field_arguments. + for (field, operand) in decl.fields.iter().zip(operand_values.into_iter()) { + assert!(field_arguments + .insert(field.name.clone(), operand) + .is_none()); + } + Self { + field_arguments, + heap, + result_value: None, + replace_self_with_result_value: false, + in_function: false, + } + } + + // FIXME: Code duplication. + fn pointer_deref_into_address<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + if let Some(deref_place) = place.get_last_dereferenced_pointer() { + let base_snapshot = self.expression_to_snapshot(lowerer, deref_place, true)?; + let ty = deref_place.get_type(); + lowerer.pointer_address(ty, base_snapshot, place.position()) + // match deref_place { + // vir_mid::Expression::Deref(deref) => { + // let base_snapshot = self.expression_to_snapshot(lowerer, &deref.base, true)?; + // let ty = deref.base.get_type(); + // assert!(ty.is_pointer()); + // lowerer.pointer_address(ty, base_snapshot, place.position()) + // } + // _ => unreachable!(), + // } + } else { + unreachable!() + } + // PlaceExpressionDomainEncoder::encode_expression(self, place, lowerer) + } + + pub(super) fn address_in_heap<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + let pointer = self.expression_to_snapshot(lowerer, pointer_place, true)?; + let address = + lowerer.pointer_address(pointer_place.get_type(), pointer, pointer_place.position())?; + let in_heap = vir_low::Expression::container_op_no_pos( + vir_low::ContainerOpKind::MapContains, + self.heap.as_ref().unwrap().ty.clone(), + vec![self.heap.clone().unwrap().into(), address], + ); + Ok(in_heap) + } + + // pub(in super::super::super) fn set_in_function(&mut self) { + // assert!(!self.in_function); + // self.in_function = true; + // } + + pub(in super::super::super) fn set_result_value( + &mut self, + result_value: vir_low::VariableDecl, + ) { + assert!(self.result_value.is_none()); + self.result_value = Some(result_value); + } + + pub(super) fn unset_result_value(&mut self) { + assert!(self.result_value.is_some()); + self.result_value = None; + } + + fn acc_predicate_to_snapshot_precondition<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + acc_predicate: &vir_mid::AccPredicate, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + assert!(expect_math_bool); + let expression = match &*acc_predicate.predicate { + vir_mid::Predicate::OwnedNonAliased(predicate) => { + let ty = predicate.place.get_type(); + let place = lowerer.encode_expression_as_place(&predicate.place)?; + // eprintln!("predicate: {}", predicate); + let root_address = self.pointer_deref_into_address(lowerer, &predicate.place)?; + // eprintln!("root_address2: {}", root_address); + // let deref = predicate.place.clone().unwrap_deref(); + // let base_snapshot = + // self.expression_to_snapshot(lowerer, &deref.base, expect_math_bool)?; + // let snapshot = lowerer.pointer_target_snapshot_in_heap( + // deref.base.get_type(), + // self.heap.clone(), + // base_snapshot, + // deref.position, + // )?; + + let snapshot = if config::use_snapshot_parameters_in_predicates() { + self.expression_to_snapshot(lowerer, &predicate.place, expect_math_bool)? + } else { + // FIXME: cleanup code + if lowerer.use_heap_variable()? { + let deref = predicate.place.clone().unwrap_deref(); + let base_snapshot = + self.expression_to_snapshot(lowerer, &deref.base, expect_math_bool)?; + let snapshot = lowerer.pointer_target_snapshot_in_heap( + deref.base.get_type(), + self.heap.clone().unwrap(), + base_snapshot, + deref.position, + )?; + snapshot + } else { + true.into() + } + }; + + if lowerer.use_heap_variable()? { + // let snapshot = self.expression_to_snapshot(lowerer, &predicate.place, expect_math_bool)?; + lowerer.owned_non_aliased( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address, + snapshot, + None, + )? + } else { + lowerer.owned_non_aliased_predicate( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address, + snapshot, + None, + )? + } + } + vir_mid::Predicate::MemoryBlockHeap(predicate) => { + let place = lowerer.encode_expression_as_place(&predicate.address)?; + let root_address = self.pointer_deref_into_address(lowerer, &predicate.address)?; + use vir_low::macros::*; + let compute_address = ty!(Address); + let address = expr! { + ComputeAddress::compute_address([place], [root_address]) + }; + let size = + self.expression_to_snapshot(lowerer, &predicate.size, expect_math_bool)?; + lowerer.encode_memory_block_stack_acc(address, size, acc_predicate.position)? + } + vir_mid::Predicate::MemoryBlockHeapDrop(predicate) => { + let place = self.pointer_deref_into_address(lowerer, &predicate.address)?; + let size = + self.expression_to_snapshot(lowerer, &predicate.size, expect_math_bool)?; + lowerer.encode_memory_block_heap_drop_acc(place, size, acc_predicate.position)? + } + _ => unimplemented!("{acc_predicate}"), + }; + Ok(expression) + } + + fn acc_predicate_to_snapshot_postcondition<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + acc_predicate: &vir_mid::AccPredicate, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + assert!(expect_math_bool); + let expression = match &*acc_predicate.predicate { + vir_mid::Predicate::OwnedNonAliased(predicate) => { + let position = predicate.position; + let ty = predicate.place.get_type(); + let place = lowerer.encode_expression_as_place(&predicate.place)?; + let old_value = self.replace_self_with_result_value; + self.replace_self_with_result_value = true; + let root_address_self = + self.pointer_deref_into_address(lowerer, &predicate.place)?; + self.replace_self_with_result_value = old_value; + let snap_call_self = lowerer.owned_non_aliased_snap( + CallContext::BuiltinMethod, + ty, + ty, + place.clone(), + root_address_self, + position, + )?; + if self.in_function { + let snap_call_result_value = + self.expression_to_snapshot(lowerer, &predicate.place, expect_math_bool)?; + vir_low::Expression::equals(snap_call_result_value, snap_call_self) + } else { + let root_address_parameter = + self.pointer_deref_into_address(lowerer, &predicate.place)?; + let snap_call_parameter = lowerer.owned_non_aliased_snap( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address_parameter, + position, + )?; + vir_low::Expression::equals( + snap_call_parameter, + vir_low::Expression::labelled_old(None, snap_call_self, position), + ) + } + } + vir_mid::Predicate::MemoryBlockHeap(_) | vir_mid::Predicate::MemoryBlockHeapDrop(_) => { + true.into() + } + _ => unimplemented!("{acc_predicate}"), + }; + Ok(expression) + } +} + +// impl<'a> PlaceExpressionDomainEncoder for AssertionEncoder<'a> { +// fn domain_name(&mut self, lowerer: &mut Lowerer) -> &str { +// lowerer.address_domain() +// } + +// fn encode_local( +// &mut self, +// local: &vir_mid::expression::Local, +// lowerer: &mut Lowerer, +// ) -> SpannedEncodingResult { +// lowerer.root_address(local, &None) +// } + +// fn encode_deref( +// &mut self, +// deref: &vir_mid::expression::Deref, +// lowerer: &mut Lowerer, +// _arg: vir_low::Expression, +// ) -> SpannedEncodingResult { +// let base_snapshot = self.expression_to_snapshot(lowerer, &deref.base, true)?; +// let ty = deref.base.get_type(); +// let result = if ty.is_reference() { +// lowerer.reference_address(ty, base_snapshot, deref.position)? +// } else { +// lowerer.pointer_address(ty, base_snapshot, deref.position)? +// }; +// Ok(result) +// } + +// fn encode_labelled_old( +// &mut self, +// _expression: &vir_mid::expression::LabelledOld, +// _lowerer: &mut Lowerer, +// ) -> SpannedEncodingResult { +// todo!() +// } + +// fn encode_array_index_axioms( +// &mut self, +// _base_type: &vir_mid::Type, +// _lowerer: &mut Lowerer, +// ) -> SpannedEncodingResult<()> { +// todo!() +// } +// } + +impl<'a, 'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for AssertionEncoder<'a> { + fn variable_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + variable: &vir_mid::VariableDecl, + ) -> SpannedEncodingResult { + if self.replace_self_with_result_value || self.in_function { + assert!(variable.is_self_variable()); + Ok(self.result_value.clone().unwrap()) + } else { + Ok(vir_low::VariableDecl { + name: variable.name.clone(), + ty: self.type_to_snapshot(lowerer, &variable.ty)?, + }) + } + } + + fn labelled_old_to_snapshot( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _old: &vir_mid::LabelledOld, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn func_app_to_snapshot( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _app: &vir_mid::FuncApp, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn acc_predicate_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + acc_predicate: &vir_mid::AccPredicate, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + let expression = if self.result_value.is_some() { + self.acc_predicate_to_snapshot_postcondition(lowerer, acc_predicate, expect_math_bool)? + } else { + self.acc_predicate_to_snapshot_precondition(lowerer, acc_predicate, expect_math_bool)? + }; + Ok(expression) + } + + fn field_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + field: &vir_mid::Field, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + match &*field.base { + vir_mid::Expression::Local(local) + if !self.replace_self_with_result_value && !self.in_function => + { + assert!(local.variable.is_self_variable()); + Ok(self.field_arguments[&field.field.name].clone()) + // if self.replace_self_with_result_value { + // Ok(self.result_value.clone().unwrap().into()) + // } else + // {Ok(self.field_arguments[&field.field.name].clone())} + } + _ => { + // FIXME: Code duplication because Rust does not have syntax for calling + // overriden methods. + let base_snapshot = + self.expression_to_snapshot(lowerer, &field.base, expect_math_bool)?; + let result = if field.field.is_discriminant() { + let ty = field.base.get_type(); + // FIXME: Create a method for obtainging the discriminant type. + let type_decl = lowerer.encoder.get_type_decl_mid(ty)?; + let enum_decl = type_decl.unwrap_enum(); + let discriminant_call = + lowerer.obtain_enum_discriminant(base_snapshot, ty, field.position)?; + lowerer.construct_constant_snapshot( + &enum_decl.discriminant_type, + discriminant_call, + field.position, + )? + } else { + lowerer.obtain_struct_field_snapshot( + field.base.get_type(), + &field.field, + base_snapshot, + field.position, + )? + }; + self.ensure_bool_expression(lowerer, field.get_type(), result, expect_math_bool) + } + } + } + + fn deref_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + deref: &vir_mid::Deref, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + let base_snapshot = self.expression_to_snapshot(lowerer, &deref.base, expect_math_bool)?; + let ty = deref.base.get_type(); + let result = if ty.is_reference() { + lowerer.reference_target_current_snapshot(ty, base_snapshot, Default::default())? + } else if lowerer.use_heap_variable()? { + lowerer.pointer_target_snapshot_in_heap( + deref.base.get_type(), + self.heap.clone().unwrap(), + base_snapshot, + deref.position, + )? + } else { + // eprintln!("deref: {}", deref); + // unimplemented!() + true.into() // TODO + }; + self.ensure_bool_expression(lowerer, deref.get_type(), result, expect_math_bool) + } + + fn owned_non_aliased_snap( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _ty: &vir_mid::Type, + _pointer_snapshot: &vir_mid::Expression, + ) -> SpannedEncodingResult { + unimplemented!() + } + + fn call_context(&self) -> CallContext { + CallContext::BuiltinMethod + } + + // fn unfolding_to_snapshot( + // &mut self, + // _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + // _unfolding: &vir_mid::Unfolding, + // _expect_math_bool: bool, + // ) -> SpannedEncodingResult { + // todo!() + // } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/change_unique_ref_place.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/change_unique_ref_place.rs index 2a11f96b51b..93bb81ce6fa 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/change_unique_ref_place.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/change_unique_ref_place.rs @@ -6,7 +6,7 @@ use crate::encoder::{ lifetimes::LifetimesInterface, lowerer::Lowerer, places::PlacesInterface, - predicates::UniqueRefUseBuilder, + predicates::PredicatesOwnedInterface, references::ReferencesInterface, snapshots::{IntoPureSnapshot, IntoSnapshot}, }, @@ -106,6 +106,11 @@ impl<'l, 'p, 'v, 'tcx> ChangeUniqueRefPlaceMethodBuilder<'l, 'p, 'v, 'tcx> { self.source_snapshot.clone().into(), self.inner.position, )?; + let slice_len = self.inner.lowerer.reference_slice_len( + self.inner.ty, + self.source_snapshot.clone().into(), + self.inner.position, + )?; let deref_source_place = self .inner .lowerer @@ -133,8 +138,7 @@ impl<'l, 'p, 'v, 'tcx> ChangeUniqueRefPlaceMethodBuilder<'l, 'p, 'v, 'tcx> { .lowerer .encode_lifetime_const_into_pure_is_alive_variable(lifetime)?; let lifetime = lifetime.to_pure_snapshot(self.inner.lowerer)?; - let mut builder = UniqueRefUseBuilder::new( - self.lowerer(), + let source_expression = self.lowerer().unique_ref( CallContext::BuiltinMethod, &target_type, &target_type_decl, @@ -143,14 +147,11 @@ impl<'l, 'p, 'v, 'tcx> ChangeUniqueRefPlaceMethodBuilder<'l, 'p, 'v, 'tcx> { current_snapshot.clone(), final_snapshot.clone(), lifetime.clone().into(), + slice_len.clone(), )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let source_expression = builder.build(); self.add_precondition(expr! { [lifetime_alive.clone().into()] ==> [source_expression] }); - let mut builder = UniqueRefUseBuilder::new( - self.lowerer(), + let target_expression = self.lowerer().unique_ref( CallContext::BuiltinMethod, &target_type, &target_type_decl, @@ -159,10 +160,8 @@ impl<'l, 'p, 'v, 'tcx> ChangeUniqueRefPlaceMethodBuilder<'l, 'p, 'v, 'tcx> { current_snapshot, final_snapshot, lifetime.into(), + slice_len, )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let target_expression = builder.build(); self.add_postcondition(expr! { [lifetime_alive.into()] ==> [target_expression] }); Ok(()) } diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/copy_place.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/copy_place.rs index 2ae1cebfedc..71b0a9cfe30 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/copy_place.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/copy_place.rs @@ -115,10 +115,25 @@ impl<'l, 'p, 'v, 'tcx> CopyPlaceMethodBuilder<'l, 'p, 'v, 'tcx> { ) } + pub(in super::super::super::super) fn create_source_owned_predicate( + &mut self, + ) -> SpannedEncodingResult { + self.inner.inner.lowerer.owned_non_aliased_predicate( + CallContext::BuiltinMethod, + self.inner.inner.ty, + self.inner.inner.type_decl, + self.inner.source_place.clone().into(), + self.inner.source_root_address.clone().into(), + self.inner.source_snapshot.clone().into(), + Some(self.source_permission_amount.clone().into()), + ) + } + pub(in super::super::super::super) fn create_target_owned( &mut self, + must_be_predicate: bool, ) -> SpannedEncodingResult { - self.inner.create_target_owned() + self.inner.create_target_owned(must_be_predicate) } pub(in super::super::super::super) fn add_target_validity_postcondition( diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/duplicate_frac_ref.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/duplicate_frac_ref.rs index 58f11ea5cff..dec734ff928 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/duplicate_frac_ref.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/duplicate_frac_ref.rs @@ -5,7 +5,7 @@ use crate::encoder::{ builtin_methods::CallContext, lowerer::Lowerer, places::PlacesInterface, - predicates::FracRefUseBuilder, + predicates::{FracRefUseBuilder, PredicatesOwnedInterface}, references::ReferencesInterface, snapshots::{IntoPureSnapshot, IntoSnapshot}, }, @@ -128,16 +128,29 @@ impl<'l, 'p, 'v, 'tcx> DuplicateFracRefMethodBuilder<'l, 'p, 'v, 'tcx> { &target_type_decl, deref_source_place, root_address.clone(), - current_snapshot.clone(), + // current_snapshot.clone(), lifetime.clone().into(), )?; builder.add_lifetime_arguments()?; builder.add_const_arguments()?; - let source_expression = builder.build(); + let source_expression = builder.build()?; self.add_precondition(source_expression.clone()); self.add_postcondition(source_expression); - let mut builder = FracRefUseBuilder::new( - self.lowerer(), + // let mut builder = FracRefUseBuilder::new( + // self.lowerer(), + // CallContext::BuiltinMethod, + // &target_type, + // &target_type_decl, + // deref_target_place, + // root_address, + // // current_snapshot, + // lifetime.into(), + // )?; + // builder.add_lifetime_arguments()?; + // builder.add_const_arguments()?; + // let target_expression = builder.build(); + let TODO_target_slice_len = None; + let target_expression = self.inner.lowerer.frac_ref_predicate( CallContext::BuiltinMethod, &target_type, &target_type_decl, @@ -145,10 +158,8 @@ impl<'l, 'p, 'v, 'tcx> DuplicateFracRefMethodBuilder<'l, 'p, 'v, 'tcx> { root_address, current_snapshot, lifetime.into(), + TODO_target_slice_len, )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let target_expression = builder.build(); self.add_postcondition(target_expression); Ok(()) } diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_into.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_into.rs index 7e753f27d06..b9eb2d29cc5 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_into.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_into.rs @@ -9,7 +9,7 @@ use crate::encoder::{ }, lowerer::Lowerer, places::PlacesInterface, - predicates::OwnedNonAliasedUseBuilder, + predicates::PredicatesOwnedInterface, snapshots::{IntoSnapshot, SnapshotValuesInterface}, }, }; @@ -73,19 +73,17 @@ impl<'l, 'p, 'v, 'tcx> IntoMemoryBlockMethodBuilder<'l, 'p, 'v, 'tcx> { // FIXME: Remove code duplication with create_source_owned. pub(in super::super::super::super) fn create_owned( &mut self, + must_be_predicate: bool, ) -> SpannedEncodingResult { - let mut builder = OwnedNonAliasedUseBuilder::new( - self.inner.lowerer, + self.inner.lowerer.owned_non_aliased_full_vars( CallContext::BuiltinMethod, self.inner.ty, self.inner.type_decl, - self.place.clone().into(), - self.root_address.clone().into(), - self.snapshot.clone().into(), - )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - Ok(builder.build()) + &self.place, + &self.root_address, + &self.snapshot, + must_be_predicate, + ) } pub(in super::super::super::super) fn create_target_memory_block( diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_range_join.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_range_join.rs new file mode 100644 index 00000000000..9df0105fce7 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_range_join.rs @@ -0,0 +1,381 @@ +use super::{ + common::{BuiltinMethodBuilder, BuiltinMethodBuilderMethods}, + memory_block_range_split_join_common::MemoryBlockRangeSplitJoinMethodBuilder, + memory_block_split_join_common::BuiltinMethodSplitJoinBuilderMethods, +}; +use crate::encoder::{ + errors::{BuiltinMethodKind, SpannedEncodingResult}, + middle::core_proof::{ + addresses::AddressesInterface, + lowerer::{DomainsLowererInterface, Lowerer}, + predicates::PredicatesMemoryBlockInterface, + snapshots::{IntoSnapshot, SnapshotValuesInterface}, + type_layouts::TypeLayoutsInterface, + }, +}; +use vir_crate::{ + common::expression::{BinaryOperationHelpers, ExpressionIterator, QuantifierHelpers}, + low::{self as vir_low}, + middle as vir_mid, +}; + +pub(in super::super::super::super) struct MemoryBlockRangeJoinMethodBuilder<'l, 'p, 'v, 'tcx> { + inner: MemoryBlockRangeSplitJoinMethodBuilder<'l, 'p, 'v, 'tcx>, +} + +impl<'l, 'p, 'v, 'tcx> BuiltinMethodBuilderMethods<'l, 'p, 'v, 'tcx> + for MemoryBlockRangeJoinMethodBuilder<'l, 'p, 'v, 'tcx> +{ + fn inner(&mut self) -> &mut BuiltinMethodBuilder<'l, 'p, 'v, 'tcx> { + self.inner.inner() + } +} + +impl<'l, 'p, 'v, 'tcx> BuiltinMethodSplitJoinBuilderMethods<'l, 'p, 'v, 'tcx> + for MemoryBlockRangeJoinMethodBuilder<'l, 'p, 'v, 'tcx> +{ +} + +impl<'l, 'p, 'v, 'tcx> MemoryBlockRangeJoinMethodBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + kind: vir_low::MethodKind, + method_name: &'l str, + ty: &'l vir_mid::Type, + type_decl: &'l vir_mid::TypeDecl, + error_kind: BuiltinMethodKind, + ) -> SpannedEncodingResult { + Ok(Self { + inner: MemoryBlockRangeSplitJoinMethodBuilder::new( + lowerer, + kind, + method_name, + ty, + type_decl, + error_kind, + )?, + }) + } + + pub(in super::super::super::super) fn build(self) -> vir_low::MethodDecl { + self.inner.build() + } + + pub(in super::super::super::super) fn create_parameters( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.create_parameters() + } + + // pub(in super::super::super::super) fn add_permission_amount_positive_precondition( + // &mut self, + // ) -> SpannedEncodingResult<()> { + // self.inner.add_permission_amount_positive_precondition() + // } + + pub(in super::super::super::super) fn add_whole_memory_block_postcondition( + &mut self, + ) -> SpannedEncodingResult<()> { + let memory_block = self.inner.create_whole_block_acc()?; + self.add_postcondition(memory_block); + Ok(()) + } + + pub(in super::super::super::super) fn add_memory_block_range_precondition( + &mut self, + ) -> SpannedEncodingResult<()> { + let memory_block_range = self.inner.create_memory_block_range_acc()?; + self.add_precondition(memory_block_range); + Ok(()) + } + + // FIXME: Code duplication. + pub(in super::super::super::super) fn add_byte_values_preserved_postcondition( + &mut self, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let element_size = self + .inner + .inner + .lowerer + .encode_type_size_expression2(self.inner.inner.ty, self.inner.inner.type_decl)?; + let length = self.inner.length()?; + let whole_size = self + .inner + .inner + .lowerer + .encode_type_size_expression_repetitions( + self.inner.inner.ty, + self.inner.inner.type_decl, + length, + self.inner.inner.position, + )?; + let size_type = self.inner.inner.lowerer.size_type_mid()?; + var_decls! { + index: Int, + byte_index: Int + } + let address: vir_low::Expression = self.inner.address.clone().into(); + let element_address = self.inner.inner.lowerer.address_offset( + element_size.clone(), + address.clone(), + index.clone().into(), + self.inner.inner.position, + )?; + let start_index = self.inner.inner.lowerer.obtain_constant_value( + &size_type, + self.inner.start_index.clone().into(), + self.inner.inner.position, + )?; + let end_index = self.inner.inner.lowerer.obtain_constant_value( + &size_type, + self.inner.end_index.clone().into(), + self.inner.inner.position, + )?; + let element_bytes = self + .inner + .inner + .lowerer + .encode_memory_block_bytes_expression(element_address.clone(), element_size.clone())?; + let whole_bytes = self + .inner + .inner + .lowerer + .encode_memory_block_bytes_expression(address.clone(), whole_size.clone())?; + let read_element_byte = self.inner.inner.lowerer.encode_read_byte_expression( + vir_low::Expression::labelled_old( + None, + element_bytes.clone(), + self.inner.inner.position, + ), + byte_index.clone().into(), + self.inner.inner.position, + )?; + let block_size = self.inner.inner.lowerer.obtain_constant_value( + &size_type, + element_size.clone(), + self.inner.inner.position, + )?; + // let block_start_index = vir_low::Expression::multiply( + // block_size, + // index.clone().into(), + // ); + // let whole_byte_index =vir_low::Expression::add(block_start_index, byte_index.clone().into()); + let whole_byte_index = self.inner.inner.lowerer.create_domain_func_app( + "Arithmetic", + "mul_add", + vec![block_size, index.clone().into(), byte_index.clone().into()], + vir_low::Type::Int, + self.inner.inner.position, + )?; + let read_whole_byte = self.inner.inner.lowerer.encode_read_byte_expression( + whole_bytes.clone(), + whole_byte_index, + self.inner.inner.position, + )?; + let element_size_int = self.inner.inner.lowerer.obtain_constant_value( + &size_type, + element_size.clone(), + self.inner.inner.position, + )?; + let body = expr!( + ((([start_index] <= index) && (index < [end_index])) && + (([0.into()] <= byte_index) && (byte_index < [element_size_int]))) ==> + ([read_element_byte.clone()] == [read_whole_byte]) + ); + let trigger = self.inner.inner.lowerer.encode_read_byte_expression( + element_bytes.clone(), + byte_index.clone().into(), + self.inner.inner.position, + )?; + let expression = vir_low::Expression::forall( + vec![index, byte_index], + vec![vir_low::Trigger::new(vec![trigger])], + body, + ); + self.add_postcondition(expression); + Ok(()) + } + + // pub(in super::super::super::super) fn add_padding_memory_block_precondition( + // &mut self, + // ) -> SpannedEncodingResult<()> { + // let expression = self.inner.create_padding_memory_block_acc()?; + // self.add_precondition(expression); + // Ok(()) + // } + + // pub(in super::super::super::super) fn add_field_memory_block_precondition( + // &mut self, + // field: &vir_mid::FieldDecl, + // ) -> SpannedEncodingResult<()> { + // let field_block = self.inner.create_field_memory_block_acc(field)?; + // self.add_precondition(field_block); + // Ok(()) + // } + + // pub(in super::super::super::super) fn add_discriminant_precondition( + // &mut self, + // decl: &vir_mid::type_decl::Enum, + // ) -> SpannedEncodingResult<()> { + // let discriminant_block = self.inner.create_discriminant_acc(decl)?; + // self.add_precondition(discriminant_block); + // Ok(()) + // } + + // pub(in super::super::super::super) fn add_variant_memory_block_precondition( + // &mut self, + // discriminant_value: vir_mid::DiscriminantValue, + // variant: &vir_mid::type_decl::Struct, + // ) -> SpannedEncodingResult<()> { + // let expression = self + // .inner + // .create_variant_memory_block_acc(discriminant_value, variant)?; + // self.add_precondition(expression); + // Ok(()) + // } + + // pub(in super::super::super::super) fn create_field_to_bytes_equality( + // &mut self, + // field: &vir_mid::FieldDecl, + // ) -> SpannedEncodingResult { + // let expression = self.inner.create_field_to_bytes_equality(field)?; + // Ok(vir_low::Expression::labelled_old_no_pos(None, expression)) + // } + + // pub(in super::super::super::super) fn add_fields_to_bytes_equalities_postcondition( + // &mut self, + // field_to_bytes_equalities: Vec, + // ) -> SpannedEncodingResult<()> { + // use vir_low::macros::*; + // let address = self.inner.address(); + // let inner = self.inner(); + // let to_bytes = ty! { Bytes }; + // let ty = inner.ty; + // let size_of = inner + // .lowerer + // .encode_type_size_expression2(inner.ty, inner.type_decl)?; + // let memory_block_bytes = inner + // .lowerer + // .encode_memory_block_bytes_expression(address, size_of)?; + // let bytes_quantifier = expr! { + // forall( + // snapshot: {ty.to_snapshot(inner.lowerer)?} :: + // [ { (Snap::to_bytes(snapshot)) } ] + // [ field_to_bytes_equalities.into_iter().conjoin() ] ==> + // ([memory_block_bytes] == (Snap::to_bytes(snapshot))) + // ) + // }; + // self.add_postcondition(bytes_quantifier); + // Ok(()) + // } + + // pub(in super::super::super::super) fn create_variant_to_bytes_equality( + // &mut self, + // discriminant_value: vir_mid::DiscriminantValue, + // variant: &vir_mid::type_decl::Struct, + // decl: &vir_mid::type_decl::Enum, + // safety: vir_mid::ty::EnumSafety, + // ) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let discriminant = self.inner.discriminant.as_ref().unwrap(); + // let ty = self.inner.inner.ty; + // let to_bytes = ty! { Bytes }; + // let snapshot: vir_low::Expression = + // var! { snapshot: {self.inner.inner.ty.to_snapshot(self.inner.inner.lowerer)?} }.into(); + // let variant_index = variant.name.clone().into(); + // let variant_snapshot = self.inner.inner.lowerer.obtain_enum_variant_snapshot( + // ty, + // &variant_index, + // snapshot.clone(), + // self.inner.inner.position, + // )?; + // let variant_address = self.inner.inner.lowerer.encode_enum_variant_address( + // self.inner.inner.ty, + // &variant_index, + // self.inner.address.clone().into(), + // self.inner.inner.position, + // )?; + // let variant_type = &self.inner.inner.ty.clone().variant(variant_index); + // let variant_size_of = self + // .inner + // .inner + // .lowerer + // .encode_type_size_expression2(variant_type, variant)?; + // let memory_block_variant_bytes = self + // .inner + // .inner + // .lowerer + // .encode_memory_block_bytes_expression(variant_address, variant_size_of)?; + // let memory_block_bytes = self + // .inner + // .inner + // .create_memory_block_bytes(self.inner.address.clone().into())?; + // let discriminant_to_bytes = if safety.is_enum() { + // let discriminant_type = &decl.discriminant_type; + // let discriminant_size_of = self + // .inner + // .inner + // .lowerer + // .encode_type_size_expression2(&decl.discriminant_type, &decl.discriminant_type)?; + // let discriminant_field = decl.discriminant_field(); + // let discriminant_address = self.inner.inner.lowerer.encode_field_address( + // self.inner.inner.ty, + // &discriminant_field, + // self.inner.address.clone().into(), + // self.inner.inner.position, + // )?; + // let memory_block_discriminant_bytes = self + // .inner + // .inner + // .lowerer + // .encode_memory_block_bytes_expression(discriminant_address, discriminant_size_of)?; + // let discriminant_call = self.inner.inner.lowerer.obtain_enum_discriminant( + // snapshot.clone(), + // self.inner.inner.ty, + // self.inner.inner.position, + // )?; + // let discriminant_snapshot = self.inner.inner.lowerer.construct_constant_snapshot( + // discriminant_type, + // discriminant_call, + // self.inner.inner.position, + // )?; + // expr! { + // ((old([memory_block_discriminant_bytes])) == + // (Snap::to_bytes([discriminant_snapshot]))) + // } + // } else { + // true.into() + // }; + // let expression = expr! { + // (discriminant == [discriminant_value.into()]) ==> + // ( + // ( + // [discriminant_to_bytes] && + // ((old([memory_block_variant_bytes])) == + // (Snap::to_bytes([variant_snapshot]))) + // ) ==> + // ([memory_block_bytes] == (Snap::to_bytes([snapshot]))) + // ) + // }; + // Ok(expression) + // } + + // pub(in super::super::super::super) fn add_variants_to_bytes_equalities_postcondition( + // &mut self, + // variant_to_bytes_equalities: Vec, + // ) -> SpannedEncodingResult<()> { + // use vir_low::macros::*; + // let ty = self.inner.inner.ty; + // let to_bytes = ty! { Bytes }; + // let expression = expr! { + // forall( + // snapshot: {ty.to_snapshot(self.inner.inner.lowerer)?} :: + // [ { (Snap::to_bytes(snapshot)) } ] + // [ variant_to_bytes_equalities.into_iter().conjoin() ] + // ) + // }; + // self.add_postcondition(expression); + // Ok(()) + // } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_range_split.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_range_split.rs new file mode 100644 index 00000000000..c3185021475 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_range_split.rs @@ -0,0 +1,193 @@ +use super::{ + common::{BuiltinMethodBuilder, BuiltinMethodBuilderMethods}, + memory_block_range_split_join_common::MemoryBlockRangeSplitJoinMethodBuilder, + memory_block_split_join_common::BuiltinMethodSplitJoinBuilderMethods, +}; +use crate::encoder::{ + errors::{BuiltinMethodKind, SpannedEncodingResult}, + middle::core_proof::{ + addresses::AddressesInterface, + lowerer::{DomainsLowererInterface, Lowerer}, + predicates::PredicatesMemoryBlockInterface, + snapshots::{IntoSnapshot, SnapshotValuesInterface}, + type_layouts::TypeLayoutsInterface, + }, +}; +use vir_crate::{ + common::expression::{BinaryOperationHelpers, ExpressionIterator, QuantifierHelpers}, + low::{self as vir_low}, + middle as vir_mid, +}; + +pub(in super::super::super::super) struct MemoryBlockRangeSplitMethodBuilder<'l, 'p, 'v, 'tcx> { + inner: MemoryBlockRangeSplitJoinMethodBuilder<'l, 'p, 'v, 'tcx>, +} + +impl<'l, 'p, 'v, 'tcx> BuiltinMethodBuilderMethods<'l, 'p, 'v, 'tcx> + for MemoryBlockRangeSplitMethodBuilder<'l, 'p, 'v, 'tcx> +{ + fn inner(&mut self) -> &mut BuiltinMethodBuilder<'l, 'p, 'v, 'tcx> { + self.inner.inner() + } +} + +impl<'l, 'p, 'v, 'tcx> BuiltinMethodSplitJoinBuilderMethods<'l, 'p, 'v, 'tcx> + for MemoryBlockRangeSplitMethodBuilder<'l, 'p, 'v, 'tcx> +{ +} + +impl<'l, 'p, 'v, 'tcx> MemoryBlockRangeSplitMethodBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + kind: vir_low::MethodKind, + method_name: &'l str, + ty: &'l vir_mid::Type, + type_decl: &'l vir_mid::TypeDecl, + error_kind: BuiltinMethodKind, + ) -> SpannedEncodingResult { + Ok(Self { + inner: MemoryBlockRangeSplitJoinMethodBuilder::new( + lowerer, + kind, + method_name, + ty, + type_decl, + error_kind, + )?, + }) + } + + pub(in super::super::super::super) fn build(self) -> vir_low::MethodDecl { + self.inner.build() + } + + pub(in super::super::super::super) fn create_parameters( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.create_parameters() + } + + // pub(in super::super::super::super) fn add_permission_amount_positive_precondition( + // &mut self, + // ) -> SpannedEncodingResult<()> { + // self.inner.add_permission_amount_positive_precondition() + // } + + pub(in super::super::super::super) fn add_whole_memory_block_precondition( + &mut self, + ) -> SpannedEncodingResult<()> { + let memory_block = self.inner.create_whole_block_acc()?; + self.add_precondition(memory_block); + Ok(()) + } + + pub(in super::super::super::super) fn add_memory_block_range_postcondition( + &mut self, + ) -> SpannedEncodingResult<()> { + let memory_block_range = self.inner.create_memory_block_range_acc()?; + self.add_postcondition(memory_block_range); + Ok(()) + } + + // FIXME: Code duplication. + pub(in super::super::super::super) fn add_byte_values_preserved_postcondition( + &mut self, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let element_size = self + .inner + .inner + .lowerer + .encode_type_size_expression2(self.inner.inner.ty, self.inner.inner.type_decl)?; + let length = self.inner.length()?; + let whole_size = self + .inner + .inner + .lowerer + .encode_type_size_expression_repetitions( + self.inner.inner.ty, + self.inner.inner.type_decl, + length, + self.inner.inner.position, + )?; + let size_type = self.inner.inner.lowerer.size_type_mid()?; + var_decls! { + index: Int, + byte_index: Int + } + let address: vir_low::Expression = self.inner.address.clone().into(); + let element_address = self.inner.inner.lowerer.address_offset( + element_size.clone(), + address.clone(), + index.clone().into(), + self.inner.inner.position, + )?; + // let predicate = + // self.encode_memory_block_stack_acc(element_address.clone(), size.clone(), position)?; + let start_index = self.inner.inner.lowerer.obtain_constant_value( + &size_type, + self.inner.start_index.clone().into(), + self.inner.inner.position, + )?; + let end_index = self.inner.inner.lowerer.obtain_constant_value( + &size_type, + self.inner.end_index.clone().into(), + self.inner.inner.position, + )?; + let element_bytes = self + .inner + .inner + .lowerer + .encode_memory_block_bytes_expression(element_address.clone(), element_size.clone())?; + let whole_bytes = self + .inner + .inner + .lowerer + .encode_memory_block_bytes_expression(address.clone(), whole_size.clone())?; + let read_element_byte = self.inner.inner.lowerer.encode_read_byte_expression( + element_bytes.clone(), + byte_index.clone().into(), + self.inner.inner.position, + )?; + let block_size = self.inner.inner.lowerer.obtain_constant_value( + &size_type, + element_size.clone(), + self.inner.inner.position, + )?; + // let block_start_index = vir_low::Expression::multiply( + // block_size, + // index.clone().into(), + // ); + // let whole_byte_index = vir_low::Expression::add(block_start_index, byte_index.clone().into()); + let whole_byte_index = self.inner.inner.lowerer.create_domain_func_app( + "Arithmetic", + "mul_add", + vec![block_size, index.clone().into(), byte_index.clone().into()], + vir_low::Type::Int, + self.inner.inner.position, + )?; + let read_whole_byte = self.inner.inner.lowerer.encode_read_byte_expression( + vir_low::Expression::labelled_old(None, whole_bytes.clone(), self.inner.inner.position), + whole_byte_index, + self.inner.inner.position, + )?; + let element_size_int = self.inner.inner.lowerer.obtain_constant_value( + &size_type, + element_size.clone(), + self.inner.inner.position, + )?; + let body = expr!( + ((([start_index] <= index) && (index < [end_index])) && + (([0.into()] <= byte_index) && (byte_index < [element_size_int]))) ==> + ([read_element_byte.clone()] == [read_whole_byte]) + ); + let trigger = read_element_byte; + let expression = vir_low::Expression::forall( + vec![index, byte_index], + vec![vir_low::Trigger::new(vec![trigger])], + body, + ); + self.add_postcondition(expression); + Ok(()) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_range_split_join_common.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_range_split_join_common.rs new file mode 100644 index 00000000000..3137d88e6f1 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/memory_block_range_split_join_common.rs @@ -0,0 +1,258 @@ +use super::common::{BuiltinMethodBuilder, BuiltinMethodBuilderMethods}; +use crate::encoder::{ + errors::{BuiltinMethodKind, SpannedEncodingResult}, + middle::core_proof::{ + addresses::AddressesInterface, + lowerer::Lowerer, + predicates::PredicatesMemoryBlockInterface, + snapshots::{IntoSnapshot, SnapshotBytesInterface, SnapshotValuesInterface}, + type_layouts::TypeLayoutsInterface, + }, +}; +use vir_crate::{ + low::{self as vir_low}, + middle as vir_mid, +}; + +pub(in super::super) struct MemoryBlockRangeSplitJoinMethodBuilder<'l, 'p, 'v, 'tcx> { + pub(super) inner: BuiltinMethodBuilder<'l, 'p, 'v, 'tcx>, + pub(super) address: vir_low::VariableDecl, + pub(super) start_index: vir_low::VariableDecl, + pub(super) end_index: vir_low::VariableDecl, +} + +impl<'l, 'p, 'v, 'tcx> BuiltinMethodBuilderMethods<'l, 'p, 'v, 'tcx> + for MemoryBlockRangeSplitJoinMethodBuilder<'l, 'p, 'v, 'tcx> +{ + fn inner(&mut self) -> &mut BuiltinMethodBuilder<'l, 'p, 'v, 'tcx> { + &mut self.inner + } +} + +pub(in super::super) trait BuiltinMethodSplitJoinBuilderMethods<'l, 'p, 'v, 'tcx>: + Sized + BuiltinMethodBuilderMethods<'l, 'p, 'v, 'tcx> +where + 'p: 'l, + 'v: 'p, + 'tcx: 'v, +{ +} + +impl<'l, 'p, 'v, 'tcx> MemoryBlockRangeSplitJoinMethodBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + kind: vir_low::MethodKind, + method_name: &'l str, + ty: &'l vir_mid::Type, + type_decl: &'l vir_mid::TypeDecl, + error_kind: BuiltinMethodKind, + ) -> SpannedEncodingResult { + let address = vir_low::VariableDecl::new("address", lowerer.address_type()?); + let size_type = lowerer.size_type()?; + let start_index = vir_low::VariableDecl::new("start_index", size_type.clone()); + let end_index = vir_low::VariableDecl::new("end_index", size_type.clone()); + let inner = + BuiltinMethodBuilder::new(lowerer, kind, method_name, ty, type_decl, error_kind)?; + Ok(Self { + inner, + address, + start_index, + end_index, + }) + } + + pub(in super::super) fn build(self) -> vir_low::MethodDecl { + self.inner.build() + } + + // pub(in super::super) fn address(&self) -> vir_low::Expression { + // self.address.clone().into() + // } + + pub(in super::super) fn create_parameters(&mut self) -> SpannedEncodingResult<()> { + self.inner.parameters.push(self.address.clone()); + self.inner.parameters.push(self.start_index.clone()); + self.inner.parameters.push(self.end_index.clone()); + Ok(()) + } + + pub(in super::super) fn length(&mut self) -> SpannedEncodingResult { + let size_type = self.inner.lowerer.size_type_mid()?; + self.inner.lowerer.construct_binary_op_snapshot( + vir_mid::BinaryOpKind::Sub, + &size_type, + &size_type, + self.end_index.clone().into(), + self.start_index.clone().into(), + self.inner.position, + ) + } + + // pub(in super::super) fn add_permission_amount_positive_precondition( + // &mut self, + // ) -> SpannedEncodingResult<()> { + // let expression = self + // .inner + // .create_permission_amount_positive(&self.permission_amount)?; + // self.add_precondition(expression); + // Ok(()) + // } + + pub(in super::super) fn create_whole_block_acc( + &mut self, + ) -> SpannedEncodingResult { + // self.create_memory_block(self.address.clone().into()) + use vir_low::macros::*; + let length = self.length()?; + let inner = self.inner(); + let size_of = inner.lowerer.encode_type_size_expression_repetitions( + inner.ty, + inner.type_decl, + length, + inner.position, + )?; + let address = &self.address; + Ok(expr! { + acc(MemoryBlock(address, [size_of])) + }) + } + + pub(in super::super) fn create_memory_block_range_acc( + &mut self, + ) -> SpannedEncodingResult { + // self.create_memory_block(self.address.clone().into()) + let size_of = self + .inner + .lowerer + .encode_type_size_expression2(self.inner.ty, self.inner.type_decl)?; + self.inner.lowerer.encode_memory_block_range_acc( + self.address.clone().into(), + size_of, + self.start_index.clone().into(), + self.end_index.clone().into(), + self.inner.position, + ) + } + + // pub(in super::super) fn padding_size(&mut self) -> SpannedEncodingResult { + // self.inner + // .lowerer + // .encode_type_padding_size_expression(self.inner.ty) + // } + + // pub(in super::super) fn create_padding_memory_block_acc( + // &mut self, + // ) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let address = self.address.clone().into(); + // let padding_size = self.padding_size()?; + // let permission_amount = self.permission_amount.clone().into(); + // let expression = expr! { + // acc(MemoryBlock([address], [padding_size]), [permission_amount]) + // }; + // Ok(expression) + // } + + // pub(in super::super) fn create_field_memory_block_acc( + // &mut self, + // field: &vir_mid::FieldDecl, + // ) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let field_address = self.inner.lowerer.encode_field_address( + // self.inner.ty, + // field, + // self.address.clone().into(), + // self.inner.position, + // )?; + // let field_size_of = self + // .inner + // .lowerer + // .encode_type_size_expression2(&field.ty, &field.ty)?; + // let permission_amount = self.permission_amount.clone().into(); + // let field_block = expr! { + // acc(MemoryBlock([field_address], [field_size_of]), [permission_amount]) + // }; + // Ok(field_block) + // } + + // pub(in super::super) fn create_discriminant_acc( + // &mut self, + // decl: &vir_mid::type_decl::Enum, + // ) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let discriminant_size_of = self + // .inner + // .lowerer + // .encode_type_size_expression2(&decl.discriminant_type, &decl.discriminant_type)?; + // let discriminant_field = decl.discriminant_field(); + // let discriminant_address = self.inner.lowerer.encode_field_address( + // self.inner.ty, + // &discriminant_field, + // self.address.clone().into(), + // self.inner.position, + // )?; + // let discriminant_block = expr! { + // acc(MemoryBlock([discriminant_address], [discriminant_size_of])) + // }; + // Ok(discriminant_block) + // } + + // pub(in super::super) fn create_variant_memory_block_acc( + // &mut self, + // discriminant_value: vir_mid::DiscriminantValue, + // variant: &vir_mid::type_decl::Struct, + // ) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let variant_index = variant.name.clone().into(); + // let variant_address = self.inner.lowerer.encode_enum_variant_address( + // self.inner.ty, + // &variant_index, + // self.address.clone().into(), + // Default::default(), + // )?; + // let variant_type = self.inner.ty.clone().variant(variant_index); + // let variant_size_of = self + // .inner + // .lowerer + // // .encode_type_size_expression(&variant_type)?; + // // FIXME: This is probably wrong: test enums containing arrays. + // .encode_type_size_expression2(&variant_type, &variant_type)?; + // let discriminant = self.discriminant.as_ref().unwrap().clone().into(); + // let expression = expr! { + // ([discriminant] == [discriminant_value.into()]) ==> + // (acc(MemoryBlock([variant_address], [variant_size_of]))) + // }; + // Ok(expression) + // } + + // pub(in super::super) fn create_field_to_bytes_equality( + // &mut self, + // field: &vir_mid::FieldDecl, + // ) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let address = self.address(); + // let inner = self.inner(); + // inner.lowerer.encode_snapshot_to_bytes_function(inner.ty)?; + // let field_address = + // inner + // .lowerer + // .encode_field_address(inner.ty, field, address, inner.position)?; + // let field_size_of = inner + // .lowerer + // .encode_type_size_expression2(&field.ty, &field.ty)?; + // let memory_block_field_bytes = inner + // .lowerer + // .encode_memory_block_bytes_expression(field_address, field_size_of)?; + // let snapshot = var! { snapshot: {inner.ty.to_snapshot(inner.lowerer)?} }.into(); + // let field_snapshot = inner.lowerer.obtain_struct_field_snapshot( + // inner.ty, + // field, + // snapshot, + // inner.position, + // )?; + // let to_bytes = ty! { Bytes }; + // Ok(expr! { + // (([memory_block_field_bytes])) == (Snap<(&field.ty)>::to_bytes([field_snapshot])) + // }) + // } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/mod.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/mod.rs index 2dcd5ae34cc..6238132d5a6 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/mod.rs @@ -5,9 +5,13 @@ pub(super) mod duplicate_frac_ref; pub(super) mod memory_block_copy; pub(super) mod memory_block_into; pub(super) mod memory_block_join; +pub(super) mod memory_block_range_join; +pub(super) mod memory_block_range_split_join_common; pub(super) mod memory_block_split; +pub(super) mod memory_block_range_split; pub(super) mod memory_block_split_join_common; pub(super) mod move_copy_place_common; pub(super) mod move_place; +pub(super) mod restore_raw_borrowed; pub(super) mod write_address_constant; pub(super) mod write_place_constant; diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/move_copy_place_common.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/move_copy_place_common.rs index 55d21ef39af..816277d2511 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/move_copy_place_common.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/move_copy_place_common.rs @@ -9,7 +9,7 @@ use crate::encoder::{ builtin_methods::{calls::interface::CallContext, BuiltinMethodsInterface}, lowerer::Lowerer, places::PlacesInterface, - predicates::OwnedNonAliasedUseBuilder, + predicates::PredicatesOwnedInterface, snapshots::{IntoSnapshot, SnapshotValidityInterface}, }, }; @@ -78,37 +78,33 @@ impl<'l, 'p, 'v, 'tcx> MoveCopyPlaceMethodBuilder<'l, 'p, 'v, 'tcx> { pub(in super::super::super::super) fn create_source_owned( &mut self, + must_be_predicate: bool, ) -> SpannedEncodingResult { - let mut builder = OwnedNonAliasedUseBuilder::new( - self.inner.lowerer, + self.inner.lowerer.owned_non_aliased_full_vars( CallContext::BuiltinMethod, self.inner.ty, self.inner.type_decl, - self.source_place.clone().into(), - self.source_root_address.clone().into(), - self.source_snapshot.clone().into(), - )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - Ok(builder.build()) + &self.source_place, + &self.source_root_address, + &self.source_snapshot, + must_be_predicate, + ) } // FIXME: Remove duplicates with other builders. pub(in super::super::super::super) fn create_target_owned( &mut self, + must_be_predicate: bool, ) -> SpannedEncodingResult { - let mut builder = OwnedNonAliasedUseBuilder::new( - self.inner.lowerer, + self.inner.lowerer.owned_non_aliased_full_vars( CallContext::BuiltinMethod, self.inner.ty, self.inner.type_decl, - self.target_place.clone().into(), - self.target_root_address.clone().into(), - self.source_snapshot.clone().into(), - )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - Ok(builder.build()) + &self.target_place, + &self.target_root_address, + &self.source_snapshot, + must_be_predicate, + ) } // FIXME: Remove duplicate with add_source_validity_precondition diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/move_place.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/move_place.rs index 79778481385..dba455c109c 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/move_place.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/move_place.rs @@ -9,12 +9,16 @@ use crate::encoder::{ builtin_methods::{ calls::interface::CallContext, BuiltinMethodCallsInterface, BuiltinMethodsInterface, }, + lifetimes::LifetimesInterface, lowerer::Lowerer, places::PlacesInterface, + predicates::PredicatesOwnedInterface, + references::ReferencesInterface, snapshots::SnapshotValuesInterface, }, }; use vir_crate::{ + common::expression::UnaryOperationHelpers, low::{self as vir_low}, middle as vir_mid, }; @@ -97,14 +101,16 @@ impl<'l, 'p, 'v, 'tcx> MovePlaceMethodBuilder<'l, 'p, 'v, 'tcx> { pub(in super::super::super::super) fn create_source_owned( &mut self, + must_be_predicate: bool, ) -> SpannedEncodingResult { - self.inner.create_source_owned() + self.inner.create_source_owned(must_be_predicate) } pub(in super::super::super::super) fn create_target_owned( &mut self, + must_be_predicate: bool, ) -> SpannedEncodingResult { - self.inner.create_target_owned() + self.inner.create_target_owned(must_be_predicate) } pub(in super::super::super::super) fn add_target_validity_postcondition( @@ -342,6 +348,71 @@ impl<'l, 'p, 'v, 'tcx> MovePlaceMethodBuilder<'l, 'p, 'v, 'tcx> { Ok(()) } + pub(in super::super::super::super) fn add_dead_lifetime_hack( + &mut self, + lifetime: &vir_mid::ty::LifetimeConst, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let lifetime_alive = self + .inner + .inner + .lowerer + .encode_lifetime_const_into_pure_is_alive_variable(lifetime)?; + let guard = vir_low::Expression::not(lifetime_alive.into()); + let source_current_snapshot = self.inner.inner.lowerer.reference_target_current_snapshot( + self.inner.inner.ty, + self.inner.source_snapshot.clone().into(), + self.inner.inner.position, + )?; + let source_final_snapshot = self.inner.inner.lowerer.reference_target_final_snapshot( + self.inner.inner.ty, + self.inner.source_snapshot.clone().into(), + self.inner.inner.position, + )?; + let target_snapshot = self.inner.inner.lowerer.owned_non_aliased_snap( + CallContext::BuiltinMethod, + self.inner.inner.ty, + self.inner.inner.type_decl, + self.inner.target_place.clone().into(), + self.inner.target_root_address.clone().into(), + self.inner.inner.position, + )?; + let target_current_snapshot = self.inner.inner.lowerer.reference_target_current_snapshot( + self.inner.inner.ty, + target_snapshot.clone(), + self.inner.inner.position, + )?; + let target_final_snapshot = self.inner.inner.lowerer.reference_target_final_snapshot( + self.inner.inner.ty, + target_snapshot, + self.inner.inner.position, + )?; + let body = vec![ + vir_low::Statement::comment( + "FIXME: This is a hack. Because the lifetime is dead, the reference \ + is dangling and there is no predicate that would witness that \ + the value of the dereference is the source of the dereference. \ + This is also the reason why it is sound just to assume that the \ + two are equal. A proper solution should use a custom equality function \ + that equates the targets only if lifetimes are alive." + .to_string(), + ), + stmtp! { self.inner.inner.position => + assume ([source_current_snapshot] == [target_current_snapshot]) + }, + stmtp! { self.inner.inner.position => + assume ([source_final_snapshot] == [target_final_snapshot]) + }, + // assume destructor$Snap$ref$Unique$slice$struct$m_T1$$$target_current(source_snapshot) == destructor$Snap$ref$Unique$slice$struct$m_T1$$$target_current(snap_owned_non_aliased$ref$Unique$slice$struct$m_T1$(target_place, target_root_address, lft_early_bound_0$alive, lft_early_bound_0)) + + // assume destructor$Snap$ref$Unique$slice$struct$m_T1$$$target_final(source_snapshot) == destructor$Snap$ref$Unique$slice$struct$m_T1$$$target_final(snap_owned_non_aliased$ref$Unique$slice$struct$m_T1$(target_place, target_root_address, lft_early_bound_0$alive, lft_early_bound_0)) + ]; + let statement = + vir_low::Statement::conditional(guard, body, Vec::new(), self.inner.inner.position); + self.add_statement(statement); + Ok(()) + } + pub(in super::super::super::super) fn duplicate_frac_ref( &mut self, ) -> SpannedEncodingResult<()> { diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/restore_raw_borrowed.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/restore_raw_borrowed.rs new file mode 100644 index 00000000000..74ab9df2afa --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/restore_raw_borrowed.rs @@ -0,0 +1,124 @@ +use super::common::{BuiltinMethodBuilder, BuiltinMethodBuilderMethods}; +use crate::encoder::{ + errors::{BuiltinMethodKind, SpannedEncodingResult}, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + lowerer::Lowerer, + places::PlacesInterface, + predicates::{PredicatesOwnedInterface, RestorationInterface}, + snapshots::IntoSnapshot, + }, +}; +use vir_crate::{ + low::{self as vir_low}, + middle as vir_mid, +}; + +pub(in super::super::super::super) struct RestoreRawBorrowedMethodBuilder<'l, 'p, 'v, 'tcx> { + inner: BuiltinMethodBuilder<'l, 'p, 'v, 'tcx>, + borrowing_address: vir_low::VariableDecl, + restored_place: vir_low::VariableDecl, + restored_root_address: vir_low::VariableDecl, + snapshot: vir_low::VariableDecl, +} + +impl<'l, 'p, 'v, 'tcx> BuiltinMethodBuilderMethods<'l, 'p, 'v, 'tcx> + for RestoreRawBorrowedMethodBuilder<'l, 'p, 'v, 'tcx> +{ + fn inner(&mut self) -> &mut BuiltinMethodBuilder<'l, 'p, 'v, 'tcx> { + &mut self.inner + } +} + +impl<'l, 'p, 'v, 'tcx> RestoreRawBorrowedMethodBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + kind: vir_low::MethodKind, + method_name: &'l str, + ty: &'l vir_mid::Type, + type_decl: &'l vir_mid::TypeDecl, + error_kind: BuiltinMethodKind, + ) -> SpannedEncodingResult { + let borrowing_address = + vir_low::VariableDecl::new("borrowing_address", lowerer.address_type()?); + let restored_place = vir_low::VariableDecl::new("restored_place", lowerer.place_type()?); + let restored_root_address = + vir_low::VariableDecl::new("restored_root_address", lowerer.address_type()?); + let snapshot = vir_low::VariableDecl::new("snapshot", ty.to_snapshot(lowerer)?); + let inner = + BuiltinMethodBuilder::new(lowerer, kind, method_name, ty, type_decl, error_kind)?; + Ok(Self { + inner, + borrowing_address, + restored_place, + restored_root_address, + snapshot, + }) + } + + pub(in super::super::super::super) fn build(self) -> vir_low::MethodDecl { + self.inner.build() + } + + pub(in super::super::super::super) fn create_parameters( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.parameters.push(self.borrowing_address.clone()); + self.inner.parameters.push(self.restored_place.clone()); + self.inner + .parameters + .push(self.restored_root_address.clone()); + self.inner.parameters.push(self.snapshot.clone()); + self.create_lifetime_parameters()?; + self.create_const_parameters()?; + Ok(()) + } + + pub(in super::super::super::super) fn add_aliased_source_precondition( + &mut self, + ) -> SpannedEncodingResult<()> { + let aliased_root_place = self + .inner + .lowerer + .encode_aliased_place_root(self.inner.position)?; + unimplemented!(); + // let aliased_predicate = self.inner.lowerer.owned_aliased( + // CallContext::BuiltinMethod, + // self.inner.ty, + // self.inner.ty, + // aliased_root_place, + // self.borrowing_address.clone().into(), + // self.snapshot.clone().into(), + // None, + // )?; + // self.add_precondition(aliased_predicate); + Ok(()) + } + + pub(in super::super::super::super) fn add_shift_precondition( + &mut self, + ) -> SpannedEncodingResult<()> { + let restore_raw_borrowed = self.inner.lowerer.restore_raw_borrowed( + self.inner.ty, + self.restored_place.clone().into(), + self.restored_root_address.clone().into(), + )?; + self.add_precondition(restore_raw_borrowed); + Ok(()) + } + + pub(crate) fn add_non_aliased_target_postcondition(&mut self) -> SpannedEncodingResult<()> { + let non_aliased_predicate = self.inner.lowerer.owned_non_aliased_full_vars( + CallContext::BuiltinMethod, + self.inner.ty, + self.inner.ty, + &self.restored_place, + &self.restored_root_address, + &self.snapshot, + false, + )?; + self.add_postcondition(non_aliased_predicate); + Ok(()) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/write_place_constant.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/write_place_constant.rs index 7d0a79c8a42..a092dd03954 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/write_place_constant.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/decls/write_place_constant.rs @@ -9,7 +9,7 @@ use crate::encoder::{ builtin_methods::{BuiltinMethodsInterface, CallContext}, lowerer::Lowerer, places::PlacesInterface, - predicates::{OwnedNonAliasedUseBuilder, PredicatesOwnedInterface}, + predicates::PredicatesOwnedInterface, snapshots::{IntoSnapshot, SnapshotValidityInterface, SnapshotValuesInterface}, }, }; @@ -91,22 +91,20 @@ impl<'l, 'p, 'v, 'tcx> WritePlaceConstantMethodBuilder<'l, 'p, 'v, 'tcx> { // FIXME: Remove duplicates with other builders. pub(in super::super::super::super) fn create_target_owned( &mut self, + must_be_predicate: bool, ) -> SpannedEncodingResult { self.inner .lowerer .mark_owned_non_aliased_as_unfolded(self.inner.ty)?; - let mut builder = OwnedNonAliasedUseBuilder::new( - self.inner.lowerer, + self.inner.lowerer.owned_non_aliased_full_vars( CallContext::BuiltinMethod, self.inner.ty, self.inner.type_decl, - self.target_place.clone().into(), - self.target_root_address.clone().into(), - self.source_snapshot.clone().into(), - )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - Ok(builder.build()) + &self.target_place, + &self.target_root_address, + &self.source_snapshot, + must_be_predicate, + ) } pub(in super::super::super::super) fn add_source_validity_precondition( diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/mod.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/mod.rs index 5152118b533..acde36ccc11 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/builders/mod.rs @@ -8,7 +8,10 @@ pub(in super::super) use self::decls::{ memory_block_copy::MemoryBlockCopyMethodBuilder, memory_block_into::IntoMemoryBlockMethodBuilder, memory_block_join::MemoryBlockJoinMethodBuilder, + memory_block_range_join::MemoryBlockRangeJoinMethodBuilder, + memory_block_range_split::MemoryBlockRangeSplitMethodBuilder, memory_block_split::MemoryBlockSplitMethodBuilder, move_place::MovePlaceMethodBuilder, + restore_raw_borrowed::RestoreRawBorrowedMethodBuilder, write_address_constant::WriteAddressConstantMethodBuilder, write_place_constant::WritePlaceConstantMethodBuilder, }; diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/calls/interface.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/calls/interface.rs index 7b9234efc8d..53eec156539 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/calls/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/calls/interface.rs @@ -98,6 +98,21 @@ pub(in super::super::super) trait BuiltinMethodCallsInterface { ) -> SpannedEncodingResult where G: WithLifetimes + WithConstArguments; + + #[allow(clippy::too_many_arguments)] + fn call_restore_raw_borrowed_method( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + position: vir_low::Position, + borrowing_address: vir_low::Expression, + restored_place: vir_low::Expression, + restored_root_address: vir_low::Expression, + snapshot: vir_low::Expression, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; } impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodCallsInterface for Lowerer<'p, 'v, 'tcx> { @@ -249,4 +264,33 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodCallsInterface for Lowerer<'p, 'v, 'tcx> builder.add_const_arguments()?; Ok(builder.build()) } + + fn call_restore_raw_borrowed_method( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + position: vir_low::Position, + borrowing_address: vir_low::Expression, + restored_place: vir_low::Expression, + restored_root_address: vir_low::Expression, + snapshot: vir_low::Expression, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let mut builder = BuiltinMethodCallBuilder::new( + self, + context, + "restore_raw_borrowed", + ty, + generics, + position, + )?; + builder.add_argument(borrowing_address); + builder.add_argument(restored_place); + builder.add_argument(restored_root_address); + builder.add_argument(snapshot); + Ok(builder.build()) + } } diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/interface.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/interface.rs index 30f6a32d4d0..78f7b4ae353 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/interface.rs @@ -1,7 +1,7 @@ use super::{ builders::{ ChangeUniqueRefPlaceMethodBuilder, DuplicateFracRefMethodBuilder, - MemoryBlockCopyMethodBuilder, + MemoryBlockCopyMethodBuilder, RestoreRawBorrowedMethodBuilder, }, BuiltinMethodCallsInterface, CallContext, }; @@ -13,25 +13,33 @@ use crate::encoder::{ block_markers::BlockMarkersInterface, builtin_methods::builders::{ BuiltinMethodBuilderMethods, CopyPlaceMethodBuilder, IntoMemoryBlockMethodBuilder, - MemoryBlockJoinMethodBuilder, MemoryBlockSplitMethodBuilder, MovePlaceMethodBuilder, - WriteAddressConstantMethodBuilder, WritePlaceConstantMethodBuilder, + MemoryBlockJoinMethodBuilder, MemoryBlockRangeJoinMethodBuilder, + MemoryBlockRangeSplitMethodBuilder, MemoryBlockSplitMethodBuilder, + MovePlaceMethodBuilder, WriteAddressConstantMethodBuilder, + WritePlaceConstantMethodBuilder, }, compute_address::ComputeAddressInterface, errors::ErrorsInterface, + footprint::FootprintInterface, + heap::HeapInterface, lifetimes::LifetimesInterface, lowerer::{ DomainsLowererInterface, Lowerer, MethodsLowererInterface, PredicatesLowererInterface, VariablesLowererInterface, }, places::PlacesInterface, + pointers::PointersInterface, predicates::{ - OwnedNonAliasedUseBuilder, PredicatesMemoryBlockInterface, PredicatesOwnedInterface, + OwnedNonAliasedSnapCallBuilder, OwnedNonAliasedUseBuilder, + PredicatesMemoryBlockInterface, PredicatesOwnedInterface, RestorationInterface, }, references::ReferencesInterface, snapshots::{ - BuiltinFunctionsInterface, IntoBuiltinMethodSnapshot, IntoProcedureFinalSnapshot, - IntoProcedureSnapshot, IntoPureSnapshot, IntoSnapshot, SnapshotBytesInterface, - SnapshotValidityInterface, SnapshotValuesInterface, SnapshotVariablesInterface, + AssertionToSnapshotConstructor, BuiltinFunctionsInterface, IntoBuiltinMethodSnapshot, + IntoProcedureFinalSnapshot, IntoProcedureSnapshot, IntoPureSnapshot, IntoSnapshot, + IntoSnapshotLowerer, PredicateKind, SelfFramingAssertionToSnapshot, + SnapshotBytesInterface, SnapshotValidityInterface, SnapshotValuesInterface, + SnapshotVariablesInterface, }, type_layouts::TypeLayoutsInterface, }, @@ -40,7 +48,10 @@ use itertools::Itertools; use rustc_hash::FxHashSet; use vir_crate::{ common::{ - expression::{ExpressionIterator, UnaryOperationHelpers}, + check_mode::CheckMode, + expression::{ + BinaryOperationHelpers, ExpressionIterator, QuantifierHelpers, UnaryOperationHelpers, + }, identifier::WithIdentifier, }, low::{self as vir_low, macros::method_name}, @@ -50,6 +61,10 @@ use vir_crate::{ }, }; +// FIXME: Move this to some proper place. It is shared with the snap function +// encoder. +pub(in super::super) use super::assertion_encoder::AssertionEncoder; + #[derive(Default)] pub(in super::super) struct BuiltinMethodsState { encoded_write_place_constant_methods: FxHashSet, @@ -61,7 +76,9 @@ pub(in super::super) struct BuiltinMethodsState { encoded_owned_non_aliased_havoc_methods: FxHashSet, encoded_memory_block_copy_methods: FxHashSet, encoded_memory_block_split_methods: FxHashSet, + encoded_memory_block_range_split_methods: FxHashSet, encoded_memory_block_join_methods: FxHashSet, + encoded_memory_block_range_join_methods: FxHashSet, encoded_memory_block_havoc_methods: FxHashSet, encoded_into_memory_block_methods: FxHashSet, encoded_assign_methods: FxHashSet, @@ -73,7 +90,9 @@ pub(in super::super) struct BuiltinMethodsState { encoded_lft_tok_sep_take_methods: FxHashSet, encoded_lft_tok_sep_return_methods: FxHashSet, encoded_open_close_mut_ref_methods: FxHashSet, + encoded_restore_raw_borrowed_methods: FxHashSet, encoded_bor_shorten_methods: FxHashSet, + encoded_stashed_owned_aliased_predicates: FxHashSet, } trait Private { @@ -232,10 +251,46 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { let perm_amount = value .lifetime_token_permission .to_procedure_snapshot(self)?; - self.encode_place_arguments(arguments, &value.deref_place, false)?; - if value.uniqueness.is_unique() { - let snapshot_final = value.deref_place.to_procedure_final_snapshot(self)?; - arguments.push(snapshot_final); + arguments.push(self.encode_expression_as_place(&value.deref_place)?); + arguments.push(self.extract_root_address(&value.deref_place)?); + // self.encode_place_arguments(arguments, &value.deref_place, false)?; + if self.check_mode.unwrap() == CheckMode::PurificationFunctional { + arguments.push(value.deref_place.to_procedure_snapshot(self)?); + } else { + let place = self.encode_expression_as_place(&value.deref_place)?; + let root_address = self.extract_root_address(&value.deref_place)?; + let ty = value.deref_place.get_type(); + let TODO_target_slice_len = None; + match value.uniqueness { + vir_mid::ty::Uniqueness::Unique => { + let snapshot_current = self.unique_ref_snap( + CallContext::Procedure, + ty, + ty, + place, + root_address, + deref_lifetime.clone().into(), + TODO_target_slice_len, + false, + )?; + arguments.push(snapshot_current); + let snapshot_final = + value.deref_place.to_procedure_final_snapshot(self)?; + arguments.push(snapshot_final); + } + vir_mid::ty::Uniqueness::Shared => { + let snapshot_current = self.frac_ref_snap( + CallContext::Procedure, + ty, + ty, + place, + root_address, + deref_lifetime.clone().into(), + TODO_target_slice_len, + )?; + arguments.push(snapshot_current); + } + } } arguments.extend(self.create_lifetime_arguments( CallContext::Procedure, @@ -279,6 +334,9 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { }; arguments.push(len); } + vir_mid::Rvalue::Cast(value) => { + self.encode_operand_arguments(arguments, &value.operand, true)?; + } vir_mid::Rvalue::UnaryOp(value) => { self.encode_operand_arguments(arguments, &value.argument, true)?; } @@ -312,6 +370,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { for operand in &aggr_value.operands { self.encode_operand_arguments(arguments, operand, false)?; } + if self.use_heap_variable()? && aggr_value.ty.is_struct() { + let heap = self.heap_variable_version_at_label(&None)?; + arguments.push(heap.into()); + } } } Ok(()) @@ -515,6 +577,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { &target_place, &target_address, &result_value, + false, )?; posts.push(predicate); self.encode_assign_method_rvalue( @@ -627,19 +690,37 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { operand_address: Address, operand_value: { ty.to_snapshot(self)? } }; - let predicate = self.owned_non_aliased_full_vars( + let non_aliased_predicate = self.owned_non_aliased_full_vars( CallContext::BuiltinMethod, ty, ty, &operand_place, &operand_address, &operand_value, + false, )?; let compute_address = ty!(Address); let address = expr! { ComputeAddress::compute_address(operand_place, operand_address) }; - pres.push(predicate.clone()); - posts.push(predicate); + pres.push(non_aliased_predicate); + let aliased_root_place = self.encode_aliased_place_root(position)?; + let aliased_predicate: vir_low::Expression = unimplemented!(); + // let aliased_predicate = self.owned_aliased( + // CallContext::BuiltinMethod, + // ty, + // ty, + // aliased_root_place, + // address.clone(), + // operand_value.clone().into(), + // None, + // )?; + let restore_raw_borrowed = self.restore_raw_borrowed( + ty, + operand_place.clone().into(), + operand_address.clone().into(), + )?; + posts.push(aliased_predicate); + posts.push(restore_raw_borrowed); parameters.push(operand_place); parameters.push(operand_address); parameters.push(operand_value); @@ -650,6 +731,26 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { parameters.push(length.clone()); length.into() } + vir_mid::Rvalue::Cast(value) => { + let source_type = value.operand.expression.get_type(); + match (&value.ty, source_type) { + (vir_mid::Type::Pointer(_), vir_mid::Type::Pointer(_)) => { + let operand_value = self.encode_assign_operand( + parameters, + pres, + posts, + 1, + &value.operand, + position, + true, + )?; + let address = + self.pointer_address(source_type, operand_value.into(), position)?; + self.construct_constant_snapshot(result_type, address, position)? + } + (t, s) => unimplemented!("({t}) {s}"), + } + } vir_mid::Rvalue::UnaryOp(value) => { let operand_value = self.encode_assign_operand( parameters, @@ -775,7 +876,167 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { self.construct_enum_snapshot(&value.ty, variant_constructor, position)? } vir_mid::Type::Struct(_) => { - self.construct_struct_snapshot(&value.ty, arguments, position)? + // FIXME: Code duplication with encode_owned_non_aliased_snapshot. + let heap = if self.use_heap_variable()? { + let heap_name = self.heap_variable_name()?; + let heap = vir_low::VariableDecl::new(heap_name, self.heap_type()?); + parameters.push(heap.clone()); + Some(heap) + } else { + None + }; + // // FIXME: Do not hardcode variables here. + // // Instead pass in the original ones. + // var_decls!(target_place: Place, target_address: Address); + let decl = self.encoder.get_type_decl_mid(&value.ty)?.unwrap_struct(); + if let Some(invariant) = decl.structural_invariant { + assert_eq!(arguments.len(), decl.fields.len()); + // Assert the invariant for the struct in the precondition. + let mut invariant_encoder = + SelfFramingAssertionToSnapshot::for_assign_precondition( + arguments.clone(), + decl.fields.clone(), + heap.clone(), + ); + for assertion in &invariant { + let encoded_assertion = invariant_encoder + .expression_to_snapshot(self, assertion, true)?; + pres.push(encoded_assertion); + } + // Create the snapshot constructor. + let deref_fields = + self.structural_invariant_to_deref_fields(&invariant)?; + let mut constructor_encoder = + AssertionToSnapshotConstructor::for_assign_aggregate_postcondition( + result_type, + arguments, + decl.fields, + deref_fields, + heap, + position, + ); + let invariant_expression = invariant.into_iter().conjoin(); + let permission_expression = + invariant_expression.convert_into_permission_expression(); + constructor_encoder + .expression_to_snapshot_constructor(self, &permission_expression)? + } else { + self.construct_struct_snapshot(&value.ty, arguments, position)? + } + // if let Some(invariant) = &decl.structural_invariant { + // let mut assertion_encoder = + // AssertionEncoder::new(&decl, arguments.clone(), &heap); + // if self.use_heap_variable()? { + // // TODO: Add a postcondition that is the + // // structural invariant that equates snap + // // calls from the input predicates to + // // projections on `result_value`. + // let deref_fields = self.structural_invariant_to_deref_fields( + // invariant.clone(), + // &value.ty, + // &decl, + // )?; + // for deref in deref_fields { + // let base_snapshot = assertion_encoder.expression_to_snapshot( + // self, + // &deref.base, + // true, + // )?; + // let argument = if let Some(heap) = heap.as_ref() { + // let in_heap = + // assertion_encoder.address_in_heap(self, &deref.base)?; + // // let in_heap = self.address_in_heap(heap.clone(), &deref.base)?; + // pres.push(in_heap); + // self.pointer_target_snapshot_in_heap( + // deref.base.get_type(), + // heap.clone(), + // base_snapshot, + // position, + // )? + // } else { + // // let deref_value = assertion_encoder + // // .expression_to_snapshot(self, &deref.base, true)?; + // let address = self.pointer_address( + // deref.base.get_type(), + // base_snapshot, + // position, + // )?; + // let ty = &deref.ty; + // self.owned_aliased_snap( + // CallContext::BuiltinMethod, + // ty, + // ty, + // address, + // position, + // )? + // }; + // arguments.push(argument); + // // if self.use_heap_variable()? { + // // let in_heap = + // // assertion_encoder.address_in_heap(self, &deref.base)?; + // // // let in_heap = self.address_in_heap(heap.clone(), &deref.base)?; + // // pres.push(in_heap); + // // } + // } + // } else { + // assert_eq!(arguments.len(), decl.fields.len()); + // for (argument, field) in arguments.iter().zip(decl.fields.iter()) { + // let field_snapshot = self.obtain_struct_field_snapshot( + // result_type, + // field, + // result_value.clone().into(), + // position, + // )?; + // posts.push(vir_low::Expression::equals( + // argument.clone(), + // field_snapshot, + // )) + // } + // let predicate = self + // .owned_non_aliased_predicate( + // CallContext::BuiltinMethod, + // result_type, + // result_type, + // target_place.clone().into(), + // target_address.clone().into(), + // true.into(), + // None, + // )? + // .unwrap_predicate_access_predicate(); + // assertion_encoder.set_result_value(result_value.clone()); + // for assertion in invariant { + // let low_assertion = assertion_encoder + // .expression_to_snapshot(self, assertion, true)?; + // posts.push(vir_low::Expression::unfolding( + // predicate.clone(), + // low_assertion, + // position, + // )); + // } + // assertion_encoder.unset_result_value(); + // } + // for assertion in invariant { + // let low_assertion = assertion_encoder + // .expression_to_snapshot(self, assertion, true)?; + // pres.push(low_assertion); + // } + // } + + // if self.use_heap_variable()? { + // parameters.push(heap.clone().unwrap()); + // // TODO: add the invariant to the precondition with + // // deref fields being the parameters of this method + // self.construct_struct_snapshot(&value.ty, arguments, position)? + // } else { + // self.owned_non_aliased_snap( + // CallContext::BuiltinMethod, + // &value.ty, + // &value.ty, + // target_place.into(), + // target_address.into(), + // position, + // )? + // } } vir_mid::Type::Array(value_ty) => vir_low::Expression::container_op( vir_low::ContainerOpKind::SeqConstructor, @@ -785,9 +1046,9 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { ), ty => unimplemented!("{}", ty), }; - posts.push( - self.encode_snapshot_valid_call_for_type(assigned_value.clone(), result_type)?, - ); + // posts.push( + // self.encode_snapshot_valid_call_for_type(assigned_value.clone(), result_type)?, + // ); assigned_value } }; @@ -854,9 +1115,15 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { posts.push( expr! { acc(MemoryBlock([operation_result_address.clone()], [size_of_result.clone()])) }, ); - posts.push( - expr! { acc(OwnedNonAliased([flag_place], target_address, [flag_value.clone()])) }, - ); + posts.push(self.owned_non_aliased( + CallContext::BuiltinMethod, + flag_type, + flag_type, + flag_place, + target_address.into(), + flag_value.clone(), + None, + )?); let operand_left = self.encode_assign_operand(parameters, pres, posts, 1, &value.left, position, true)?; let operand_right = @@ -930,6 +1197,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { &operand_snapshot_current, &operand_snapshot_final, &deref_lifetime, + None, )? } else { self.frac_ref_full_vars( @@ -947,20 +1215,36 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { let new_reference_predicate = { self.mark_owned_non_aliased_as_unfolded(result_type)?; let mut builder = OwnedNonAliasedUseBuilder::new( + self, + CallContext::BuiltinMethod, + result_type, + ty, + target_place.clone().into(), + target_address.clone().into(), + )?; + // builder.add_snapshot_argument(result_value.clone().into())?; + builder.add_custom_argument(true.into())?; + builder.add_custom_argument(new_borrow_lifetime.clone().into())?; + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + let predicate = builder.build()?; + let mut builder = OwnedNonAliasedSnapCallBuilder::new( self, CallContext::BuiltinMethod, result_type, ty, target_place.into(), target_address.into(), - result_value.clone().into(), + position, )?; builder.add_custom_argument(true.into())?; builder.add_custom_argument(new_borrow_lifetime.clone().into())?; builder.add_lifetime_arguments()?; builder.add_const_arguments()?; - builder.build() + let snapshot = builder.build()?; + expr! { [predicate] && (result_value == [snapshot]) } }; + eprintln!("new_reference_predicate: {new_reference_predicate}"); let restoration = { let final_snapshot = self.reference_target_final_snapshot( result_type, @@ -1080,23 +1364,39 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { &operand_place, &operand_root_address, &operand_snapshot, + false, )?; let reference_predicate = { self.mark_owned_non_aliased_as_unfolded(result_type)?; let mut builder = OwnedNonAliasedUseBuilder::new( + self, + CallContext::BuiltinMethod, + result_type, + ty, + target_place.clone().into(), + target_address.clone().into(), + )?; + // builder.add_snapshot_argument(result_value.clone().into())?; + builder.add_custom_argument(true.into())?; + builder.add_custom_argument(new_borrow_lifetime.clone().into())?; + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + let predicate = builder.build()?; + let mut builder = OwnedNonAliasedSnapCallBuilder::new( self, CallContext::BuiltinMethod, result_type, ty, target_place.into(), target_address.into(), - result_value.clone().into(), + position, )?; builder.add_custom_argument(true.into())?; builder.add_custom_argument(new_borrow_lifetime.clone().into())?; builder.add_lifetime_arguments()?; builder.add_const_arguments()?; - builder.build() + let snapshot = builder.build()?; + expr! { [predicate] && (result_value == [snapshot]) } }; let restoration = { let restoration_snapshot = if value.uniqueness.is_unique() { @@ -1209,6 +1509,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { &place, &root_address, &snapshot, + false, )?; pres.push(predicate.clone()); let post_predicate = if operand.kind == vir_mid::OperandKind::Copy { @@ -1270,7 +1571,15 @@ pub(in super::super) trait BuiltinMethodsInterface { fn encode_memory_block_copy_method(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()>; fn encode_memory_block_split_method(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()>; + fn encode_memory_block_range_split_method( + &mut self, + ty: &vir_mid::Type, + ) -> SpannedEncodingResult<()>; fn encode_memory_block_join_method(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()>; + fn encode_memory_block_range_join_method( + &mut self, + ty: &vir_mid::Type, + ) -> SpannedEncodingResult<()>; fn encode_move_place_method(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()>; fn encode_change_unique_ref_place_method( @@ -1325,6 +1634,10 @@ pub(in super::super) trait BuiltinMethodsInterface { predicate: vir_mid::VariableDecl, position: vir_low::Position, ) -> SpannedEncodingResult<()>; + fn encode_restore_raw_borrowed_method( + &mut self, + ty: &vir_mid::Type, + ) -> SpannedEncodingResult<()>; fn encode_open_frac_bor_atomic_method( &mut self, ty: &vir_mid::Type, @@ -1342,6 +1655,29 @@ pub(in super::super) trait BuiltinMethodsInterface { &mut self, ty_with_lifetime: &vir_mid::Type, ) -> SpannedEncodingResult<()>; + fn encode_stash_range_call( + &mut self, + statements: &mut Vec, + ty: &vir_mid::Type, + pointer_value: vir_low::Expression, + start_index: vir_low::Expression, + end_index: vir_low::Expression, + label: String, + position: vir_low::Position, + ) -> SpannedEncodingResult<()>; + + fn encode_restore_stash_range_call( + &mut self, + statements: &mut Vec, + ty: &vir_mid::Type, + old_pointer_value: vir_low::Expression, + old_start_index: vir_low::Expression, + old_end_index: vir_low::Expression, + label: String, + new_address: vir_low::Expression, + new_start_index: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult<()>; } impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { @@ -1451,9 +1787,9 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { ); let target_memory_block = builder.create_target_memory_block()?; builder.add_precondition(target_memory_block); - let source_owned = builder.create_source_owned()?; + let source_owned = builder.create_source_owned(false)?; builder.add_precondition(source_owned); - let target_owned = builder.create_target_owned()?; + let target_owned = builder.create_target_owned(false)?; builder.add_postcondition(target_owned); let source_memory_block = builder.create_source_memory_block()?; builder.add_postcondition(source_memory_block); @@ -1461,7 +1797,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { builder.add_target_validity_postcondition()?; if has_body { builder.create_body(); - let source_owned = builder.create_source_owned()?; + let source_owned = builder.create_source_owned(true)?; builder.add_statement(vir_low::Statement::unfold_no_pos(source_owned)); } match &type_decl { @@ -1511,8 +1847,20 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { _ => unimplemented!("{type_decl:?}"), } if has_body { - let target_owned = builder.create_target_owned()?; + let target_owned = builder.create_target_owned(true)?; builder.add_statement(vir_low::Statement::fold_no_pos(target_owned)); + if let vir_mid::TypeDecl::Reference(vir_mid::type_decl::Reference { + uniqueness, + lifetimes, + .. + }) = &type_decl + { + if uniqueness.is_unique() { + // FIXME: Have a getter for the first lifetime. + let lifetime = &lifetimes[0]; + builder.add_dead_lifetime_hack(lifetime)?; + } + } } let method = builder.build(); self.declare_method(method)?; @@ -1607,7 +1955,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { "copy_place", &normalized_type, &type_decl, - BuiltinMethodKind::MovePlace, + BuiltinMethodKind::CopyPlace, )?; builder.create_parameters()?; // FIXME: To generate body for arrays, we would need to generate a @@ -1624,12 +1972,12 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { let source_owned = builder.create_source_owned()?; builder.add_precondition(source_owned.clone()); builder.add_postcondition(source_owned); - let target_owned = builder.create_target_owned()?; + let target_owned = builder.create_target_owned(false)?; builder.add_postcondition(target_owned); builder.add_target_validity_postcondition()?; if has_body { builder.create_body(); - let source_owned = builder.create_source_owned()?; + let source_owned = builder.create_source_owned_predicate()?; builder.add_statement(vir_low::Statement::unfold_no_pos(source_owned)); } match &type_decl { @@ -1650,9 +1998,9 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { _ => unimplemented!("{type_decl:?}"), } if has_body { - let target_owned = builder.create_target_owned()?; + let target_owned = builder.create_target_owned(true)?; builder.add_statement(vir_low::Statement::fold_no_pos(target_owned)); - let source_owned = builder.create_source_owned()?; + let source_owned = builder.create_source_owned_predicate()?; builder.add_statement(vir_low::Statement::fold_no_pos(source_owned)); } let method = builder.build(); @@ -1698,7 +2046,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { let target_memory_block = builder.create_target_memory_block()?; builder.add_precondition(target_memory_block); builder.add_source_validity_precondition()?; - let target_owned = builder.create_target_owned()?; + let target_owned = builder.create_target_owned(false)?; builder.add_postcondition(target_owned); if has_body { builder.create_body(); @@ -1729,7 +2077,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { _ => unimplemented!("{type_decl:?}"), } if has_body { - let target_owned = builder.create_target_owned()?; + let target_owned = builder.create_target_owned(true)?; builder.add_statement(vir_low::Statement::fold_no_pos(target_owned)); } let method = builder.build(); @@ -1770,6 +2118,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { &place, &root_address, &old_snapshot, + false, )?; let predicate_out = self.owned_non_aliased_full_vars( CallContext::BuiltinMethod, @@ -1778,6 +2127,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { &place, &root_address, &fresh_snapshot, + false, )?; let method = vir_low::MethodDecl::new( method_name, @@ -1810,6 +2160,8 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { .encoded_memory_block_split_methods .insert(ty_identifier); + self.encode_compute_address(ty)?; + let type_decl = self.encoder.get_type_decl_mid(ty)?; let normalized_type = ty.normalize_type(); let mut builder = MemoryBlockSplitMethodBuilder::new( @@ -1851,6 +2203,44 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { } Ok(()) } + fn encode_memory_block_range_split_method( + &mut self, + ty: &vir_mid::Type, + ) -> SpannedEncodingResult<()> { + let ty_identifier = ty.get_identifier(); + if !self + .builtin_methods_state + .encoded_memory_block_range_split_methods + .contains(&ty_identifier) + { + assert!( + !ty.is_trusted() && !ty.is_type_var(), + "Trying to split an abstract type." + ); + self.builtin_methods_state + .encoded_memory_block_range_split_methods + .insert(ty_identifier); + + self.encode_compute_address(ty)?; + let type_decl = self.encoder.get_type_decl_mid(ty)?; + let normalized_type = ty.normalize_type(); + let mut builder = MemoryBlockRangeSplitMethodBuilder::new( + self, + vir_low::MethodKind::LowMemoryOperation, + "memory_block_range_split", + &normalized_type, + &type_decl, + BuiltinMethodKind::JoinMemoryBlock, + )?; + builder.create_parameters()?; + builder.add_whole_memory_block_precondition()?; + builder.add_memory_block_range_postcondition()?; + builder.add_byte_values_preserved_postcondition()?; + let method = builder.build(); + self.declare_method(method)?; + } + Ok(()) + } fn encode_memory_block_join_method(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()> { let ty_identifier = ty.get_identifier(); if !self @@ -1866,6 +2256,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { .encoded_memory_block_join_methods .insert(ty_identifier); + self.encode_compute_address(ty)?; let type_decl = self.encoder.get_type_decl_mid(ty)?; let normalized_type = ty.normalize_type(); let mut builder = MemoryBlockJoinMethodBuilder::new( @@ -1919,6 +2310,44 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { } Ok(()) } + fn encode_memory_block_range_join_method( + &mut self, + ty: &vir_mid::Type, + ) -> SpannedEncodingResult<()> { + let ty_identifier = ty.get_identifier(); + if !self + .builtin_methods_state + .encoded_memory_block_range_join_methods + .contains(&ty_identifier) + { + assert!( + !ty.is_trusted() && !ty.is_type_var(), + "Trying to join an abstract type." + ); + self.builtin_methods_state + .encoded_memory_block_range_join_methods + .insert(ty_identifier); + + self.encode_compute_address(ty)?; + let type_decl = self.encoder.get_type_decl_mid(ty)?; + let normalized_type = ty.normalize_type(); + let mut builder = MemoryBlockRangeJoinMethodBuilder::new( + self, + vir_low::MethodKind::LowMemoryOperation, + "memory_block_range_join", + &normalized_type, + &type_decl, + BuiltinMethodKind::JoinMemoryBlock, + )?; + builder.create_parameters()?; + builder.add_memory_block_range_precondition()?; + builder.add_whole_memory_block_postcondition()?; + builder.add_byte_values_preserved_postcondition()?; + let method = builder.build(); + self.declare_method(method)?; + } + Ok(()) + } fn encode_havoc_memory_block_method_name( &mut self, ty: &vir_mid::Type, @@ -1996,7 +2425,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { )?; builder.create_parameters()?; builder.add_const_parameters_validity_precondition()?; - let predicate = builder.create_owned()?; + let predicate = builder.create_owned(false)?; builder.add_precondition(predicate); let memory_block = builder.create_target_memory_block()?; builder.add_postcondition(memory_block); @@ -2011,7 +2440,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { ); if has_body { builder.create_body(); - let predicate = builder.create_owned()?; + let predicate = builder.create_owned(true)?; builder.add_statement(vir_low::Statement::unfold_no_pos(predicate)); } match &type_decl { @@ -2108,13 +2537,12 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { source_snapshot.clone(), source_permission_amount, )?]; - let new_snapshot = self.new_snapshot_variable_version(&target.get_base(), position)?; - self.encode_snapshot_update_with_new_snapshot( + let new_snapshot = self.encode_snapshot_update_with_new_snapshot( &mut copy_place_statements, &target, source_snapshot, position, - Some(new_snapshot.clone()), + // Some(new_snapshot.clone()), )?; if let Some(conditions) = value.use_field { let mut disjuncts = Vec::new(); @@ -2122,12 +2550,12 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { disjuncts.push(self.lower_block_marker_condition(condition)?); } let mut else_branch = vec![assign_statement]; - self.encode_snapshot_update_with_new_snapshot( + let new_snapshot_else_branch = self.encode_snapshot_update_with_new_snapshot( &mut else_branch, &target, result_value.into(), position, - Some(new_snapshot), + // Some(new_snapshot), )?; statements.push(vir_low::Statement::conditional( disjuncts.into_iter().disjoin(), @@ -2135,6 +2563,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { else_branch, position, )); + statements.push(vir_low::Statement::assume( + vir_low::Expression::equals(new_snapshot, new_snapshot_else_branch), + position, + )); } else { // Use field unconditionally. statements.extend(copy_place_statements); @@ -2147,21 +2579,69 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { result_value.clone().into(), position, )?; - if let vir_mid::Rvalue::Ref(value) = value { - let snapshot = if value.uniqueness.is_unique() { - self.reference_target_final_snapshot( - target.get_type(), - result_value.into(), - position, - )? - } else { - self.reference_target_current_snapshot( - target.get_type(), - result_value.into(), - position, - )? - }; - self.encode_snapshot_update(statements, &value.place, snapshot, position)?; + // if let vir_mid::Rvalue::Ref(value) = value { + // let snapshot = if value.uniqueness.is_unique() { + // self.reference_target_final_snapshot( + // target.get_type(), + // result_value.into(), + // position, + // )? + // } else { + // self.reference_target_current_snapshot( + // target.get_type(), + // result_value.into(), + // position, + // )? + // }; + // self.encode_snapshot_update(statements, &value.place, snapshot, position)?; + // } + match value { + vir_mid::Rvalue::Ref(value) => { + let snapshot = if value.uniqueness.is_unique() { + self.reference_target_final_snapshot( + target.get_type(), + result_value.into(), + position, + )? + } else { + self.reference_target_current_snapshot( + target.get_type(), + result_value.into(), + position, + )? + }; + self.encode_snapshot_update(statements, &value.place, snapshot, position)?; + } + // vir_mid::Rvalue::AddressOf(value) => { + // let address = self.pointer_address( + // target.get_type(), + // result_value.clone().into(), + // position, + // )?; + // let heap = self.heap_variable_version_at_label(&None)?; + // // statements.push(vir_low::Statement::assume( + // // vir_low::Expression::container_op_no_pos( + // // vir_low::ContainerOpKind::MapContains, + // // heap.ty.clone(), + // // vec![heap.into(), address], + // // ), + // // position, + // // )); + // let heap_chunk = self.pointer_target_snapshot( + // target.get_type(), + // &None, + // result_value.into(), + // position, + // )?; + // statements.push(vir_low::Statement::assume( + // vir_low::Expression::equals( + // heap_chunk, + // value.place.to_procedure_snapshot(self)?, + // ), + // position, + // )); + // } + _ => {} } } Ok(()) @@ -2263,11 +2743,15 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { lifetime_perm: Perm, owned_perm: Perm, place: Place, - root_address: Address, - current_snapshot: {ty.to_snapshot(self)?} + root_address: Address + }; + let current_snapshot = if self.check_mode.unwrap() == CheckMode::PurificationSoudness { + Some(var! { current_snapshot: {ty.to_snapshot(self)?} }) + } else { + None }; let lifetime_access = expr! { acc(LifetimeToken(lifetime), lifetime_perm) }; - let frac_ref_access = self.frac_ref_full_vars( + let frac_ref_access = self.frac_ref_full_vars_opt( CallContext::BuiltinMethod, ty, ty, @@ -2276,25 +2760,45 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { ¤t_snapshot, &lifetime, )?; + let TODO_target_slice_len = None; + let prestate_snapshot = vir_low::Expression::labelled_old_no_pos( + None, + self.frac_ref_snap( + CallContext::BuiltinMethod, + ty, + ty, + place.clone().into(), + root_address.clone().into(), + lifetime.clone().into(), + TODO_target_slice_len, + )?, + ); let owned_access = self.owned_non_aliased( CallContext::BuiltinMethod, ty, &type_decl, place.clone().into(), root_address.clone().into(), - current_snapshot.clone().into(), + prestate_snapshot, + Some(owned_perm.clone().into()), + )?; + let owned_access_magic_wand = self.owned_non_aliased_predicate( + CallContext::BuiltinMethod, + ty, + &type_decl, + place.clone().into(), + root_address.clone().into(), + true.into(), // TODO: Remove. Some(owned_perm.clone().into()), )?; + let mut parameters = vec![lifetime, lifetime_perm.clone(), place, root_address]; + if let Some(current_snapshot) = current_snapshot { + parameters.push(current_snapshot); + } let method = vir_low::MethodDecl::new( self.encode_open_frac_bor_atomic_method_name(ty)?, vir_low::MethodKind::MirOperation, - vec![ - lifetime, - lifetime_perm.clone(), - place, - root_address, - current_snapshot, - ], + parameters, vec![owned_perm.clone()], vec![ expr! { @@ -2310,9 +2814,9 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { expr! { [vir_low::Expression::no_permission()] < owned_perm }, - owned_access.clone(), + owned_access, vir_low::Expression::magic_wand_no_pos( - owned_access, + owned_access_magic_wand, expr! { [lifetime_access] && [frac_ref_access] }, ), ], @@ -2323,6 +2827,43 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { Ok(()) } + fn encode_restore_raw_borrowed_method( + &mut self, + ty: &vir_mid::Type, + ) -> SpannedEncodingResult<()> { + let ty_identifier = ty.get_identifier(); + if !self + .builtin_methods_state + .encoded_restore_raw_borrowed_methods + .contains(&ty_identifier) + { + self.builtin_methods_state + .encoded_restore_raw_borrowed_methods + .insert(ty_identifier); + + self.encode_restore_raw_borrowed_transition_predicate(ty)?; + + let type_decl = self.encoder.get_type_decl_mid(ty)?; + let normalized_type = ty.normalize_type(); + + let mut builder = RestoreRawBorrowedMethodBuilder::new( + self, + vir_low::MethodKind::LowMemoryOperation, + "restore_raw_borrowed", + &normalized_type, + &type_decl, + BuiltinMethodKind::RestoreRawBorrowed, + )?; + builder.create_parameters()?; + builder.add_aliased_source_precondition()?; + builder.add_shift_precondition()?; + builder.add_non_aliased_target_postcondition()?; + let method = builder.build(); + self.declare_method(method)?; + } + Ok(()) + } + fn encode_lft_tok_sep_take_method(&mut self, lft_count: usize) -> SpannedEncodingResult<()> { if !self .builtin_methods_state @@ -2566,6 +3107,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { &place, &root_address, ¤t_snapshot, + false, )?; let unique_ref_predicate = self.unique_ref_full_vars( CallContext::BuiltinMethod, @@ -2576,6 +3118,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { ¤t_snapshot, &final_snapshot, &lifetime, + None, )?; let open_method = vir_low::MethodDecl::new( method_name! { open_mut_ref }, @@ -2629,6 +3172,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { { let close_mut_ref_predicate = vir_low::PredicateDecl::new( predicate_name! { CloseMutRef }, + vir_low::PredicateKind::WithoutSnapshotWhole, vec![ lifetime.clone(), lifetime_perm.clone(), @@ -2726,6 +3270,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { ¤t_snapshot, &final_snapshot, &old_lft, + None, )?); posts.push(self.unique_ref_full_vars( CallContext::BuiltinMethod, @@ -2736,6 +3281,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { ¤t_snapshot, &final_snapshot, &lft, + None, )?); } else { pres.push(self.frac_ref_full_vars( @@ -2771,4 +3317,368 @@ impl<'p, 'v: 'p, 'tcx: 'v> BuiltinMethodsInterface for Lowerer<'p, 'v, 'tcx> { } Ok(()) } + + fn encode_stash_range_call( + &mut self, + statements: &mut Vec, + ty: &vir_mid::Type, + pointer_value: vir_low::Expression, + start_index: vir_low::Expression, + end_index: vir_low::Expression, + label: String, + position: vir_low::Position, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + statements.push(vir_low::Statement::comment(format!( + "Stash range call for {}", + label + ))); + // statements.push(vir_low::Statement::label(label.clone(), position)); + let exhale_owned = vir_low::Statement::exhale( + self.owned_aliased_range( + CallContext::Procedure, + ty, + ty, + pointer_value.clone(), + start_index.clone(), + end_index.clone(), + None, + )?, + position, + ); + statements.push(exhale_owned); + let vir_mid::Type::Pointer(pointer_type) = ty else { + unreachable!("ty: {}", ty); + }; + let target_type = &*pointer_type.target_type; + let ty_identifier = target_type.get_identifier(); + if !self + .builtin_methods_state + .encoded_stashed_owned_aliased_predicates + .contains(&ty_identifier) + { + self.builtin_methods_state + .encoded_stashed_owned_aliased_predicates + .insert(ty_identifier); + let predicate = vir_low::PredicateDecl::new( + predicate_name! { StashedOwnedAliased }, + vir_low::PredicateKind::WithoutSnapshotWhole, + vec![ + var! { index: Int }, + var! { bytes: Bytes }, + var! { snapshot: { target_type.to_snapshot(self)? } }, + ], + None, + ); + self.declare_predicate(predicate)?; + } + let start_address = self.pointer_address(ty, pointer_value, position)?; + let size = self.encode_type_size_expression2(target_type, target_type)?; + let inhale_raw = vir_low::Statement::inhale( + self.encode_memory_block_range_acc( + start_address.clone(), + size.clone(), + start_index.clone(), + end_index.clone(), + position, + )?, + position, + ); + statements.push(inhale_raw); + let inhale_stash = { + let size_type = self.size_type_mid()?; + var_decls! { + index: Int + } + // let start_address = self.pointer_address( + // ty, + // pointer_value, + // position, + // )?; + let element_address = + self.address_offset(size.clone(), start_address, index.clone().into(), position)?; + let start_index = self.obtain_constant_value(&size_type, start_index, position)?; + let end_index = self.obtain_constant_value(&size_type, end_index, position)?; + let bytes = self.encode_memory_block_bytes_expression(element_address.clone(), size)?; + let snapshot = vir_low::Expression::labelled_old( + Some(label.clone()), + self.owned_aliased_snap( + CallContext::Procedure, + target_type, + target_type, + element_address.clone(), + position, + )?, + position, + ); + let stash_predicate = expr! { + acc(StashedOwnedAliased( + index, + [bytes], + [snapshot] + )) + }; + let body = expr!( + (([start_index] <= index) && (index < [end_index])) ==> + [stash_predicate] + ); + let expression = vir_low::Expression::forall( + vec![index], + vec![vir_low::Trigger::new(vec![element_address])], + body, + ); + vir_low::Statement::inhale(expression, position) + }; + statements.push(inhale_stash); + // statements.push(vir_low::Statement::label( + // format!("{}$post", label), + // position, + // )); + Ok(()) + } + + fn encode_restore_stash_range_call( + &mut self, + statements: &mut Vec, + ty: &vir_mid::Type, + old_pointer_value: vir_low::Expression, + old_start_index: vir_low::Expression, + old_end_index: vir_low::Expression, + label: String, + new_pointer_value: vir_low::Expression, + new_start_index: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult<()> { + statements.push(vir_low::Statement::comment(format!( + "Restore stash for {}", + label + ))); + let label_post = format!("{}$post", label); + use vir_low::macros::*; + let vir_mid::Type::Pointer(pointer_type) = ty else { + unreachable!("ty: {}", ty); + }; + let size_type = self.size_type_mid()?; + let target_type = &*pointer_type.target_type; + let size = self.encode_type_size_expression2(target_type, target_type)?; + let old_start_address = self.pointer_address(ty, old_pointer_value, position)?; + let new_start_address = self.pointer_address(ty, new_pointer_value, position)?; + let new_start_index = + self.obtain_constant_value(&size_type, new_start_index.clone(), position)?; + let new_end_index = vir_low::Expression::add( + new_start_index.clone(), + vir_low::Expression::labelled_old( + Some(label.clone()), + vir_low::Expression::subtract( + self.obtain_constant_value(&size_type, old_end_index.clone(), position)?, + self.obtain_constant_value(&size_type, old_start_index.clone(), position)?, + ), + position, + ), + ); + let assume_extensionality = { + // For performance reasons, we do not have global extensionality + // assumptions, but assume them when needed. + + // FIXME: Instead of having the assumption as a quantifier, assert + // that bytes are equal for the entire range and then assume that + // the byte blocks are equal. + var_decls! { + index: Int, + byte_index: Int + } + let new_element_address = self.address_offset( + size.clone(), + new_start_address.clone(), + index.clone().into(), + position, + )?; + let old_index = vir_low::Expression::add( + vir_low::Expression::labelled_old( + Some(label.clone()), + self.obtain_constant_value(&size_type, old_start_index.clone(), position)?, + position, + ), + vir_low::Expression::subtract(index.clone().into(), new_start_index.clone()), + ); + let old_element_address = self.address_offset( + size.clone(), + old_start_address.clone(), + old_index.clone(), + position, + )?; + let new_block_bytes = self + .encode_memory_block_bytes_expression(new_element_address.clone(), size.clone())?; + let old_block_bytes = self + .encode_memory_block_bytes_expression(old_element_address.clone(), size.clone())?; + let old_block_bytes = + vir_low::Expression::labelled_old(Some(label_post), old_block_bytes, position); + let element_size_int = + self.obtain_constant_value(&size_type, size.clone(), position)?; + let new_read_element_byte = self.encode_read_byte_expression( + new_block_bytes.clone(), + byte_index.clone().into(), + position, + )?; + let old_read_element_byte = self.encode_read_byte_expression( + old_block_bytes.clone(), + byte_index.clone().into(), + position, + )?; + let bytes_equal_body = expr!( + ((([new_start_index.clone()] <= index) && (index < [new_end_index.clone()])) && + (([0.into()] <= byte_index) && (byte_index < [element_size_int]))) ==> + ([new_read_element_byte.clone()] == [old_read_element_byte.clone()]) + ); + let bytes_equal = vir_low::Expression::forall( + vec![index.clone(), byte_index], + vec![vir_low::Trigger::new(vec![new_read_element_byte])], + bytes_equal_body, + ); + let assert_byte_equality = vir_low::Statement::assert(bytes_equal, position); + statements.push(assert_byte_equality); + let body = expr!( + (([new_start_index.clone()] <= index) && (index < [new_end_index.clone()])) ==> + // ( + ([new_block_bytes] == [old_block_bytes]) + // == [bytes_equal]) + ); + let expression = vir_low::Expression::forall( + vec![index], + vec![vir_low::Trigger::new(vec![new_element_address])], + body, + ); + vir_low::Statement::assume(expression, position) + }; + statements.push(assume_extensionality); + let exhale_stash = { + var_decls! { + index: Int + } + // let start_address = self.pointer_address( + // ty, + // pointer_value, + // position, + // )?; + let old_index = vir_low::Expression::add( + vir_low::Expression::labelled_old( + Some(label.clone()), + self.obtain_constant_value(&size_type, old_start_index.clone(), position)?, + position, + ), + vir_low::Expression::subtract(index.clone().into(), new_start_index.clone()), + ); + let old_element_address = + self.address_offset(size.clone(), old_start_address, old_index.clone(), position)?; + let new_element_address = self.address_offset( + size.clone(), + new_start_address.clone(), + index.clone().into(), + position, + )?; + // let start_index = self.obtain_constant_value(&size_type, start_index, position)?; + // let end_index = self.obtain_constant_value(&size_type, end_index, position)?; + let bytes = self + .encode_memory_block_bytes_expression(new_element_address.clone(), size.clone())?; + let snapshot = vir_low::Expression::labelled_old( + Some(label), + self.owned_aliased_snap( + CallContext::Procedure, + target_type, + target_type, + old_element_address.clone(), + position, + )?, + position, + ); + let stash_predicate = expr! { + acc(StashedOwnedAliased( + [old_index], + [bytes], + [snapshot] + )) + }; + let body = expr!( + (([new_start_index.clone()] <= index) && (index < [new_end_index.clone()])) ==> + [stash_predicate] + ); + let expression = vir_low::Expression::forall( + vec![index], + vec![vir_low::Trigger::new(vec![new_element_address])], + body, + ); + vir_low::Statement::exhale(expression, position) + }; + statements.push(exhale_stash); + // FIXME: Code duplication with encode_memory_block_range_acc. + let exhale_raw = { + var_decls! { + index: Int + } + let element_address = self.address_offset( + size.clone(), + new_start_address.clone(), + index.clone().into(), + position, + )?; + let predicate = self.encode_memory_block_stack_acc( + element_address.clone(), + size.clone(), + position, + )?; + // let new_start_index = self.obtain_constant_value(&size_type, new_start_address.clone(), position)?; + // let end_index = self.obtain_constant_value(&size_type, end_index, position)?; + let body = expr!( + (([new_start_index.clone()] <= index) && (index < [new_end_index.clone()])) ==> [predicate] + ); + let expression = vir_low::Expression::forall( + vec![index], + vec![vir_low::Trigger::new(vec![element_address])], + body, + ); + vir_low::Statement::exhale( + expression, + // self.encode_memory_block_range_acc( + // new_start_address.clone(), + // size, + // new_start_index.clone(), + // new_end_index.clone(), + // position, + // )?, + position, + ) + }; + statements.push(exhale_raw); + let inhale_owned = { + var_decls! { + index: Int + } + let element_address = self.address_offset( + size, + new_start_address.clone(), + index.clone().into(), + position, + )?; + let predicate = self.owned_aliased( + CallContext::Procedure, + target_type, + target_type, + element_address.clone(), + None, + )?; + // let new_start_index = self.obtain_constant_value(&size_type, new_start_address.clone(), position)?; + // let end_index = self.obtain_constant_value(&size_type, end_index, position)?; + let body = expr!( + (([new_start_index.clone()] <= index) && (index < [new_end_index.clone()])) ==> [predicate] + ); + let expression = vir_low::Expression::forall( + vec![index], + vec![vir_low::Trigger::new(vec![element_address])], + body, + ); + vir_low::Statement::inhale(expression, position) + }; + statements.push(inhale_owned); + Ok(()) + } } diff --git a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/mod.rs b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/mod.rs index a04aa21f1cd..7f84d0111ec 100644 --- a/prusti-viper/src/encoder/middle/core_proof/builtin_methods/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/builtin_methods/mod.rs @@ -1,3 +1,4 @@ +mod assertion_encoder; mod builders; mod calls; mod interface; diff --git a/prusti-viper/src/encoder/middle/core_proof/footprint/interface.rs b/prusti-viper/src/encoder/middle/core_proof/footprint/interface.rs new file mode 100644 index 00000000000..da65901a1b8 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/footprint/interface.rs @@ -0,0 +1,393 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + high::types::HighTypeEncoderInterface, + middle::core_proof::{lowerer::Lowerer, snapshots::IntoSnapshot}, +}; +use std::collections::{BTreeMap, BTreeSet}; +use vir_crate::{ + common::position::Positioned, + low as vir_low, + middle::{self as vir_mid, operations::ty::Typed, visitors::ExpressionFolder}, +}; + +struct FootprintComputation<'l, 'p, 'v, 'tcx> { + _lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + // parameters: &'l BTreeMap, + deref_field_fresh_index_counters: BTreeMap, + deref_field_indices: BTreeMap, + derefs: Vec, +} + +// FIXME: Delete. +impl<'l, 'p, 'v, 'tcx> FootprintComputation<'l, 'p, 'v, 'tcx> { + fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + parameters: &'l BTreeMap, + ) -> Self { + let deref_field_fresh_index_counters = parameters + .iter() + .map(|(parameter, decl)| (parameter.clone(), decl.fields.len())) + .collect(); + Self { + _lowerer: lowerer, + // parameters, + deref_field_fresh_index_counters, + deref_field_indices: Default::default(), + derefs: Default::default(), + } + } + + fn extract_base_variable<'a>( + &self, + place: &'a vir_mid::Expression, + ) -> &'a vir_mid::VariableDecl { + match place { + vir_mid::Expression::Local(expression) => &expression.variable, + _ => unimplemented!(), + } + } + + // FIXME: This should be using `own` places. + fn create_deref_field(&mut self, deref: &vir_mid::Deref) -> vir_mid::Expression { + match &*deref.base { + vir_mid::Expression::Field(expression) => { + let variable = self.extract_base_variable(&expression.base); + let deref_field_name = format!("{}$deref", expression.field.name); + let deref_variable = vir_mid::VariableDecl::new(deref_field_name, deref.ty.clone()); + let field_index = self.compute_deref_field_index(deref, variable, &deref_variable); + vir_mid::Expression::field( + (*expression.base).clone(), + vir_mid::FieldDecl { + name: deref_variable.name, + index: field_index, + ty: deref_variable.ty, + }, + expression.position, + ) + } + _ => unimplemented!(), + } + } + + fn compute_deref_field_index( + &mut self, + deref: &vir_mid::Deref, + variable: &vir_mid::VariableDecl, + deref_variable: &vir_mid::VariableDecl, + ) -> usize { + if let Some(index) = self.deref_field_indices.get(deref_variable) { + *index + } else { + let counter = self + .deref_field_fresh_index_counters + .get_mut(variable) + .unwrap(); + let index = *counter; + *counter += 1; + self.deref_field_indices + .insert(deref_variable.clone(), index); + self.derefs.push(deref.clone()); + index + } + } + + fn into_deref_fields(self) -> Vec<(vir_mid::VariableDecl, usize)> { + let mut deref_fields: Vec<_> = self.deref_field_indices.into_iter().collect(); + deref_fields.sort_by_key(|(_, index)| *index); + deref_fields + } + + fn into_derefs(self) -> Vec { + self.derefs + } +} + +impl<'l, 'p, 'v, 'tcx> vir_mid::visitors::ExpressionFolder + for FootprintComputation<'l, 'p, 'v, 'tcx> +{ + fn fold_acc_predicate_enum( + &mut self, + acc_predicate: vir_mid::AccPredicate, + ) -> vir_mid::Expression { + match *acc_predicate.predicate { + vir_mid::Predicate::LifetimeToken(_) => { + unimplemented!() + } + vir_mid::Predicate::MemoryBlockStack(_) + | vir_mid::Predicate::MemoryBlockStackDrop(_) + | vir_mid::Predicate::MemoryBlockHeap(_) + | vir_mid::Predicate::MemoryBlockHeapRange(_) + | vir_mid::Predicate::MemoryBlockHeapDrop(_) => true.into(), + vir_mid::Predicate::OwnedNonAliased(predicate) => { + let position = predicate.place.position(); + let place = self.fold_expression(predicate.place); + vir_mid::Expression::builtin_func_app( + vir_mid::BuiltinFunc::IsValid, + Vec::new(), + vec![place], + vir_mid::Type::Bool, + position, + ) + // match predicate.place { + // vir_mid::Expression::Deref(deref) => { + // let deref_field = self.create_deref_field(&deref); + // let app = vir_mid::Expression::builtin_func_app( + // vir_mid::BuiltinFunc::IsValid, + // Vec::new(), + // vec![deref_field], + // vir_mid::Type::Bool, + // deref.position, + // ); + // app + // }} + // _ => unimplemented!(), + } + vir_mid::Predicate::OwnedRange(predicate) => { + unimplemented!("predicate: {}", predicate); + } + vir_mid::Predicate::OwnedSet(predicate) => { + unimplemented!("predicate: {}", predicate); + } + } + } + + fn fold_deref_enum(&mut self, deref: vir_mid::Deref) -> vir_mid::Expression { + if deref.base.get_type().is_pointer() { + self.create_deref_field(&deref) + } else { + vir_mid::Expression::Deref(self.fold_deref(deref)) + } + } +} + +struct Predicate<'l, 'p, 'v, 'tcx> { + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, +} + +impl<'l, 'p, 'v, 'tcx> Predicate<'l, 'p, 'v, 'tcx> { + fn new(lowerer: &'l mut Lowerer<'p, 'v, 'tcx>) -> Self { + Self { lowerer } + } + + // FIXME: Code duplication. + fn extract_base_variable<'a>( + &self, + place: &'a vir_mid::Expression, + ) -> &'a vir_mid::VariableDecl { + match place { + vir_mid::Expression::Local(expression) => &expression.variable, + _ => unimplemented!(), + } + } +} + +impl<'l, 'p, 'v, 'tcx> vir_mid::visitors::ExpressionFolder for Predicate<'l, 'p, 'v, 'tcx> { + // fn fold_field_enum(&mut self, field: vir_mid::Field) -> vir_mid::Expression { + // match &*field.base { + // vir_mid::Expression::Local(local) => { + // assert!(local.variable.is_self_variable()); + // let position = field.position; + // let app = vir_mid::Expression::builtin_func_app( + // vir_mid::BuiltinFunc::GetSnapshot, + // Vec::new(), + // vec![deref_field], + // vir_mid::Type::Bool, + // position, + // ); + // app + // } + // _ => vir_mid::visitors::default_fold_field(self, field), + // } + // } + // fn fold_acc_predicate_enum( + // &mut self, + // acc_predicate: vir_mid::AccPredicate, + // ) -> vir_mid::Expression { + // match *acc_predicate.predicate { + // vir_mid::Predicate::LifetimeToken(_) => { + // unimplemented!() + // } + // vir_mid::Predicate::MemoryBlockStack(_) + // | vir_mid::Predicate::MemoryBlockStackDrop(_) + // | vir_mid::Predicate::MemoryBlockHeap(_) + // | vir_mid::Predicate::MemoryBlockHeapDrop(_) => true.into(), + // vir_mid::Predicate::OwnedNonAliased(predicate) => match predicate.place { + // vir_mid::Expression::Deref(deref) => { + // let deref_field = self.create_deref_field(&deref); + // let app = vir_mid::Expression::builtin_func_app( + // vir_mid::BuiltinFunc::IsValid, + // Vec::new(), + // vec![deref_field], + // vir_mid::Type::Bool, + // deref.position, + // ); + // app + // } + // _ => unimplemented!(), + // }, + // } + // } +} + +pub(in super::super) trait FootprintInterface { + // fn purify_expressions( + // &mut self, + // expressions: Vec, + // parameters: &BTreeMap, + // ) -> SpannedEncodingResult>; + + // /// Rewrite expression so that it is based only on the snapshots, removing + // /// all predicates. + // fn structural_invariant_to_pure_expression( + // &mut self, + // expressions: Vec, + // ty: &vir_mid::Type, + // decl: &vir_mid::type_decl::Struct, + // fields: &mut Vec, + // ) -> SpannedEncodingResult>; + + // fn structural_invariant_to_predicate_assertion( + // &mut self, + // expressions: Vec, + // ) -> SpannedEncodingResult>; + + // fn framing_variable_deref_fields( + // &mut self, + // framing_variables: &[vir_mid::VariableDecl], + // ) -> SpannedEncodingResult>; + + fn structural_invariant_to_deref_fields( + &mut self, + invariant: &[vir_mid::Expression], + ) -> SpannedEncodingResult>; + + fn compute_deref_field_from_place( + &mut self, + place: &vir_mid::Expression, + ) -> SpannedEncodingResult<(String, vir_low::Type)>; +} + +impl<'p, 'v: 'p, 'tcx: 'v> FootprintInterface for Lowerer<'p, 'v, 'tcx> { + // fn framing_variable_deref_fields( + // &mut self, + // framing_variables: &[vir_mid::VariableDecl], + // ) -> SpannedEncodingResult> { + // let mut deref_fields = Vec::new(); + // for variable in framing_variables { + // let type_decl = self.encoder.get_type_decl_mid(&variable.ty)?; + // if let vir_mid::TypeDecl::Struct(vir_mid::type_decl::Struct { + // structural_invariant: Some(invariant), + // .. + // }) = &type_decl + // { + // deref_fields.extend(self.structural_invariant_to_deref_fields(invariant)?); + // } + // } + // Ok(deref_fields) + // } + + /// For the given invariant, compute the deref fields. This is done by + /// finding all `own` predicates and creating variables for them. + /// + /// The order of the returned fields is guaranteed to be the same for the + /// same invariant. + fn structural_invariant_to_deref_fields( + &mut self, + invariant: &[vir_mid::Expression], + ) -> SpannedEncodingResult> { + let mut owned_places = BTreeSet::default(); + for expression in invariant { + owned_places.extend(expression.collect_owned_places()); + } + let mut fields = Vec::new(); + for owned_place in owned_places { + let (name, ty) = self.compute_deref_field_from_place(&owned_place)?; + fields.push((owned_place, name, ty)); + } + Ok(fields) + } + + /// Computes the parameter that corresponds to the value of + /// the dereferenced place. + fn compute_deref_field_from_place( + &mut self, + place: &vir_mid::Expression, + ) -> SpannedEncodingResult<(String, vir_low::Type)> { + let mut parameter_name = String::new(); + fn compute_name( + place: &vir_mid::Expression, + parameter_name: &mut String, + ) -> SpannedEncodingResult<()> { + match place { + vir_mid::Expression::Deref(expression) => { + compute_name(&expression.base, parameter_name)?; + parameter_name.push_str("$deref"); + } + vir_mid::Expression::Field(expression) => { + compute_name(&expression.base, parameter_name)?; + parameter_name.push('$'); + parameter_name.push_str(&expression.field.name); + } + vir_mid::Expression::Local(expression) => { + assert!(expression.variable.is_self_variable()); + } + _ => { + unimplemented!("{place}"); + } + } + Ok(()) + } + compute_name(place, &mut parameter_name)?; + Ok((parameter_name, place.get_type().to_snapshot(self)?)) + } + // fn purify_expressions( + // &mut self, + // expressions: Vec, + // parameters: &BTreeMap, + // ) -> SpannedEncodingResult> { + // let mut computation = FootprintComputation::new(self, parameters); + // let mut purified_expressions = Vec::with_capacity(expressions.len()); + // for expression in expressions { + // purified_expressions.push(computation.fold_expression(expression)); + // } + // Ok(purified_expressions) + // } + + // FIXME: Delete. + // fn structural_invariant_to_pure_expression( + // &mut self, + // expressions: Vec, + // ty: &vir_mid::Type, + // decl: &vir_mid::type_decl::Struct, + // fields: &mut Vec, + // ) -> SpannedEncodingResult> { + // let self_variable = vir_mid::VariableDecl::self_variable(ty.clone()); + // let mut parameters = BTreeMap::new(); + // parameters.insert(self_variable, decl); + // let mut computation = FootprintComputation::new(self, ¶meters); + // let mut purified_expressions = Vec::with_capacity(expressions.len()); + // for expression in expressions { + // purified_expressions.push(computation.fold_expression(expression)); + // } + // let deref_fields = computation.into_deref_fields(); + // for (deref_field, _) in deref_fields { + // fields.push(vir_low::VariableDecl::new( + // deref_field.name, + // deref_field.ty.to_snapshot(self)?, + // )); + // } + // Ok(purified_expressions) + // } + + // fn structural_invariant_to_predicate_assertion( + // &mut self, + // expressions: Vec, + // ) -> SpannedEncodingResult> { + // let mut converter = Predicate::new(self); + // let mut new_expressions = Vec::with_capacity(expressions.len()); + // for expression in expressions { + // new_expressions.push(converter.fold_expression(expression)); + // } + // Ok(new_expressions) + // } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/footprint/mod.rs b/prusti-viper/src/encoder/middle/core_proof/footprint/mod.rs new file mode 100644 index 00000000000..95e926ab446 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/footprint/mod.rs @@ -0,0 +1,3 @@ +mod interface; + +pub(super) use self::interface::FootprintInterface; diff --git a/prusti-viper/src/encoder/middle/core_proof/heap/interface.rs b/prusti-viper/src/encoder/middle/core_proof/heap/interface.rs new file mode 100644 index 00000000000..00bc0a392fe --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/heap/interface.rs @@ -0,0 +1,167 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::lowerer::{DomainsLowererInterface, Lowerer}, +}; +use vir_crate::{ + common::expression::{BinaryOperationHelpers, QuantifierHelpers}, + low as vir_low, middle as vir_mid, + middle::operations::lifetimes::WithLifetimes, +}; + +const HEAP_DOMAIN_NAME: &str = "Heap$"; +const HEAP_LOOKUP_FUNCTION_NAME: &str = "heap$lookup"; +const HEAP_UPDATE_FUNCTION_NAME: &str = "heap$update"; +const HEAP_CHUNK_TYPE_NAME: &str = "HeapChunk$"; + +pub(in super::super) trait Private { + fn encode_heap_axioms(&mut self) -> SpannedEncodingResult<()>; +} + +impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { + fn encode_heap_axioms(&mut self) -> SpannedEncodingResult<()> { + if !self.heap_state.is_heap_encoded { + self.heap_state.is_heap_encoded = true; + + let position = vir_low::Position::default(); + use vir_low::macros::*; + let heap_type = self.heap_type()?; + let heap_chunk_type = self.heap_chunk_type()?; + var_decls!( + heap: {heap_type.clone()}, + address: Address, + chunk: {heap_chunk_type.clone()} + ); + let update_call = self.create_domain_func_app( + HEAP_DOMAIN_NAME, + HEAP_UPDATE_FUNCTION_NAME, + vec![ + heap.clone().into(), + address.clone().into(), + chunk.clone().into(), + ], + heap_type, + position, + )?; + { + let lookup_call = self.create_domain_func_app( + HEAP_DOMAIN_NAME, + HEAP_LOOKUP_FUNCTION_NAME, + vec![update_call.clone(), address.clone().into()], + heap_chunk_type.clone(), + position, + )?; + + // forall heap: Heap$, addr: Address, chunk: HeapChunk$ :: + // { heap$lookup(heap$update(heap, addr, chunk), addr) } + // heap$lookup(heap$update(heap, addr, chunk), addr) == chunk + let axiom_update_value = vir_low::DomainAxiomDecl { + comment: None, + name: "heap$update_value$axiom".to_string(), + body: QuantifierHelpers::forall( + vec![heap.clone(), address.clone(), chunk.clone()], + vec![vir_low::Trigger::new(vec![lookup_call.clone()])], + expr! { + [lookup_call] == chunk + }, + ), + }; + self.declare_axiom(HEAP_DOMAIN_NAME, axiom_update_value)?; + } + { + var_decls!(address2: Address); + let lookup_call_original = self.create_domain_func_app( + HEAP_DOMAIN_NAME, + HEAP_LOOKUP_FUNCTION_NAME, + vec![heap.clone().into(), address2.clone().into()], + heap_chunk_type.clone(), + position, + )?; + let lookup_call_updated = self.create_domain_func_app( + HEAP_DOMAIN_NAME, + HEAP_LOOKUP_FUNCTION_NAME, + vec![update_call, address2.clone().into()], + heap_chunk_type, + position, + )?; + // forall heap: Heap$, addr1: Address, addr2: Address, chunk: HeapChunk$ :: + // { heap$lookup(heap$update(heap, addr1, chunk), addr2) } + // addr1 != addr2 ==> + // heap$lookup(heap$update(heap, addr1, chunk), addr2) == heap$lookup(heap, addr2) + let axiom_preserve_value = vir_low::DomainAxiomDecl { + name: "heap$update_preserve_value$axiom".to_string(), + comment: None, + body: QuantifierHelpers::forall( + vec![heap, address.clone(), address2.clone(), chunk.clone()], + vec![vir_low::Trigger::new(vec![lookup_call_updated.clone()])], + expr! { + (address != address2) ==> + ([lookup_call_updated] == [lookup_call_original]) + }, + ), + }; + self.declare_axiom(HEAP_DOMAIN_NAME, axiom_preserve_value)?; + } + } + Ok(()) + } +} + +pub(in super::super) trait HeapInterface { + fn heap_lookup( + &mut self, + heap: vir_low::Expression, + address: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; + fn heap_update( + &mut self, + heap: vir_low::Expression, + address: vir_low::Expression, + value: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; + fn heap_chunk_type(&mut self) -> SpannedEncodingResult; + fn heap_type(&mut self) -> SpannedEncodingResult; +} + +impl<'p, 'v: 'p, 'tcx: 'v> HeapInterface for Lowerer<'p, 'v, 'tcx> { + fn heap_lookup( + &mut self, + heap: vir_low::Expression, + address: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + self.encode_heap_axioms()?; + let return_type = self.heap_chunk_type()?; + self.create_domain_func_app( + HEAP_DOMAIN_NAME, + HEAP_LOOKUP_FUNCTION_NAME, + vec![heap, address], + return_type, + position, + ) + } + fn heap_update( + &mut self, + heap: vir_low::Expression, + address: vir_low::Expression, + value: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + self.encode_heap_axioms()?; + let return_type = self.heap_type()?; + self.create_domain_func_app( + HEAP_DOMAIN_NAME, + HEAP_UPDATE_FUNCTION_NAME, + vec![heap, address, value], + return_type, + position, + ) + } + fn heap_chunk_type(&mut self) -> SpannedEncodingResult { + self.domain_type(HEAP_CHUNK_TYPE_NAME) + } + fn heap_type(&mut self) -> SpannedEncodingResult { + self.domain_type(HEAP_DOMAIN_NAME) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/heap/mod.rs b/prusti-viper/src/encoder/middle/core_proof/heap/mod.rs new file mode 100644 index 00000000000..0548285daf0 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/heap/mod.rs @@ -0,0 +1,4 @@ +mod interface; +mod state; + +pub(super) use self::{interface::HeapInterface, state::HeapState}; diff --git a/prusti-viper/src/encoder/middle/core_proof/heap/state.rs b/prusti-viper/src/encoder/middle/core_proof/heap/state.rs new file mode 100644 index 00000000000..2faba742830 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/heap/state.rs @@ -0,0 +1,4 @@ +#[derive(Default)] +pub(in super::super) struct HeapState { + pub(super) is_heap_encoded: bool, +} diff --git a/prusti-viper/src/encoder/middle/core_proof/interface.rs b/prusti-viper/src/encoder/middle/core_proof/interface.rs index 1821c52762d..b50bf426cf2 100644 --- a/prusti-viper/src/encoder/middle/core_proof/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/interface.rs @@ -42,29 +42,62 @@ impl<'v, 'tcx: 'v> MidCoreProofEncoderInterface<'tcx> for super::super::super::E ); return Ok(()); } - let procedure = self.encode_procedure_core_proof(proc_def_id, check_mode)?; - let super::lowerer::LoweringResult { - procedures, - domains, - functions, - predicates, - methods, - } = super::lowerer::lower_procedure(self, proc_def_id, procedure)?; - let mut program = vir_low::Program { - name: self.env().name.get_absolute_item_name(proc_def_id), - check_mode, - procedures, - domains, - predicates, - functions, - methods, - }; - if config::inline_caller_for() { - super::transformations::inline_functions::inline_caller_for(&mut program); + for procedure in self.encode_procedure_core_proof(proc_def_id, check_mode)? { + let name = procedure.name.clone(); + let super::lowerer::LoweringResult { + procedures, + domains, + functions, + predicates, + methods, + predicates_info, + } = super::lowerer::lower_procedure(self, proc_def_id, procedure)?; + let mut program = vir_low::Program { + name, + // name: self.env().name.get_absolute_item_name(proc_def_id), + check_mode, + procedures, + domains, + predicates, + functions, + methods, + }; + let source_filename = self.env().name.source_file_name(); + if config::trace_with_symbolic_execution() || config::custom_heap_encoding() { + program = super::transformations::desugar_method_calls::desugar_method_calls( + &source_filename, + program, + ); + program = super::transformations::desugar_conditionals::desugar_conditionals( + &source_filename, + program, + ); + } + if config::trace_with_symbolic_execution() { + program = + super::transformations::symbolic_execution::purify_with_symbolic_execution( + &source_filename, + program, + predicates_info.clone(), + )?; + } + if config::custom_heap_encoding() { + super::transformations::custom_heap_encoding::custom_heap_encoding( + self, + &mut program, + predicates_info, + )?; + } + if config::inline_caller_for() { + super::transformations::inline_functions::inline_caller_for( + &source_filename, + &mut program, + ); + } + self.mid_core_proof_encoder_state + .encoded_programs + .push(program); } - self.mid_core_proof_encoder_state - .encoded_programs - .push(program); Ok(()) } @@ -99,6 +132,7 @@ impl<'v, 'tcx: 'v> MidCoreProofEncoderInterface<'tcx> for super::super::super::E functions, predicates, methods, + predicates_info: _, } = super::lowerer::lower_type(self, def_id, ty, check_copy)?; assert!(procedures.is_empty()); let mut program = vir_low::Program { @@ -111,7 +145,11 @@ impl<'v, 'tcx: 'v> MidCoreProofEncoderInterface<'tcx> for super::super::super::E methods, }; if config::inline_caller_for() { - super::transformations::inline_functions::inline_caller_for(&mut program); + let source_filename = self.env().name.source_file_name(); + super::transformations::inline_functions::inline_caller_for( + &source_filename, + &mut program, + ); } self.mid_core_proof_encoder_state .encoded_programs diff --git a/prusti-viper/src/encoder/middle/core_proof/into_low/cfg.rs b/prusti-viper/src/encoder/middle/core_proof/into_low/cfg.rs index e413408a0be..13afbfcebb4 100644 --- a/prusti-viper/src/encoder/middle/core_proof/into_low/cfg.rs +++ b/prusti-viper/src/encoder/middle/core_proof/into_low/cfg.rs @@ -6,18 +6,27 @@ use crate::encoder::{ addresses::AddressesInterface, block_markers::BlockMarkersInterface, builtin_methods::{BuiltinMethodCallsInterface, BuiltinMethodsInterface, CallContext}, + labels::LabelsInterface, lifetimes::LifetimesInterface, lowerer::{Lowerer, VariablesLowererInterface}, places::PlacesInterface, + pointers::PointersInterface, predicates::{PredicatesMemoryBlockInterface, PredicatesOwnedInterface}, references::ReferencesInterface, snapshots::{ - IntoProcedureBoolExpression, IntoProcedureFinalSnapshot, IntoProcedureSnapshot, - SnapshotValidityInterface, SnapshotValuesInterface, SnapshotVariablesInterface, + IntoProcedureAssertion, IntoProcedureBoolExpression, IntoProcedureFinalSnapshot, + IntoProcedureSnapshot, IntoSnapshotLowerer, PredicateKind, + ProcedureExpressionToSnapshot, ProcedureSnapshot, SnapshotValidityInterface, + SnapshotValuesInterface, SnapshotVariablesInterface, }, + type_layouts::TypeLayoutsInterface, }, }; use vir_crate::{ + common::{ + check_mode::CheckMode, + expression::{BinaryOperationHelpers, QuantifierHelpers}, + }, low::{self as vir_low}, middle::{self as vir_mid, operations::ty::Typed}, }; @@ -43,11 +52,15 @@ impl IntoLow for vir_mid::Statement { use vir_low::{macros::*, Statement}; match self { Self::Comment(statement) => Ok(vec![Statement::comment(statement.comment)]), - Self::OldLabel(label) => { - lowerer.save_old_label(label.name)?; - Ok(Vec::new()) + Self::OldLabel(statement) => { + lowerer.save_old_label(statement.name.clone())?; + lowerer.save_custom_label(statement.name.clone())?; + Ok(vec![vir_low::Statement::label( + statement.name, + statement.position, + )]) } - Self::Inhale(statement) => { + Self::InhalePredicate(statement) => { if let vir_mid::Predicate::OwnedNonAliased(owned) = &statement.predicate { lowerer.mark_owned_non_aliased_as_unfolded(owned.place.get_type())?; } @@ -56,7 +69,7 @@ impl IntoLow for vir_mid::Statement { statement.position, )]) } - Self::Exhale(statement) => { + Self::ExhalePredicate(statement) => { if let vir_mid::Predicate::OwnedNonAliased(owned) = &statement.predicate { lowerer.mark_owned_non_aliased_as_unfolded(owned.place.get_type())?; } @@ -92,13 +105,41 @@ impl IntoLow for vir_mid::Statement { )?; Ok(statements) } - Self::Assume(statement) => Ok(vec![Statement::assume( - statement.expression.to_procedure_bool_expression(lowerer)?, + Self::HeapHavoc(statement) => { + let new_version = lowerer.new_heap_variable_version(statement.position)?; + Ok(vec![vir_low::Statement::comment(format!( + "new heap version: {}", + new_version + ))]) + } + Self::InhaleExpression(statement) => Ok(vec![Statement::inhale( + statement.expression.to_procedure_assertion(lowerer)?, statement.position, )]), + Self::ExhaleExpression(statement) => { + let assertion = statement.expression.to_procedure_assertion(lowerer)?; + let exhale = Statement::exhale(assertion, statement.position); + Ok(vec![exhale]) + } + Self::Assume(statement) => { + assert!( + statement.expression.is_pure(), + "must be pure: {}", + statement.expression + ); + Ok(vec![Statement::assume( + statement.expression.to_procedure_assertion(lowerer)?, + statement.position, + )]) + } Self::Assert(statement) => { + assert!( + statement.expression.is_pure(), + "must be pure: {}", + statement.expression + ); let assert = Statement::assert( - statement.expression.to_procedure_bool_expression(lowerer)?, + statement.expression.to_procedure_assertion(lowerer)?, statement.position, ); let low_statement = if let Some(condition) = statement.condition { @@ -122,8 +163,12 @@ impl IntoLow for vir_mid::Statement { lowerer.mark_owned_non_aliased_as_unfolded(ty)?; let place = lowerer.encode_expression_as_place(&statement.place)?; let root_address = lowerer.extract_root_address(&statement.place)?; - let snapshot = statement.place.to_procedure_snapshot(lowerer)?; - let predicate = lowerer.owned_non_aliased( + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::Owned); + let snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; + // let snapshot = statement.place.to_procedure_snapshot(lowerer)?; + let predicate = lowerer.owned_non_aliased_predicate( CallContext::Procedure, ty, ty, @@ -132,6 +177,7 @@ impl IntoLow for vir_mid::Statement { snapshot, None, )?; + assert!(predicate.is_predicate_access_predicate()); let mut low_statement = vir_low::Statement::fold_no_pos(predicate); if let Some(condition) = statement.condition { let low_condition = lowerer.lower_block_marker_condition(condition)?; @@ -146,10 +192,14 @@ impl IntoLow for vir_mid::Statement { Self::UnfoldOwned(statement) => { let ty = statement.place.get_type(); lowerer.mark_owned_non_aliased_as_unfolded(ty)?; - let place = lowerer.encode_expression_as_place(&statement.place)?; let root_address = lowerer.extract_root_address(&statement.place)?; - let snapshot = statement.place.to_procedure_snapshot(lowerer)?; - let predicate = lowerer.owned_non_aliased( + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::Owned); + let snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; + // let snapshot = statement.place.to_procedure_snapshot(lowerer)?; + let place = lowerer.encode_expression_as_place(&statement.place)?; + let predicate = lowerer.owned_non_aliased_predicate( CallContext::Procedure, ty, ty, @@ -158,6 +208,7 @@ impl IntoLow for vir_mid::Statement { snapshot, None, )?; + assert!(predicate.is_predicate_access_predicate()); let mut low_statement = vir_low::Statement::unfold_no_pos(predicate); if let Some(condition) = statement.condition { let low_condition = lowerer.lower_block_marker_condition(condition)?; @@ -176,9 +227,15 @@ impl IntoLow for vir_mid::Statement { lowerer.encode_lifetime_const_into_procedure_variable(statement.lifetime)?; let place = lowerer.encode_expression_as_place(&statement.place)?; let root_address = lowerer.extract_root_address(&statement.place)?; - let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; + // let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; let predicate = if statement.uniqueness.is_shared() { - lowerer.frac_ref( + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::FracRef { + lifetime: lifetime.clone().into(), + }); + let current_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; + lowerer.frac_ref_predicate( CallContext::Procedure, ty, ty, @@ -186,10 +243,25 @@ impl IntoLow for vir_mid::Statement { root_address, current_snapshot, lifetime.into(), + None, )? } else { - let final_snapshot = statement.place.to_procedure_final_snapshot(lowerer)?; - lowerer.unique_ref( + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::UniqueRef { + lifetime: lifetime.clone().into(), + is_final: false, + }); + let current_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::UniqueRef { + lifetime: lifetime.clone().into(), + is_final: false, + }); + let final_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, true)?; + // let final_snapshot = statement.place.to_procedure_final_snapshot(lowerer)?; + lowerer.unique_ref_predicate( CallContext::Procedure, ty, ty, @@ -198,6 +270,7 @@ impl IntoLow for vir_mid::Statement { current_snapshot, final_snapshot, lifetime.into(), + None, )? }; let mut low_statement = vir_low::Statement::fold_no_pos(predicate); @@ -218,9 +291,15 @@ impl IntoLow for vir_mid::Statement { lowerer.encode_lifetime_const_into_procedure_variable(statement.lifetime)?; let place = lowerer.encode_expression_as_place(&statement.place)?; let root_address = lowerer.extract_root_address(&statement.place)?; - let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; + // let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; let predicate = if statement.uniqueness.is_shared() { - lowerer.frac_ref( + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::FracRef { + lifetime: lifetime.clone().into(), + }); + let current_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; + lowerer.frac_ref_predicate( CallContext::Procedure, ty, ty, @@ -228,10 +307,25 @@ impl IntoLow for vir_mid::Statement { root_address, current_snapshot, lifetime.into(), + None, )? } else { - let final_snapshot = statement.place.to_procedure_final_snapshot(lowerer)?; - lowerer.unique_ref( + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::UniqueRef { + lifetime: lifetime.clone().into(), + is_final: false, + }); + let current_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::UniqueRef { + lifetime: lifetime.clone().into(), + is_final: false, + }); + let final_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, true)?; + // let final_snapshot = statement.place.to_procedure_final_snapshot(lowerer)?; + lowerer.unique_ref_predicate( CallContext::Procedure, ty, ty, @@ -240,6 +334,7 @@ impl IntoLow for vir_mid::Statement { current_snapshot, final_snapshot, lifetime.into(), + None, )? }; let mut low_statement = vir_low::Statement::unfold_no_pos(predicate); @@ -303,6 +398,26 @@ impl IntoLow for vir_mid::Statement { }; Ok(vec![low_statement]) } + Self::JoinRange(statement) => { + let ty = statement.address.get_type(); + let vir_mid::Type::Pointer(pointer_type) = ty else { + unreachable!() + }; + let target_type = &*pointer_type.target_type; + lowerer.encode_memory_block_range_join_method(target_type)?; + let pointer_value = statement.address.to_procedure_snapshot(lowerer)?; + let start_address = + lowerer.pointer_address(ty, pointer_value, statement.position)?; + let start_index = statement.start_index.to_procedure_snapshot(lowerer)?; + let end_index = statement.end_index.to_procedure_snapshot(lowerer)?; + let low_statement = stmtp! { + statement.position => + call memory_block_range_join( + [start_address], [start_index], [end_index] + ) + }; + Ok(vec![low_statement]) + } Self::SplitBlock(statement) => { let ty = statement.place.get_type(); lowerer.encode_memory_block_split_method(ty)?; @@ -353,12 +468,36 @@ impl IntoLow for vir_mid::Statement { }; Ok(vec![low_statement]) } + Self::SplitRange(statement) => { + let ty = statement.address.get_type(); + let vir_mid::Type::Pointer(pointer_type) = ty else { + unreachable!() + }; + let target_type = &*pointer_type.target_type; + lowerer.encode_memory_block_range_split_method(target_type)?; + let pointer_value = statement.address.to_procedure_snapshot(lowerer)?; + let start_address = + lowerer.pointer_address(ty, pointer_value, statement.position)?; + let start_index = statement.start_index.to_procedure_snapshot(lowerer)?; + let end_index = statement.end_index.to_procedure_snapshot(lowerer)?; + let low_statement = stmtp! { + statement.position => + call memory_block_range_split( + [start_address], [start_index], [end_index] + ) + }; + Ok(vec![low_statement]) + } Self::ConvertOwnedIntoMemoryBlock(statement) => { let ty = statement.place.get_type(); lowerer.encode_into_memory_block_method(ty)?; let place = lowerer.encode_expression_as_place(&statement.place)?; let root_address = lowerer.extract_root_address(&statement.place)?; - let snapshot = statement.place.to_procedure_snapshot(lowerer)?; + // let snapshot = statement.place.to_procedure_snapshot(lowerer)?; + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::Owned); + let snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; let low_condition = statement .condition .map(|condition| lowerer.lower_block_marker_condition(condition)) @@ -403,6 +542,7 @@ impl IntoLow for vir_mid::Statement { current_snapshot, final_snapshot, deref_lifetime, + None, // FIXME )? } else { lowerer.frac_ref( @@ -448,6 +588,39 @@ impl IntoLow for vir_mid::Statement { }; Ok(vec![low_statement]) } + Self::RestoreRawBorrowed(statement) => { + let ty = statement.restored_place.get_type(); + lowerer.encode_restore_raw_borrowed_method(ty)?; + let borrowing_place_parent = statement.borrowing_place.get_parent_ref().unwrap(); + let borrowing_snapshot = borrowing_place_parent.to_procedure_snapshot(lowerer)?; + let borrowing_address = lowerer.pointer_address( + borrowing_place_parent.get_type(), + borrowing_snapshot, + statement.position, + )?; + let restored_place = + lowerer.encode_expression_as_place(&statement.restored_place)?; + let restored_root_address = + lowerer.extract_root_address(&statement.restored_place)?; + let snapshot = statement.borrowing_place.to_procedure_snapshot(lowerer)?; + let mut statements = vec![lowerer.call_restore_raw_borrowed_method( + CallContext::Procedure, + ty, + ty, + statement.position, + borrowing_address, + restored_place, + restored_root_address, + snapshot.clone(), + )?]; + lowerer.encode_snapshot_update( + &mut statements, + &statement.restored_place, + snapshot, + statement.position, + )?; + Ok(statements) + } Self::MovePlace(statement) => { // TODO: Remove code duplication with Self::CopyPlace let target_ty = statement.target.get_type(); @@ -460,8 +633,19 @@ impl IntoLow for vir_mid::Statement { let target_root_address = lowerer.extract_root_address(&statement.target)?; let source_place = lowerer.encode_expression_as_place(&statement.source)?; let source_root_address = lowerer.extract_root_address(&statement.source)?; - let source_snapshot = statement.source.to_procedure_snapshot(lowerer)?; - let mut statements = vec![lowerer.call_move_place_method( + // let source_snapshot = statement.source.to_procedure_snapshot(lowerer)?; + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::Owned); + let source_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.source, false)?; + let mut statements = Vec::new(); + lowerer.encode_snapshot_update( + &mut statements, + &statement.target, + source_snapshot.clone(), + statement.position, + )?; + statements.push(lowerer.call_move_place_method( CallContext::Procedure, target_ty, target_ty, @@ -471,13 +655,7 @@ impl IntoLow for vir_mid::Statement { source_place, source_root_address, source_snapshot.clone(), - )?]; - lowerer.encode_snapshot_update( - &mut statements, - &statement.target, - source_snapshot, - statement.position, - )?; + )?); Ok(statements) } Self::CopyPlace(statement) => { @@ -496,8 +674,19 @@ impl IntoLow for vir_mid::Statement { } else { vir_low::Expression::full_permission() }; - let source_snapshot = statement.source.to_procedure_snapshot(lowerer)?; - let mut statements = vec![lowerer.call_copy_place_method( + // let source_snapshot = statement.source.to_procedure_snapshot(lowerer)?; + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::Owned); + let source_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.source, false)?; + let mut statements = Vec::new(); + lowerer.encode_snapshot_update( + &mut statements, + &statement.target, + source_snapshot.clone(), + statement.position, + )?; + statements.push(lowerer.call_copy_place_method( CallContext::Procedure, target_ty, target_ty, @@ -508,13 +697,7 @@ impl IntoLow for vir_mid::Statement { source_root_address, source_snapshot.clone(), source_permission_amount, - )?]; - lowerer.encode_snapshot_update( - &mut statements, - &statement.target, - source_snapshot, - statement.position, - )?; + )?); Ok(statements) } Self::WritePlace(statement) => { @@ -574,12 +757,7 @@ impl IntoLow for vir_mid::Statement { let variant_index = variant_place.clone().unwrap_variant().variant_index; let union_place = variant_place.get_parent_ref().unwrap(); let mut statements = Vec::new(); - lowerer.encode_snapshot_havoc( - &mut statements, - union_place, - statement.position, - None, - )?; + lowerer.encode_snapshot_havoc(&mut statements, union_place, statement.position)?; let snapshot = union_place.to_procedure_snapshot(lowerer)?; let discriminant = lowerer.obtain_enum_discriminant( snapshot, @@ -606,6 +784,48 @@ impl IntoLow for vir_mid::Statement { )?; Ok(stmts) } + Self::StashRange(statement) => { + let ty = statement.address.get_type(); + let pointer_value = statement.address.to_procedure_snapshot(lowerer)?; + let start_index = statement.start_index.to_procedure_snapshot(lowerer)?; + let end_index = statement.end_index.to_procedure_snapshot(lowerer)?; + let mut statements = Vec::new(); + lowerer.encode_stash_range_call( + &mut statements, + ty, + pointer_value, + start_index, + end_index, + statement.label, + statement.position, + )?; + Ok(statements) + } + Self::StashRangeRestore(statement) => { + assert_eq!( + statement.old_address.get_type(), + statement.new_address.get_type() + ); + let ty = statement.old_address.get_type(); + let old_pointer_value = statement.old_address.to_procedure_snapshot(lowerer)?; + let old_start_index = statement.old_start_index.to_procedure_snapshot(lowerer)?; + let old_end_index = statement.old_end_index.to_procedure_snapshot(lowerer)?; + let new_address = statement.new_address.to_procedure_snapshot(lowerer)?; + let new_start_index = statement.new_start_index.to_procedure_snapshot(lowerer)?; + let mut statements = Vec::new(); + lowerer.encode_restore_stash_range_call( + &mut statements, + ty, + old_pointer_value, + old_start_index, + old_end_index, + statement.old_label, + new_address, + new_start_index, + statement.position, + )?; + Ok(statements) + } Self::NewLft(statement) => { let targets = vec![vir_low::Expression::local_no_pos( statement.target.to_procedure_snapshot(lowerer)?, @@ -651,6 +871,7 @@ impl IntoLow for vir_mid::Statement { current_snapshot.clone(), final_snapshot.clone(), lifetime.into(), + None, // FIXME: This should be a proper value )?; lowerer.mark_unique_ref_as_used(ty)?; let mut statements = vec![ @@ -731,10 +952,16 @@ impl IntoLow for vir_mid::Statement { .unwrap() .to_procedure_snapshot(lowerer)?, ); - let statements = vec![Statement::assign( - lowerer - .new_snapshot_variable_version(&statement.target, statement.position)?, - value, + let statements = vec![Statement::assume( + vir_low::Expression::equals( + lowerer + .new_snapshot_variable_version( + &statement.target, + statement.position, + )? + .into(), + value, + ), statement.position, )]; Ok(statements) @@ -797,20 +1024,18 @@ impl IntoLow for vir_mid::Statement { .to_procedure_snapshot(lowerer)?; let place = lowerer.encode_expression_as_place(&statement.place)?; let address = lowerer.extract_root_address(&statement.place)?; - let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; let targets = vec![statement .predicate_permission_amount .to_procedure_snapshot(lowerer)? .into()]; + let mut arguments = vec![lifetime.into(), perm_amount, place, address]; + if lowerer.check_mode.unwrap() == CheckMode::PurificationSoudness { + let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; + arguments.push(current_snapshot); + } Ok(vec![Statement::method_call( method_name!(frac_bor_atomic_acc), - vec![ - lifetime.into(), - perm_amount, - place, - address, - current_snapshot, - ], + arguments, targets, statement.position, )]) @@ -824,20 +1049,25 @@ impl IntoLow for vir_mid::Statement { .to_procedure_snapshot(lowerer)?; let place = lowerer.encode_expression_as_place(&statement.place)?; let root_address = lowerer.extract_root_address(&statement.place)?; - let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; let tmp_frac_ref_perm = statement .predicate_permission_amount .to_procedure_snapshot(lowerer)?; - let owned_predicate = lowerer.owned_non_aliased( + let owned_predicate = lowerer.owned_non_aliased_predicate( CallContext::Procedure, ty, ty, place.clone(), root_address.clone(), - current_snapshot.clone(), + true.into(), // FIXME: Not used. Some(tmp_frac_ref_perm.into()), )?; - let frac_predicate = lowerer.frac_ref( + let current_snapshot = + if lowerer.check_mode.unwrap() == CheckMode::PurificationSoudness { + Some(statement.place.to_procedure_snapshot(lowerer)?) + } else { + None + }; + let frac_predicate = lowerer.frac_ref_opt( CallContext::Procedure, ty, ty, @@ -866,8 +1096,22 @@ impl IntoLow for vir_mid::Statement { .to_procedure_snapshot(lowerer)?; let place = lowerer.encode_expression_as_place(&statement.place)?; let address = lowerer.extract_root_address(&statement.place)?; - let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; - let final_snapshot = statement.place.to_procedure_final_snapshot(lowerer)?; + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::UniqueRef { + lifetime: lifetime.clone().into(), + is_final: false, + }); + let current_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; + // let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; + // let final_snapshot = statement.place.to_procedure_final_snapshot(lowerer)?; + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::UniqueRef { + lifetime: lifetime.clone().into(), + is_final: true, + }); + let final_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; let statements = vec![stmtp! { statement.position => call open_mut_ref( lifetime, @@ -890,8 +1134,19 @@ impl IntoLow for vir_mid::Statement { .to_procedure_snapshot(lowerer)?; let place = lowerer.encode_expression_as_place(&statement.place)?; let address = lowerer.extract_root_address(&statement.place)?; - let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; - let final_snapshot = statement.place.to_procedure_final_snapshot(lowerer)?; + // let current_snapshot = statement.place.to_procedure_snapshot(lowerer)?; + // let final_snapshot = statement.place.to_procedure_final_snapshot(lowerer)?; + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::Owned); + let current_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; + let mut place_encoder = + ProcedureExpressionToSnapshot::for_place(PredicateKind::UniqueRef { + lifetime: lifetime.clone().into(), + is_final: true, + }); + let final_snapshot = + place_encoder.expression_to_snapshot(lowerer, &statement.place, false)?; let statements = vec![stmtp! { statement.position => call close_mut_ref( lifetime, @@ -977,7 +1232,7 @@ impl IntoLow for vir_mid::Predicate { lowerer.encode_memory_block_predicate()?; let place = lowerer.encode_expression_as_place_address(&predicate.place)?; let size = predicate.size.to_procedure_snapshot(lowerer)?; - expr! { acc(MemoryBlock([place], [size]))}.set_default_position(predicate.position) + lowerer.encode_memory_block_stack_acc(place, size, predicate.position)? } Predicate::MemoryBlockStackDrop(predicate) => { let place = lowerer.encode_expression_as_place_address(&predicate.place)?; @@ -987,6 +1242,9 @@ impl IntoLow for vir_mid::Predicate { Predicate::MemoryBlockHeap(predicate) => { unimplemented!("predicate: {}", predicate); } + Predicate::MemoryBlockHeapRange(predicate) => { + unimplemented!("predicate: {}", predicate); + } Predicate::MemoryBlockHeapDrop(predicate) => { unimplemented!("predicate: {}", predicate); } @@ -1011,6 +1269,8 @@ impl IntoLow for vir_mid::Predicate { [valid] } } + Predicate::OwnedRange(_) => todo!(), + Predicate::OwnedSet(_) => todo!(), }; Ok(result) } diff --git a/prusti-viper/src/encoder/middle/core_proof/labels/interface.rs b/prusti-viper/src/encoder/middle/core_proof/labels/interface.rs new file mode 100644 index 00000000000..c9a44c173b4 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/labels/interface.rs @@ -0,0 +1,21 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::lowerer::{DomainsLowererInterface, Lowerer}, +}; +use vir_crate::{ + common::expression::{BinaryOperationHelpers, QuantifierHelpers}, + low as vir_low, middle as vir_mid, + middle::operations::lifetimes::WithLifetimes, +}; + +pub(in super::super) trait LabelsInterface { + fn save_custom_label(&mut self, label: String) -> SpannedEncodingResult<()>; +} + +impl<'p, 'v: 'p, 'tcx: 'v> LabelsInterface for Lowerer<'p, 'v, 'tcx> { + fn save_custom_label(&mut self, label: String) -> SpannedEncodingResult<()> { + let label = vir_low::Label::new(label); + assert!(self.labels_state.labels.insert(label)); + Ok(()) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/labels/mod.rs b/prusti-viper/src/encoder/middle/core_proof/labels/mod.rs new file mode 100644 index 00000000000..381f69bf617 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/labels/mod.rs @@ -0,0 +1,4 @@ +mod interface; +mod state; + +pub(super) use self::{interface::LabelsInterface, state::LabelsState}; diff --git a/prusti-viper/src/encoder/middle/core_proof/labels/state.rs b/prusti-viper/src/encoder/middle/core_proof/labels/state.rs new file mode 100644 index 00000000000..85f7c0702ad --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/labels/state.rs @@ -0,0 +1,13 @@ +use std::collections::BTreeSet; +use vir_crate::low as vir_low; + +#[derive(Default)] +pub(in super::super) struct LabelsState { + pub(super) labels: BTreeSet, +} + +impl LabelsState { + pub(crate) fn destruct(self) -> Vec { + self.labels.into_iter().collect() + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/lifetimes/interface.rs b/prusti-viper/src/encoder/middle/core_proof/lifetimes/interface.rs index c4c01f4bf16..f8d959596e0 100644 --- a/prusti-viper/src/encoder/middle/core_proof/lifetimes/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/lifetimes/interface.rs @@ -339,6 +339,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> LifetimesInterface for Lowerer<'p, 'v, 'tcx> { self.lifetimes_state.is_lifetime_token_encoded = true; let predicate = vir_low::PredicateDecl::new( "LifetimeToken", + vir_low::PredicateKind::WithoutSnapshotFrac, vec![vir_low::VariableDecl::new( "lifetime", self.lifetime_type()?, @@ -348,6 +349,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> LifetimesInterface for Lowerer<'p, 'v, 'tcx> { self.declare_predicate(predicate)?; let predicate = vir_low::PredicateDecl::new( "DeadLifetimeToken", + vir_low::PredicateKind::WithoutSnapshotWhole, vec![vir_low::VariableDecl::new( "lifetime", self.lifetime_type()?, diff --git a/prusti-viper/src/encoder/middle/core_proof/lowerer/functions/interface.rs b/prusti-viper/src/encoder/middle/core_proof/lowerer/functions/interface.rs index 3c428d40d8a..f7340571785 100644 --- a/prusti-viper/src/encoder/middle/core_proof/lowerer/functions/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/lowerer/functions/interface.rs @@ -2,10 +2,13 @@ use crate::encoder::{ errors::SpannedEncodingResult, high::pure_functions::HighPureFunctionEncoderInterface, middle::core_proof::{ + footprint::FootprintInterface, function_gas::FunctionGasInterface, lowerer::{DomainsLowererInterface, Lowerer}, snapshots::{ - IntoPureBoolExpression, IntoPureSnapshot, IntoSnapshot, SnapshotValidityInterface, + FramedExpressionToSnapshot, IntoFramedPureSnapshot, IntoPureBoolExpression, + IntoPureSnapshot, IntoSnapshot, IntoSnapshotLowerer, ProcedureExpressionToSnapshot, + SnapshotValidityInterface, }, types::TypesInterface, }, @@ -124,7 +127,19 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { return_type, ); let body = if let Some(body) = function_decl.body { - expr! { ([call.clone()] == [body.to_pure_snapshot(self)?]) } + eprintln!("body: {}", body); + let framing_variables = &function_decl.parameters; + for variable in framing_variables { + eprintln!("variable: {}", variable); + } + // let deref_fields = self.framing_variable_deref_fields(framing_variables)?; + // for (e, name, ty) in &deref_fields { + // eprintln!("field: {} {} {}", e, name, ty); + // } + let mut body_encoder = + FramedExpressionToSnapshot::for_function_body(framing_variables); + let encoded_body = body_encoder.expression_to_snapshot(self, &body, false)?; + expr! { ([call.clone()] == [encoded_body]) } } else { true.into() }; diff --git a/prusti-viper/src/encoder/middle/core_proof/lowerer/mod.rs b/prusti-viper/src/encoder/middle/core_proof/lowerer/mod.rs index 341045a7d4f..4d91aac8f46 100644 --- a/prusti-viper/src/encoder/middle/core_proof/lowerer/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/lowerer/mod.rs @@ -3,23 +3,31 @@ use self::{ predicates::PredicatesLowererState, variables::VariablesLowererState, }; use super::{ + addresses::AddressState, adts::AdtsState, + block_markers::BlockMarkersInterface, builtin_methods::BuiltinMethodsState, compute_address::ComputeAddressState, + heap::HeapState, into_low::IntoLow, + labels::LabelsState, lifetimes::LifetimesState, places::PlacesState, - predicates::{PredicatesMemoryBlockInterface, PredicatesOwnedInterface, PredicatesState}, + predicates::{ + PredicateInfo, PredicatesMemoryBlockInterface, PredicatesOwnedInterface, PredicatesState, + }, snapshots::{SnapshotVariablesInterface, SnapshotsState}, types::TypesState, }; use crate::encoder::{ - errors::SpannedEncodingResult, middle::core_proof::builtin_methods::BuiltinMethodsInterface, + errors::{ErrorCtxt, SpannedEncodingResult}, + middle::core_proof::builtin_methods::BuiltinMethodsInterface, + mir::errors::ErrorInterface, Encoder, }; use prusti_rustc_interface::hir::def_id::DefId; use rustc_hash::{FxHashMap, FxHashSet}; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use vir_crate::{ common::{cfg::Cfg, check_mode::CheckMode, graphviz::ToGraphviz}, low::{self as vir_low, operations::ty::Typed}, @@ -44,6 +52,7 @@ pub(super) struct LoweringResult { pub(super) functions: Vec, pub(super) predicates: Vec, pub(super) methods: Vec, + pub(super) predicates_info: BTreeMap, } pub(super) fn lower_procedure<'p, 'v: 'p, 'tcx: 'v>( @@ -102,6 +111,9 @@ pub(super) struct Lowerer<'p, 'v: 'p, 'tcx: 'v> { pub(super) adts_state: AdtsState, pub(super) lifetimes_state: LifetimesState, pub(super) places_state: PlacesState, + pub(super) heap_state: HeapState, + pub(super) address_state: AddressState, + pub(super) labels_state: LabelsState, } impl<'p, 'v: 'p, 'tcx: 'v> Lowerer<'p, 'v, 'tcx> { @@ -123,6 +135,9 @@ impl<'p, 'v: 'p, 'tcx: 'v> Lowerer<'p, 'v, 'tcx> { adts_state: Default::default(), lifetimes_state: Default::default(), places_state: Default::default(), + heap_state: Default::default(), + address_state: Default::default(), + labels_state: Default::default(), } } @@ -131,6 +146,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> Lowerer<'p, 'v, 'tcx> { def_id: DefId, mut procedure: vir_mid::ProcedureDecl, ) -> SpannedEncodingResult { + assert!( + !procedure.position.is_default(), + "procedure {def_id:?} without position" + ); self.def_id = Some(def_id); let mut basic_blocks_map = BTreeMap::new(); let mut basic_block_edges = BTreeMap::new(); @@ -142,22 +161,26 @@ impl<'p, 'v: 'p, 'tcx: 'v> Lowerer<'p, 'v, 'tcx> { self.set_current_block_for_snapshots(label, &predecessors, &mut basic_block_edges)?; let basic_block = procedure.basic_blocks.remove(label).unwrap(); let marker = self.create_block_marker(label)?; - marker_initialisation.push(vir_low::Statement::assign_no_pos( + marker_initialisation.push(vir_low::Statement::assign( marker.clone(), false.into(), + procedure.position, )); let mut statements = vec![ - vir_low::Statement::assign_no_pos(marker.clone(), true.into()), + vir_low::Statement::assign(marker.clone(), true.into(), procedure.position), // We need to use a function call here because Silicon optimizes // out assignments to pure variables and our Z3 wrapper does not // see them. - vir_low::Statement::log_event(self.create_domain_func_app( - "MarkerCalls", - format!("basic_block_marker${}", marker.name), - vec![], - vir_low::Type::Bool, - Default::default(), - )?), + vir_low::Statement::log_event( + self.create_domain_func_app( + "MarkerCalls", + format!("basic_block_marker${}", marker.name), + vec![], + vir_low::Type::Bool, + procedure.position, + )?, + procedure.position, + ), ]; for statement in basic_block.statements { statements.extend(statement.into_low(&mut self)?); @@ -170,52 +193,98 @@ impl<'p, 'v: 'p, 'tcx: 'v> Lowerer<'p, 'v, 'tcx> { std::mem::swap(entry_block_statements, &mut marker_initialisation); entry_block_statements.extend(marker_initialisation); - let mut basic_blocks = Vec::new(); + let mut basic_blocks = BTreeMap::new(); for basic_block_id in traversal_order { let (statements, mut successor) = basic_blocks_map.remove(&basic_block_id).unwrap(); let label = basic_block_id.clone().into_low(&mut self)?; if let Some(intermediate_blocks) = basic_block_edges.remove(&basic_block_id) { - for (successor_label, successor_statements) in intermediate_blocks { + for (successor_label, equalities) in intermediate_blocks { let successor_label = successor_label.into_low(&mut self)?; let intermediate_block_label = vir_low::Label::new(format!( "label__from__{}__to__{}", label.name, successor_label.name )); successor.replace_label(&successor_label, intermediate_block_label.clone()); - basic_blocks.push(vir_low::BasicBlock { - label: intermediate_block_label, - statements: successor_statements, - successor: vir_low::Successor::Goto(successor_label), - }); + let mut successor_statements = Vec::new(); + for (variable_name, ty, position, old_version, new_version) in equalities { + let new_variable = self.create_snapshot_variable_low( + &variable_name, + ty.clone(), + new_version, + )?; + let old_variable = self.create_snapshot_variable_low( + &variable_name, + ty.clone(), + old_version, + )?; + let position = self.encoder.change_error_context( + // FIXME: Get a more precise span. + position, + ErrorCtxt::Unexpected, + ); + let statement = vir_low::macros::stmtp! { + position => assume (new_variable == old_variable) + }; + successor_statements.push(statement); + } + basic_blocks.insert( + intermediate_block_label, + vir_low::BasicBlock { + statements: successor_statements, + successor: vir_low::Successor::Goto(successor_label), + }, + ); } } - basic_blocks.push(vir_low::BasicBlock { + basic_blocks.insert( label, - statements, - successor, - }); + vir_low::BasicBlock { + statements, + successor, + }, + ); } + let entry = procedure.entry.clone().into_low(&mut self)?; + let exit = procedure.exit.clone().into_low(&mut self)?; let mut removed_functions = FxHashSet::default(); - if procedure.check_mode == CheckMode::Specifications { + if procedure.check_mode == CheckMode::PurificationFunctional { removed_functions.insert(self.encode_memory_block_bytes_function_name()?); } - let mut predicates = self.collect_owned_predicate_decls()?; - basic_blocks[0].statements.splice( + let (mut predicates, predicates_info) = self.collect_owned_predicate_decls()?; + basic_blocks.get_mut(&entry).unwrap().statements.splice( 0..0, self.lifetimes_state.lifetime_is_alive_initialization(), ); + if prusti_common::config::dump_debug_info() { + let source_filename = self.encoder.env().name.source_file_name(); + prusti_common::report::log::report_with_writer( + "graphviz_method_vir_low_before_perm_desugaring", + format!("{}.{}.dot", source_filename, procedure.name), + |writer| basic_blocks.to_graphviz(writer).unwrap(), + ); + } let mut domains = self.domains_state.destruct(); domains.extend(self.compute_address_state.destruct()); predicates.extend(self.predicates_state.destruct()); let mut lowered_procedure = vir_low::ProcedureDecl { name: procedure.name, + position: procedure.position, locals: self.variables_state.destruct(), + custom_labels: self.labels_state.destruct(), basic_blocks, + entry, + exit, }; let mut methods = self.methods_state.destruct(); let mut functions = self.functions_state.destruct(); - if procedure.check_mode == CheckMode::Specifications { + if procedure.check_mode == CheckMode::PurificationFunctional { + removed_functions.extend( + functions + .iter() + .filter(|function| function.kind == vir_low::FunctionKind::Snap) + .map(|function| function.name.clone()), + ); super::transformations::remove_predicates::remove_predicates( &mut lowered_procedure, &mut methods, @@ -224,13 +293,15 @@ impl<'p, 'v: 'p, 'tcx: 'v> Lowerer<'p, 'v, 'tcx> { ); functions.retain(|function| !removed_functions.contains(&function.name)); }; - Ok(LoweringResult { + let result = LoweringResult { procedures: vec![lowered_procedure], domains, functions, predicates, methods, - }) + predicates_info, + }; + Ok(result) } fn create_parameters(&self, arguments: &[vir_low::Expression]) -> Vec { @@ -243,12 +314,12 @@ impl<'p, 'v: 'p, 'tcx: 'v> Lowerer<'p, 'v, 'tcx> { .collect() } - fn create_block_marker( - &mut self, - label: &vir_mid::BasicBlockId, - ) -> SpannedEncodingResult { - self.create_variable(format!("{label}$marker"), vir_low::Type::Bool) - } + // fn create_block_marker( + // &mut self, + // label: &vir_mid::BasicBlockId, + // ) -> SpannedEncodingResult { + // self.create_variable(format!("{label}$marker"), vir_low::Type::Bool) + // } /// If `check_copy` is true, encode `copy` builtin method. fn lower_type( @@ -263,7 +334,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> Lowerer<'p, 'v, 'tcx> { if check_copy { self.encode_copy_place_method(&ty)?; } - let mut predicates = self.collect_owned_predicate_decls()?; + let (mut predicates, predicates_info) = self.collect_owned_predicate_decls()?; let mut domains = self.domains_state.destruct(); domains.extend(self.compute_address_state.destruct()); predicates.extend(self.predicates_state.destruct()); @@ -273,6 +344,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> Lowerer<'p, 'v, 'tcx> { functions: self.functions_state.destruct(), predicates, methods: self.methods_state.destruct(), + predicates_info, }) } } diff --git a/prusti-viper/src/encoder/middle/core_proof/lowerer/variables/interface.rs b/prusti-viper/src/encoder/middle/core_proof/lowerer/variables/interface.rs index 12374a09824..5b594acd781 100644 --- a/prusti-viper/src/encoder/middle/core_proof/lowerer/variables/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/lowerer/variables/interface.rs @@ -24,6 +24,7 @@ pub(in super::super::super) trait VariablesLowererInterface { name: String, ty: vir_low::Type, ) -> SpannedEncodingResult; + fn register_variable(&mut self, variable: &vir_low::VariableDecl) -> SpannedEncodingResult<()>; fn create_new_temporary_variable( &mut self, ty: vir_low::Type, @@ -43,6 +44,14 @@ impl<'p, 'v: 'p, 'tcx: 'v> VariablesLowererInterface for Lowerer<'p, 'v, 'tcx> { } Ok(vir_low::VariableDecl::new(name, ty)) } + fn register_variable(&mut self, variable: &vir_low::VariableDecl) -> SpannedEncodingResult<()> { + if !self.variables_state.variables.contains_key(&variable.name) { + self.variables_state + .variables + .insert(variable.name.clone(), variable.ty.clone()); + } + Ok(()) + } fn create_new_temporary_variable( &mut self, ty: vir_low::Type, diff --git a/prusti-viper/src/encoder/middle/core_proof/mod.rs b/prusti-viper/src/encoder/middle/core_proof/mod.rs index 2a36c53a2d3..ae6fe67660a 100644 --- a/prusti-viper/src/encoder/middle/core_proof/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/mod.rs @@ -5,12 +5,14 @@ mod builtin_methods; mod compute_address; mod const_generics; mod errors; +mod footprint; mod function_gas; mod interface; mod into_low; mod lifetimes; mod lowerer; mod places; +mod pointers; mod predicates; mod references; mod snapshots; @@ -18,5 +20,7 @@ mod transformations; mod type_layouts; mod types; mod utils; +mod heap; +mod labels; pub(crate) use self::interface::{MidCoreProofEncoderInterface, MidCoreProofEncoderState}; diff --git a/prusti-viper/src/encoder/middle/core_proof/places/encoder.rs b/prusti-viper/src/encoder/middle/core_proof/places/encoder.rs index 4204479138d..cd57e864c75 100644 --- a/prusti-viper/src/encoder/middle/core_proof/places/encoder.rs +++ b/prusti-viper/src/encoder/middle/core_proof/places/encoder.rs @@ -8,7 +8,7 @@ use crate::encoder::{ }; use vir_crate::{ low as vir_low, - middle::{self as vir_mid}, + middle::{self as vir_mid, operations::ty::Typed}, }; pub(super) struct PlaceEncoder {} @@ -42,7 +42,12 @@ impl PlaceExpressionDomainEncoder for PlaceEncoder { lowerer: &mut Lowerer, arg: vir_low::Expression, ) -> SpannedEncodingResult { - lowerer.encode_deref_place(arg, deref.position) + if deref.base.get_type().is_reference() { + lowerer.encode_deref_place(arg, deref.position) + } else { + assert!(deref.base.get_type().is_pointer()); + lowerer.encode_aliased_place_root(deref.position) + } } fn encode_array_index_axioms( @@ -52,4 +57,12 @@ impl PlaceExpressionDomainEncoder for PlaceEncoder { ) -> SpannedEncodingResult<()> { lowerer.encode_place_array_index_axioms(ty) } + + fn encode_labelled_old( + &mut self, + expression: &vir_mid::expression::LabelledOld, + lowerer: &mut Lowerer, + ) -> SpannedEncodingResult { + self.encode_expression(&expression.base, lowerer) + } } diff --git a/prusti-viper/src/encoder/middle/core_proof/places/interface.rs b/prusti-viper/src/encoder/middle/core_proof/places/interface.rs index 66796753217..f7d5ef21b66 100644 --- a/prusti-viper/src/encoder/middle/core_proof/places/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/places/interface.rs @@ -4,6 +4,7 @@ use super::{ use crate::encoder::{ errors::SpannedEncodingResult, middle::core_proof::{ + compute_address::ComputeAddressInterface, lowerer::{DomainsLowererInterface, Lowerer}, type_layouts::TypeLayoutsInterface, }, @@ -54,6 +55,10 @@ pub(in super::super) trait PlacesInterface { position: vir_mid::Position, ) -> SpannedEncodingResult; fn encode_place_array_index_axioms(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()>; + fn encode_aliased_place_root( + &mut self, + position: vir_low::Position, + ) -> SpannedEncodingResult; } impl<'p, 'v: 'p, 'tcx: 'v> PlacesInterface for Lowerer<'p, 'v, 'tcx> { @@ -159,4 +164,19 @@ impl<'p, 'v: 'p, 'tcx: 'v> PlacesInterface for Lowerer<'p, 'v, 'tcx> { } Ok(()) } + fn encode_aliased_place_root( + &mut self, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let return_type = self.place_type()?; + let place_root = self.create_domain_func_app( + "Place", + "aliased_place_root", + vec![], + return_type, + position, + )?; + self.encode_compute_address_for_place_root(&place_root)?; + Ok(place_root) + } } diff --git a/prusti-viper/src/encoder/middle/core_proof/pointers/interface.rs b/prusti-viper/src/encoder/middle/core_proof/pointers/interface.rs new file mode 100644 index 00000000000..62db159bf81 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/pointers/interface.rs @@ -0,0 +1,184 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + addresses::AddressesInterface, + heap::HeapInterface, + lowerer::{DomainsLowererInterface, Lowerer}, + snapshots::{IntoSnapshot, SnapshotValuesInterface, SnapshotVariablesInterface}, + type_layouts::TypeLayoutsInterface, + }, +}; +use vir_crate::{ + common::identifier::WithIdentifier, + low as vir_low, + middle::{self as vir_mid}, +}; + +pub(in super::super) trait PointersInterface { + fn pointer_address( + &mut self, + pointer_type: &vir_mid::Type, + snapshot: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; + fn pointer_slice_len( + &mut self, + pointer_type: &vir_mid::Type, + snapshot: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; + fn pointer_target_snapshot_in_heap( + &mut self, + ty: &vir_mid::Type, + heap: vir_low::VariableDecl, + snapshot: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; + fn pointer_target_snapshot( + &mut self, + ty: &vir_mid::Type, + old_label: &Option, + snapshot: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; + fn pointer_target_as_snapshot_field( + &mut self, + framing_type: &vir_mid::Type, + deref_field_name: &str, + deref_type: vir_low::Type, + framing_place_snapshot: vir_low::Expression, + position: vir_mid::Position, + ) -> SpannedEncodingResult; + fn heap_chunk_to_snapshot( + &mut self, + ty: &vir_mid::Type, + heap_chunk: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; + fn address_in_heap( + &mut self, + heap: vir_low::VariableDecl, + pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult; +} + +impl<'p, 'v: 'p, 'tcx: 'v> PointersInterface for Lowerer<'p, 'v, 'tcx> { + fn pointer_address( + &mut self, + pointer_type: &vir_mid::Type, + snapshot: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + assert!(pointer_type.is_pointer()); + // self.obtain_constant_value(pointer_type, snapshot, position) + let address_type = self.address_type()?; + self.obtain_parameter_snapshot(pointer_type, "address", address_type, snapshot, position) + } + fn pointer_slice_len( + &mut self, + pointer_type: &vir_mid::Type, + snapshot: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + assert!(pointer_type.is_pointer_to_slice()); + let len_type = self.size_type()?; + self.obtain_parameter_snapshot(pointer_type, "len", len_type, snapshot, position) + } + fn pointer_target_snapshot_in_heap( + &mut self, + ty: &vir_mid::Type, + heap: vir_low::VariableDecl, + snapshot: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let address = self.pointer_address(ty, snapshot, position)?; + let heap_chunk = self.heap_lookup(heap.into(), address, position)?; + // let heap_chunk = vir_low::Expression::container_op_no_pos( + // vir_low::ContainerOpKind::MapLookup, + // heap.ty.clone(), + // vec![heap.into(), address], + // ); + let pointer_type = ty.clone().unwrap_pointer(); + self.heap_chunk_to_snapshot(&pointer_type.target_type, heap_chunk, position) + } + fn pointer_target_snapshot( + &mut self, + ty: &vir_mid::Type, + old_label: &Option, + snapshot: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + if self.use_heap_variable()? { + // let address = self.pointer_address(ty, snapshot, position)?; + let heap = self.heap_variable_version_at_label(old_label)?; + // let heap_chunk = vir_low::Expression::container_op_no_pos( + // vir_low::ContainerOpKind::MapLookup, + // heap.ty.clone(), + // vec![heap.into(), address], + // ); + // let pointer_type = ty.clone().unwrap_pointer(); + // self.heap_chunk_to_snapshot(&pointer_type.target_type, heap_chunk, position) + self.pointer_target_snapshot_in_heap(ty, heap, snapshot, position) + } else { + unimplemented!(); + // let address = self.pointer_address(ty, snapshot, position)?; + // let pointer_type = ty.clone().unwrap_pointer(); + // let target_type = &*pointer_type.target_type; + // self.owned_aliased_snap( + // CallContext::Procedure, + // target_type, + // target_type, + // address, + // position, + // ) + } + } + fn pointer_target_as_snapshot_field( + &mut self, + framing_type: &vir_mid::Type, + deref_field_name: &str, + deref_type: vir_low::Type, + framing_place_snapshot: vir_low::Expression, + position: vir_mid::Position, + ) -> SpannedEncodingResult { + self.obtain_parameter_snapshot( + framing_type, + deref_field_name, + deref_type, + framing_place_snapshot, + position, + ) + } + fn heap_chunk_to_snapshot( + &mut self, + ty: &vir_mid::Type, + heap_chunk: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let return_type = ty.to_snapshot(self)?; + self.create_domain_func_app( + // FIXME: Use HEAP_CHUNK_TYPE_NAME here. + "HeapChunk$", + format!("heap_chunk_to${}", ty.get_identifier()), + vec![heap_chunk], + return_type, + position, + ) + } + fn address_in_heap( + &mut self, + _heap: vir_low::VariableDecl, + _pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + todo!("Delete"); + // let pointer = pointer_place.to_pure_snapshot(self)?; + // let address = + // self.pointer_address(pointer_place.get_type(), pointer, pointer_place.position())?; + // let in_heap = vir_low::Expression::container_op_no_pos( + // vir_low::ContainerOpKind::MapContains, + // heap.ty.clone(), + // vec![heap.into(), address], + // ); + // Ok(in_heap) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/pointers/mod.rs b/prusti-viper/src/encoder/middle/core_proof/pointers/mod.rs new file mode 100644 index 00000000000..0e0b37ac78b --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/pointers/mod.rs @@ -0,0 +1,3 @@ +mod interface; + +pub(super) use self::interface::PointersInterface; diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/memory_block/interface.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/memory_block/interface.rs index 8bbf051fcdb..fe66a0c096d 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/memory_block/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/memory_block/interface.rs @@ -5,11 +5,12 @@ use crate::encoder::{ lowerer::{ DomainsLowererInterface, FunctionsLowererInterface, Lowerer, PredicatesLowererInterface, }, + snapshots::SnapshotValuesInterface, type_layouts::TypeLayoutsInterface, }, }; use rustc_hash::FxHashSet; -use vir_crate::low as vir_low; +use vir_crate::{common::expression::QuantifierHelpers, low as vir_low, middle as vir_mid}; #[derive(Default)] pub(in super::super) struct PredicatesMemoryBlockState { @@ -21,10 +22,12 @@ trait Private { fn encode_generic_memory_block_predicate( &mut self, predicate_name: &str, + predicate_kind: vir_low::PredicateKind, ) -> SpannedEncodingResult<()>; fn encode_generic_memory_block_acc( &mut self, predicate_name: &str, + predicate_kind: vir_low::PredicateKind, place: vir_low::Expression, size: vir_low::Expression, position: vir_low::Position, @@ -35,6 +38,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { fn encode_generic_memory_block_predicate( &mut self, predicate_name: &str, + predicate_kind: vir_low::PredicateKind, ) -> SpannedEncodingResult<()> { if !self .predicates_encoding_state @@ -48,6 +52,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { .insert(predicate_name.to_string()); let predicate = vir_low::PredicateDecl::new( predicate_name, + predicate_kind, vec![ vir_low::VariableDecl::new("address", self.address_type()?), vir_low::VariableDecl::new("size", self.size_type()?), @@ -61,11 +66,12 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { fn encode_generic_memory_block_acc( &mut self, predicate_name: &str, + predicate_kind: vir_low::PredicateKind, place: vir_low::Expression, size: vir_low::Expression, position: vir_low::Position, ) -> SpannedEncodingResult { - self.encode_generic_memory_block_predicate(predicate_name)?; + self.encode_generic_memory_block_predicate(predicate_name, predicate_kind)?; let expression = vir_low::Expression::predicate_access_predicate( predicate_name.to_string(), vec![place, size], @@ -78,13 +84,40 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { pub(in super::super::super) trait PredicatesMemoryBlockInterface { fn bytes_type(&mut self) -> SpannedEncodingResult; + fn byte_type(&mut self) -> SpannedEncodingResult; + fn encode_read_byte_expression( + &mut self, + bytes: vir_low::Expression, + index: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; fn encode_memory_block_predicate(&mut self) -> SpannedEncodingResult<()>; + fn encode_memory_block_stack_acc( + &mut self, + place: vir_low::Expression, + size: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; + fn encode_memory_block_range_acc( + &mut self, + address: vir_low::Expression, + size: vir_low::Expression, + start_index: vir_low::Expression, + end_index: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; fn encode_memory_block_stack_drop_acc( &mut self, place: vir_low::Expression, size: vir_low::Expression, position: vir_low::Position, ) -> SpannedEncodingResult; + fn encode_memory_block_heap_drop_acc( + &mut self, + place: vir_low::Expression, + size: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; fn encode_memory_block_bytes_function_name(&mut self) -> SpannedEncodingResult; fn encode_memory_block_bytes_expression( &mut self, @@ -97,8 +130,72 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesMemoryBlockInterface for Lowerer<'p, 'v, 't fn bytes_type(&mut self) -> SpannedEncodingResult { self.domain_type("Bytes") } + fn byte_type(&mut self) -> SpannedEncodingResult { + self.domain_type("Byte") + } + fn encode_read_byte_expression( + &mut self, + bytes: vir_low::Expression, + index: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let byte_type = self.byte_type()?; + self.create_domain_func_app( + "Byte", + "Byte$read_byte", + vec![bytes, index], + byte_type, + position, + ) + } fn encode_memory_block_predicate(&mut self) -> SpannedEncodingResult<()> { - self.encode_generic_memory_block_predicate("MemoryBlock") + self.encode_generic_memory_block_predicate( + "MemoryBlock", + vir_low::PredicateKind::MemoryBlock, + ) + } + fn encode_memory_block_stack_acc( + &mut self, + place: vir_low::Expression, + size: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + self.encode_generic_memory_block_acc( + "MemoryBlock", + vir_low::PredicateKind::MemoryBlock, + place, + size, + position, + ) + } + fn encode_memory_block_range_acc( + &mut self, + address: vir_low::Expression, + size: vir_low::Expression, + start_index: vir_low::Expression, + end_index: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + use vir_low::macros::*; + let size_type = self.size_type_mid()?; + var_decls! { + index: Int + } + let element_address = + self.address_offset(size.clone(), address, index.clone().into(), position)?; + let predicate = + self.encode_memory_block_stack_acc(element_address.clone(), size.clone(), position)?; + let start_index = self.obtain_constant_value(&size_type, start_index, position)?; + let end_index = self.obtain_constant_value(&size_type, end_index, position)?; + let body = expr!( + (([start_index] <= index) && (index < [end_index])) ==> [predicate] + ); + let expression = vir_low::Expression::forall( + vec![index], + vec![vir_low::Trigger::new(vec![element_address])], + body, + ); + Ok(expression) } fn encode_memory_block_stack_drop_acc( &mut self, @@ -106,7 +203,27 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesMemoryBlockInterface for Lowerer<'p, 'v, 't size: vir_low::Expression, position: vir_low::Position, ) -> SpannedEncodingResult { - self.encode_generic_memory_block_acc("MemoryBlockStackDrop", place, size, position) + self.encode_generic_memory_block_acc( + "MemoryBlockStackDrop", + vir_low::PredicateKind::WithoutSnapshotWhole, + place, + size, + position, + ) + } + fn encode_memory_block_heap_drop_acc( + &mut self, + place: vir_low::Expression, + size: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + self.encode_generic_memory_block_acc( + "MemoryBlockHeapDrop", + vir_low::PredicateKind::WithoutSnapshotWhole, + place, + size, + position, + ) } fn encode_memory_block_bytes_function_name(&mut self) -> SpannedEncodingResult { Ok("MemoryBlock$bytes".to_string()) diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/mod.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/mod.rs index 95565515e2e..31edc4c7c80 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/mod.rs @@ -1,11 +1,14 @@ mod memory_block; mod owned; +mod restoration; mod state; pub(super) use self::{ memory_block::PredicatesMemoryBlockInterface, owned::{ - FracRefUseBuilder, OwnedNonAliasedUseBuilder, PredicatesOwnedInterface, UniqueRefUseBuilder, + FracRefUseBuilder, OwnedAliasedSnapCallBuilder, OwnedNonAliasedSnapCallBuilder, + OwnedNonAliasedUseBuilder, PredicateInfo, PredicatesOwnedInterface, }, + restoration::RestorationInterface, state::PredicatesState, }; diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/function_decl.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/function_decl.rs new file mode 100644 index 00000000000..7934c27e520 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/function_decl.rs @@ -0,0 +1,251 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + footprint::FootprintInterface, + lifetimes::LifetimesInterface, + lowerer::Lowerer, + places::PlacesInterface, + snapshots::{ + AssertionToSnapshotConstructor, IntoPureSnapshot, IntoSnapshot, PredicateKind, + SnapshotValidityInterface, SnapshotValuesInterface, + }, + type_layouts::TypeLayoutsInterface, + }, +}; +use vir_crate::{ + common::{ + expression::{BinaryOperationHelpers, ExpressionIterator}, + identifier::WithIdentifier, + }, + low::{self as vir_low}, + middle as vir_mid, +}; + +pub(in super::super::super) struct FunctionDeclBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super) lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + pub(in super::super) function_name: &'l str, + pub(in super::super) ty: &'l vir_mid::Type, + pub(in super::super) type_decl: &'l vir_mid::TypeDecl, + pub(in super::super) parameters: Vec, + pub(in super::super) pres: Vec, + pub(in super::super) posts: Vec, + pub(in super::super) conjuncts: Option>, + pub(in super::super) position: vir_low::Position, + pub(in super::super) place: vir_low::VariableDecl, +} + +impl<'l, 'p, 'v, 'tcx> FunctionDeclBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + function_name: &'l str, + ty: &'l vir_mid::Type, + type_decl: &'l vir_mid::TypeDecl, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let place = vir_low::VariableDecl::new("place", lowerer.place_type()?); + Ok(Self { + function_name, + ty, + type_decl, + parameters: Vec::new(), + pres: Vec::new(), + posts: Vec::new(), + conjuncts: None, + position, + lowerer, + place, + }) + } + + pub(in super::super::super) fn build(self) -> SpannedEncodingResult { + let return_type = self.ty.to_snapshot(self.lowerer)?; + let function = vir_low::FunctionDecl { + name: format!("{}${}", self.function_name, self.ty.get_identifier()), + kind: vir_low::FunctionKind::Snap, + parameters: self.parameters, + body: self + .conjuncts + .map(|conjuncts| conjuncts.into_iter().conjoin()), + pres: self.pres, + posts: self.posts, + return_type, + }; + Ok(function) + } + + pub(in super::super) fn create_lifetime_parameters(&mut self) -> SpannedEncodingResult<()> { + self.parameters + .extend(self.lowerer.create_lifetime_parameters(self.type_decl)?); + Ok(()) + } + + pub(in super::super) fn create_const_parameters(&mut self) -> SpannedEncodingResult<()> { + for parameter in self.type_decl.get_const_parameters() { + self.parameters + .push(parameter.to_pure_snapshot(self.lowerer)?); + } + Ok(()) + } + + pub(in super::super) fn add_precondition( + &mut self, + assertion: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + self.pres.push(assertion); + Ok(()) + } + + pub(in super::super) fn add_postcondition( + &mut self, + assertion: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + self.posts.push(assertion); + Ok(()) + } + + pub(in super::super) fn array_length_int( + &mut self, + array_length_mid: &vir_mid::VariableDecl, + ) -> SpannedEncodingResult { + let array_length = array_length_mid.to_pure_snapshot(self.lowerer)?; + let size_type_mid = self.lowerer.size_type_mid()?; + self.lowerer + .obtain_constant_value(&size_type_mid, array_length.into(), self.position) + } + + pub(in super::super) fn result_type(&mut self) -> SpannedEncodingResult { + self.ty.to_snapshot(self.lowerer) + } + + pub(in super::super) fn result(&mut self) -> SpannedEncodingResult { + Ok(vir_low::VariableDecl::new("__result", self.result_type()?)) + } + + pub(in super::super) fn add_validity_postcondition(&mut self) -> SpannedEncodingResult<()> { + let result = self.result()?; + let validity = self + .lowerer + .encode_snapshot_valid_call_for_type(result.into(), self.ty)?; + self.add_postcondition(validity) + } + + pub(in super::super) fn add_snapshot_len_equal_to_postcondition( + &mut self, + array_length_mid: &vir_mid::VariableDecl, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let snapshot = self.result()?; + let snapshot_length = self + .lowerer + .obtain_array_len_snapshot(snapshot.into(), self.position)?; + let array_length_int = self.array_length_int(array_length_mid)?; + let expression = expr! { + ([array_length_int] == [snapshot_length]) + }; + self.add_postcondition(expression) + } + + pub(in super::super) fn create_field_snap_call( + &mut self, + field: &vir_mid::FieldDecl, + snap_call: impl FnOnce( + &mut Self, + &vir_mid::FieldDecl, + vir_low::Expression, + ) -> SpannedEncodingResult, + ) -> SpannedEncodingResult { + let field_place = self.lowerer.encode_field_place( + self.ty, + field, + self.place.clone().into(), + self.position, + )?; + snap_call(self, field, field_place.into()) + // let target_slice_len = self.slice_len_expression()?; + // self.lowerer.frac_ref_snap( + // CallContext::BuiltinMethod, + // &field.ty, + // &field.ty, + // field_place, + // self.root_address.clone().into(), + // self.reference_lifetime.clone().into(), + // target_slice_len, + // ) + } + + pub(in super::super) fn create_field_snapshot_equality( + &mut self, + field: &vir_mid::FieldDecl, + snap_call: impl FnOnce( + &mut Self, + &vir_mid::FieldDecl, + vir_low::Expression, + ) -> SpannedEncodingResult, + ) -> SpannedEncodingResult { + use vir_low::macros::*; + let result = self.result()?; + let field_snapshot = self.lowerer.obtain_struct_field_snapshot( + self.ty, + field, + result.into(), + self.position, + )?; + let snap_call = self.create_field_snap_call(&field, snap_call)?; + Ok(expr! { + [field_snapshot] == [snap_call] + }) + } + + pub(in super::super::super) fn add_unfolding_postcondition( + &mut self, + precondition_predicate: vir_low::Expression, + body: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + let unfolding = precondition_predicate.into_unfolding(body); + self.add_postcondition(unfolding) + } + + pub(in super::super::super) fn add_structural_invariant( + &mut self, + decl: &vir_mid::type_decl::Struct, + precondition_predicate: Option, + predicate_kind: PredicateKind, + snap_call: &impl Fn( + &mut Self, + &vir_mid::FieldDecl, + vir_low::Expression, + ) -> SpannedEncodingResult, + ) -> SpannedEncodingResult<()> { + if let Some(invariant) = decl.structural_invariant.clone() { + let mut regular_field_arguments = Vec::new(); + for field in &decl.fields { + let field_snap_call = self.create_field_snap_call(field, snap_call)?; + regular_field_arguments.push(field_snap_call); + // regular_field_arguments.push(self.create_field_snap_call(field)?); + } + let result = self.result()?; + let deref_fields = self + .lowerer + .structural_invariant_to_deref_fields(&invariant)?; + let mut constructor_encoder = AssertionToSnapshotConstructor::for_function_body( + predicate_kind, + self.ty, + regular_field_arguments, + decl.fields.clone(), + deref_fields, + self.position, + ); + let invariant_expression = invariant.into_iter().conjoin(); + let permission_expression = invariant_expression.convert_into_permission_expression(); + let constructor = constructor_encoder + .expression_to_snapshot_constructor(self.lowerer, &permission_expression)?; + let body = vir_low::Expression::equals(result.into(), constructor); + if let Some(precondition_predicate) = precondition_predicate { + self.add_unfolding_postcondition(precondition_predicate, body)?; + } else { + self.add_postcondition(body)?; + } + } + Ok(()) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/function_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/function_use.rs new file mode 100644 index 00000000000..2ddd0f84e31 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/function_use.rs @@ -0,0 +1,80 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + builtin_methods::CallContext, + lifetimes::LifetimesInterface, + lowerer::Lowerer, + snapshots::{IntoPureSnapshot, IntoSnapshot}, + }, +}; +use vir_crate::{ + common::identifier::WithIdentifier, + low::{self as vir_low}, + middle as vir_mid, + middle::operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes}, +}; + +pub(in super::super) struct FunctionCallBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + pub(in super::super) lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + pub(in super::super) function_name: &'l str, + pub(in super::super) context: CallContext, + pub(in super::super) ty: &'l vir_mid::Type, + pub(in super::super) generics: &'l G, + pub(in super::super) arguments: Vec, + pub(in super::super) position: vir_low::Position, +} + +impl<'l, 'p, 'v, 'tcx, G> FunctionCallBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + pub(in super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + function_name: &'l str, + context: CallContext, + ty: &'l vir_mid::Type, + generics: &'l G, + arguments: Vec, + position: vir_low::Position, + ) -> SpannedEncodingResult { + Ok(Self { + lowerer, + function_name, + context, + ty, + generics, + arguments, + position, + }) + } + + pub(in super::super) fn build(self) -> SpannedEncodingResult { + let return_type = self.ty.to_snapshot(self.lowerer)?; + let call = vir_low::Expression::function_call( + format!("{}${}", self.function_name, self.ty.get_identifier()), + self.arguments, + return_type, + ); + Ok(call.set_default_position(self.position)) + } + + pub(in super::super) fn add_lifetime_arguments(&mut self) -> SpannedEncodingResult<()> { + self.arguments.extend( + self.lowerer + .create_lifetime_arguments(self.context, self.generics)?, + ); + Ok(()) + } + + pub(in super::super) fn add_const_arguments(&mut self) -> SpannedEncodingResult<()> { + // FIXME: remove code duplication with other add_const_arguments methods + for argument in self.generics.get_const_arguments() { + self.arguments + .push(argument.to_pure_snapshot(self.lowerer)?); + } + Ok(()) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/mod.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/mod.rs index ef427252419..6bcb70532ad 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/mod.rs @@ -1,2 +1,4 @@ +pub(super) mod function_decl; +pub(super) mod function_use; pub(super) mod predicate_decl; pub(super) mod predicate_use; diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/predicate_decl.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/predicate_decl.rs index ab23d9db1f2..557225a27e1 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/predicate_decl.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/predicate_decl.rs @@ -1,17 +1,27 @@ use crate::encoder::{ errors::SpannedEncodingResult, middle::core_proof::{ + addresses::AddressesInterface, builtin_methods::CallContext, lifetimes::LifetimesInterface, lowerer::Lowerer, - predicates::owned::builders::{ - unique_ref::predicate_use::UniqueRefUseBuilder, FracRefUseBuilder, + places::PlacesInterface, + pointers::PointersInterface, + predicates::{ + owned::builders::FracRefUseBuilder, PredicatesMemoryBlockInterface, + PredicatesOwnedInterface, }, references::ReferencesInterface, - snapshots::{IntoPureSnapshot, SnapshotValidityInterface, SnapshotValuesInterface}, + snapshots::{ + IntoPureSnapshot, IntoSnapshot, IntoSnapshotLowerer, PredicateKind, + SelfFramingAssertionToSnapshot, SnapshotBytesInterface, SnapshotValidityInterface, + SnapshotValuesInterface, + }, type_layouts::TypeLayoutsInterface, + types::TypesInterface, }, }; +use prusti_common::config; use vir_crate::{ common::{expression::ExpressionIterator, identifier::WithIdentifier}, low::{self as vir_low}, @@ -26,6 +36,12 @@ pub(in super::super::super) struct PredicateDeclBuilder<'l, 'p, 'v, 'tcx> { pub(in super::super) parameters: Vec, pub(in super::super) conjuncts: Option>, pub(in super::super) position: vir_low::Position, + /// `place` is used by subtypes that cannot be aliased. + pub(in super::super) place: vir_low::VariableDecl, + /// `root_address` is used by subtypes that cannot be aliased. + pub(in super::super) root_address: vir_low::VariableDecl, + /// `address` is used by subtypes that can be aliased. + pub(in super::super) address: vir_low::VariableDecl, } impl<'l, 'p, 'v, 'tcx> PredicateDeclBuilder<'l, 'p, 'v, 'tcx> { @@ -37,6 +53,9 @@ impl<'l, 'p, 'v, 'tcx> PredicateDeclBuilder<'l, 'p, 'v, 'tcx> { position: vir_low::Position, ) -> SpannedEncodingResult { Ok(Self { + place: vir_low::VariableDecl::new("place", lowerer.place_type()?), + root_address: vir_low::VariableDecl::new("root_address", lowerer.address_type()?), + address: vir_low::VariableDecl::new("address", lowerer.address_type()?), ty, predicate_name, type_decl, @@ -50,6 +69,7 @@ impl<'l, 'p, 'v, 'tcx> PredicateDeclBuilder<'l, 'p, 'v, 'tcx> { pub(in super::super) fn build(self) -> vir_low::PredicateDecl { vir_low::PredicateDecl { name: format!("{}${}", self.predicate_name, self.ty.get_identifier()), + kind: vir_low::PredicateKind::Owned, parameters: self.parameters, body: self .conjuncts @@ -89,36 +109,137 @@ impl<'l, 'p, 'v, 'tcx> PredicateDeclBuilder<'l, 'p, 'v, 'tcx> { self.add_conjunct(validity) } + pub(in super::super) fn add_unique_ref_pointer_predicate( + &mut self, + lifetime: &vir_mid::ty::LifetimeConst, + place: vir_low::VariableDecl, + root_address: vir_low::VariableDecl, + // _snapshot: &vir_low::VariableDecl, + ) -> SpannedEncodingResult { + let lifetime = lifetime.to_pure_snapshot(self.lowerer)?; + // let pointer_type = &self.lowerer.reference_address_type(self.ty)?; + let pointer_type = { + let reference_type = self.type_decl.clone().unwrap_reference(); + vir_mid::Type::pointer(reference_type.target_type) + }; + self.lowerer.ensure_type_definition(&pointer_type)?; + let current_snapshot = true.into(); // FIXME + let final_snapshot = true.into(); // FIXME + let expression = self.lowerer.unique_ref_predicate( + CallContext::BuiltinMethod, + &pointer_type, + &pointer_type, + place.clone().into(), + root_address.clone().into(), + current_snapshot, + final_snapshot, + lifetime.into(), + None, // FIXME + )?; + self.add_conjunct(expression)?; + Ok(pointer_type) + } + + /// `is_unique_ref` – whether the predicate is used in `UniqueRef` or `Owned`. pub(in super::super) fn add_unique_ref_target_predicate( &mut self, target_type: &vir_mid::Type, lifetime: &vir_mid::ty::LifetimeConst, - place: &vir_low::VariableDecl, - snapshot: &vir_low::VariableDecl, + place: vir_low::Expression, + root_address: vir_low::VariableDecl, + // snapshot: &vir_low::VariableDecl, + is_unique_ref: bool, // FIXME: Refactor to not use this flag. ) -> SpannedEncodingResult<()> { use vir_low::macros::*; let deref_place = self .lowerer .reference_deref_place(place.clone().into(), self.position)?; - let target_address = - self.lowerer - .reference_address(self.ty, snapshot.clone().into(), self.position)?; - let current_snapshot = self.lowerer.reference_target_current_snapshot( - self.ty, - snapshot.clone().into(), - self.position, - )?; - let final_snapshot = self.lowerer.reference_target_final_snapshot( - self.ty, - snapshot.clone().into(), - self.position, - )?; let lifetime_alive = self .lowerer .encode_lifetime_const_into_pure_is_alive_variable(lifetime)?; let lifetime = lifetime.to_pure_snapshot(self.lowerer)?; - let mut builder = UniqueRefUseBuilder::new( - self.lowerer, + let (target_address, target_len) = if config::use_snapshot_parameters_in_predicates() { + unimplemented!("TODO: Delete this branch"); + // // FIXME: target_len should be the length of the target slice. + // ( + // self.lowerer + // .reference_address(self.ty, snapshot.clone().into(), self.position)?, + // None, + // ) + } else { + let pointer_type = &self.lowerer.reference_address_type(self.ty)?; + let pointer_snapshot = if is_unique_ref { + self.lowerer.unique_ref_snap( + CallContext::BuiltinMethod, + pointer_type, + pointer_type, + place.clone().into(), + root_address.clone().into(), + lifetime.clone().into(), + None, + false, + )? + } else { + self.lowerer + .encode_snapshot_to_bytes_function(pointer_type)?; + let size_of = self + .lowerer + .encode_type_size_expression2(self.ty, self.type_decl)?; + let compute_address = ty!(Address); + let compute_address_expression = expr! { + ComputeAddress::compute_address( + [place.clone().into()], + [root_address.clone().into()] + ) + }; + let bytes = self + .lowerer + .encode_memory_block_bytes_expression(compute_address_expression, size_of)?; + let from_bytes = pointer_type.to_snapshot(self.lowerer)?; + expr! { + Snap::from_bytes([bytes]) + } + }; + let target_address = self.lowerer.pointer_address( + pointer_type, + pointer_snapshot.clone(), + self.position, + )?; + // .obtain_constant_value(address_type, pointer_snapshot, self.position)? + + let target_len = if pointer_type.is_pointer_to_slice() { + Some(self.lowerer.pointer_slice_len( + pointer_type, + pointer_snapshot, + self.position, + )?) + } else { + None + }; + (target_address, target_len) + }; + // let current_snapshot = self.lowerer.reference_target_current_snapshot( + // self.ty, + // snapshot.clone().into(), + // self.position, + // )?; + // let final_snapshot = self.lowerer.reference_target_final_snapshot( + // self.ty, + // snapshot.clone().into(), + // self.position, + // )?; + let current_snapshot = true.into(); // FIXME + let final_snapshot = self.lowerer.unique_ref_snap( + CallContext::BuiltinMethod, + target_type, + target_type, + deref_place.clone(), + target_address.clone(), + lifetime.clone().into(), + target_len.clone(), + true, + )?; + let expression = self.lowerer.unique_ref_predicate( CallContext::BuiltinMethod, target_type, target_type, @@ -127,34 +248,78 @@ impl<'l, 'p, 'v, 'tcx> PredicateDeclBuilder<'l, 'p, 'v, 'tcx> { current_snapshot, final_snapshot, lifetime.into(), + target_len, )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let expression = builder.build(); self.add_conjunct(expr! { [lifetime_alive.into()] ==> [expression] }) } + // FIXME: Code duplication with `add_unique_ref_target_predicate`. pub(in super::super) fn add_frac_ref_target_predicate( &mut self, target_type: &vir_mid::Type, lifetime: &vir_mid::ty::LifetimeConst, - place: &vir_low::VariableDecl, - snapshot: &vir_low::VariableDecl, + place: vir_low::Expression, + root_address: vir_low::VariableDecl, + // snapshot: &vir_low::VariableDecl, ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let lifetime_alive = self + .lowerer + .encode_lifetime_const_into_pure_is_alive_variable(lifetime)?; let deref_place = self .lowerer .reference_deref_place(place.clone().into(), self.position)?; + let pointer_type = &self.lowerer.reference_address_type(self.ty)?; + let pointer_snapshot = { + self.lowerer + .encode_snapshot_to_bytes_function(pointer_type)?; + let size_of = self + .lowerer + .encode_type_size_expression2(self.ty, self.type_decl)?; + let compute_address = ty!(Address); + let compute_address_expression = expr! { + ComputeAddress::compute_address( + [place.clone().into()], + [root_address.clone().into()] + ) + }; + let bytes = self + .lowerer + .encode_memory_block_bytes_expression(compute_address_expression, size_of)?; + let from_bytes = pointer_type.to_snapshot(self.lowerer)?; + expr! { + Snap::from_bytes([bytes]) + } + }; let target_address = self.lowerer - .reference_address(self.ty, snapshot.clone().into(), self.position)?; - let current_snapshot = self.lowerer.reference_target_current_snapshot( - self.ty, - snapshot.clone().into(), - self.position, - )?; + .pointer_address(pointer_type, pointer_snapshot.clone(), self.position)?; + // let target_address = + // self.lowerer + // .reference_address(self.ty, snapshot.clone().into(), self.position)?; + // let current_snapshot = self.lowerer.reference_target_current_snapshot( + // self.ty, + // snapshot.clone().into(), + // self.position, + // )?; + let current_snapshot = true.into(); // FIXME let lifetime = lifetime.to_pure_snapshot(self.lowerer)?; - let mut builder = FracRefUseBuilder::new( - self.lowerer, + // let mut builder = FracRefUseBuilder::new( + // self.lowerer, + // CallContext::BuiltinMethod, + // target_type, + // target_type, + // deref_place, + // target_address, + // // current_snapshot, + // lifetime.into(), + // )?; + // builder.add_lifetime_arguments()?; + // builder.add_const_arguments()?; + // let expression = builder.build(); + // self.add_conjunct(expression); + let target_len = None; // FIXME + let expression = self.lowerer.frac_ref_predicate( CallContext::BuiltinMethod, target_type, target_type, @@ -162,11 +327,9 @@ impl<'l, 'p, 'v, 'tcx> PredicateDeclBuilder<'l, 'p, 'v, 'tcx> { target_address, current_snapshot, lifetime.into(), + target_len, )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let expression = builder.build(); - self.add_conjunct(expression) + self.add_conjunct(expr! { [lifetime_alive.into()] ==> [expression] }) } pub(in super::super) fn array_length_int( @@ -194,6 +357,28 @@ impl<'l, 'p, 'v, 'tcx> PredicateDeclBuilder<'l, 'p, 'v, 'tcx> { }; self.add_conjunct(expression) } + + pub(in super::super::super) fn add_structural_invariant( + &mut self, + decl: &vir_mid::type_decl::Struct, + predicate_kind: PredicateKind, + ) -> SpannedEncodingResult> { + if let Some(invariant) = &decl.structural_invariant { + let mut encoder = SelfFramingAssertionToSnapshot::for_predicate_body( + self.place.clone(), + self.root_address.clone(), + predicate_kind, + ); + for assertion in invariant { + let low_assertion = + encoder.expression_to_snapshot(self.lowerer, assertion, true)?; + self.add_conjunct(low_assertion)?; + } + Ok(encoder.into_created_predicate_types()) + } else { + Ok(Vec::new()) + } + } } pub(in super::super::super) trait PredicateDeclBuilderMethods<'l, 'p, 'v, 'tcx> diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/predicate_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/predicate_use.rs index 6eaa9dc1e77..82455754b07 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/predicate_use.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/common/predicate_use.rs @@ -53,7 +53,7 @@ where pub(in super::super) fn build(self) -> vir_low::Expression { vir_low::Expression::predicate_access_predicate( - format!("{}${}", self.predicate_name, self.ty.get_identifier()), + self.predicate_name(), self.arguments, self.permission_amount .unwrap_or_else(vir_low::Expression::full_permission), @@ -61,6 +61,10 @@ where ) } + pub(in super::super) fn predicate_name(&self) -> String { + format!("{}${}", self.predicate_name, self.ty.get_identifier()) + } + pub(in super::super) fn add_lifetime_arguments(&mut self) -> SpannedEncodingResult<()> { self.arguments.extend( self.lowerer diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/function_decl.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/function_decl.rs new file mode 100644 index 00000000000..0e0e244457e --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/function_decl.rs @@ -0,0 +1,305 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + footprint::FootprintInterface, + lifetimes::LifetimesInterface, + lowerer::Lowerer, + places::PlacesInterface, + predicates::{ + owned::builders::common::function_decl::FunctionDeclBuilder, PredicatesOwnedInterface, + }, + snapshots::{ + AssertionToSnapshotConstructor, IntoPureSnapshot, PredicateKind, + SnapshotValuesInterface, + }, + type_layouts::TypeLayoutsInterface, + }, +}; + +use vir_crate::{ + common::{ + expression::{BinaryOperationHelpers, ExpressionIterator}, + identifier::WithIdentifier, + }, + low::{self as vir_low}, + middle::{ + self as vir_mid, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes}, + }, +}; + +use super::predicate_use::FracRefUseBuilder; + +pub(in super::super::super) struct FracRefSnapFunctionBuilder<'l, 'p, 'v, 'tcx> { + inner: FunctionDeclBuilder<'l, 'p, 'v, 'tcx>, + // place: vir_low::VariableDecl, + root_address: vir_low::VariableDecl, + reference_lifetime: vir_low::VariableDecl, + slice_len: Option, +} + +impl<'l, 'p, 'v, 'tcx> FracRefSnapFunctionBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + ty: &'l vir_mid::Type, + type_decl: &'l vir_mid::TypeDecl, + ) -> SpannedEncodingResult { + let slice_len = if ty.is_slice() { + Some(vir_mid::VariableDecl::new( + "slice_len", + lowerer.size_type_mid()?, + )) + } else { + None + }; + Ok(Self { + // place: vir_low::VariableDecl::new("place", lowerer.place_type()?), + root_address: vir_low::VariableDecl::new("root_address", lowerer.address_type()?), + reference_lifetime: vir_low::VariableDecl::new( + "reference_lifetime", + lowerer.lifetime_type()?, + ), + slice_len, + inner: FunctionDeclBuilder::new( + lowerer, + "snap_current_frac_ref", + ty, + type_decl, + Default::default(), + )?, + }) + } + + pub(in super::super::super) fn create_parameters(&mut self) -> SpannedEncodingResult<()> { + self.inner.parameters.push(self.inner.place.clone()); + self.inner.parameters.push(self.root_address.clone()); + self.inner.parameters.push(self.reference_lifetime.clone()); + self.inner.create_lifetime_parameters()?; + if let Some(slice_len) = self.slice_len()? { + self.inner.parameters.push(slice_len); + } + self.inner.create_const_parameters()?; + Ok(()) + } + + pub(in super::super::super) fn add_frac_ref_precondition( + &mut self, + ) -> SpannedEncodingResult<()> { + let predicate = self.precondition_predicate()?; + self.inner.add_precondition(predicate) + } + + // FIXME: Code duplication. + fn slice_len(&mut self) -> SpannedEncodingResult> { + self.slice_len + .as_ref() + .map(|slice_len_mid| slice_len_mid.to_pure_snapshot(self.inner.lowerer)) + .transpose() + } + + // FIXME: Code duplication. + fn slice_len_expression(&mut self) -> SpannedEncodingResult> { + Ok(self.slice_len()?.map(|slice_len| slice_len.into())) + } + + fn precondition_predicate(&mut self) -> SpannedEncodingResult { + self.frac_ref_predicate( + self.inner.ty, + self.inner.type_decl, + self.inner.place.clone().into(), + self.root_address.clone().into(), + self.reference_lifetime.clone().into(), + ) + } + + fn frac_ref_predicate( + &mut self, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + reference_lifetime: vir_low::Expression, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + // let slice_len = if let Some(slice_len_mid) = &self.slice_len { + // let slice_len = slice_len_mid.to_pure_snapshot(self.inner.lowerer)?; + // Some(slice_len.into()) + // } else { + // None + // }; + let mut builder = FracRefUseBuilder::new( + self.inner.lowerer, + CallContext::BuiltinMethod, + ty, + generics, + place, + root_address, + reference_lifetime, + // slice_len, + )?; + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.build() + } + + pub(in super::super::super) fn build(self) -> SpannedEncodingResult { + self.inner.build() + } + + // // FIXME: Code duplication. + // fn create_field_snap_call( + // &mut self, + // field: &vir_mid::FieldDecl, + // ) -> SpannedEncodingResult { + // let field_place = self.inner.lowerer.encode_field_place( + // self.inner.ty, + // field, + // self.inner.place.clone().into(), + // self.inner.position, + // )?; + // let target_slice_len = self.slice_len_expression()?; + // self.inner.lowerer.frac_ref_snap( + // CallContext::BuiltinMethod, + // &field.ty, + // &field.ty, + // field_place, + // self.root_address.clone().into(), + // self.reference_lifetime.clone().into(), + // target_slice_len, + // ) + // } + + // FIXME: Code duplication. + pub(in super::super::super) fn create_field_snapshot_equality( + &mut self, + field: &vir_mid::FieldDecl, + ) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let result = self.inner.result()?; + // let field_snapshot = self.inner.lowerer.obtain_struct_field_snapshot( + // self.inner.ty, + // field, + // result.into(), + // self.inner.position, + // )?; + // let snap_call = self.create_field_snap_call(&field)?; + // Ok(expr! { + // [field_snapshot] == [snap_call] + // }) + // self.inner.create_field_snap_call(field, |builder, field, field_place| { + // let target_slice_len = self.slice_len_expression()?; + // self.inner.lowerer.frac_ref_snap( + // CallContext::BuiltinMethod, + // &field.ty, + // &field.ty, + // field_place, + // self.root_address.clone().into(), + // self.reference_lifetime.clone().into(), + // target_slice_len, + // ) + // }) + let frac_ref_call = self.field_frac_ref_snap()?; + self.inner + .create_field_snapshot_equality(field, frac_ref_call) + } + + fn field_frac_ref_snap( + &mut self, + ) -> SpannedEncodingResult< + impl Fn( + &mut FunctionDeclBuilder, + &vir_mid::FieldDecl, + vir_low::Expression, + ) -> SpannedEncodingResult, + > { + let target_slice_len = self.slice_len_expression()?; + let root_address: vir_low::Expression = self.root_address.clone().into(); + let root_address = std::rc::Rc::new(root_address); + let lifetime: vir_low::Expression = self.reference_lifetime.clone().into(); + let lifetime = std::rc::Rc::new(lifetime); + Ok( + move |builder: &mut FunctionDeclBuilder, field: &vir_mid::FieldDecl, field_place| { + builder.lowerer.frac_ref_snap( + CallContext::BuiltinMethod, + &field.ty, + &field.ty, + field_place, + (*root_address).clone(), + (*lifetime).clone(), + target_slice_len.clone(), + ) + }, + ) + } + + // FIXME: Code duplication. + pub(in super::super::super) fn add_unfolding_postcondition( + &mut self, + body: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + let predicate = self.precondition_predicate()?; + let unfolding = predicate.into_unfolding(body); + self.inner.add_postcondition(unfolding) + } + + pub(in super::super::super) fn add_structural_invariant( + &mut self, + decl: &vir_mid::type_decl::Struct, + ) -> SpannedEncodingResult<()> { + let precondition_predicate = self.precondition_predicate()?; + let predicate_kind = PredicateKind::FracRef { + lifetime: self.reference_lifetime.clone().into(), + }; + let snap_call = self.field_frac_ref_snap()?; + self.inner.add_structural_invariant( + decl, + Some(precondition_predicate), + predicate_kind, + &snap_call, + ) + } + + // // FIXME: Code duplication. + // pub(in super::super::super) fn add_structural_invariant2( + // &mut self, + // decl: &vir_mid::type_decl::Struct, + // ) -> SpannedEncodingResult<()> { + // if let Some(invariant) = decl.structural_invariant.clone() { + // let mut regular_field_arguments = Vec::new(); + // for field in &decl.fields { + // let frac_ref_call = self.field_frac_ref_snap()?; + // let snap_call = self.inner.create_field_snap_call(field, frac_ref_call)?; + // regular_field_arguments.push(snap_call); + // // regular_field_arguments.push(self.create_field_snap_call(field)?); + // } + // let result = self.inner.result()?; + // let deref_fields = self + // .inner + // .lowerer + // .structural_invariant_to_deref_fields(&invariant)?; + // let mut constructor_encoder = AssertionToSnapshotConstructor::for_function_body( + // PredicateKind::FracRef { + // lifetime: self.reference_lifetime.clone().into(), + // }, + // self.inner.ty, + // regular_field_arguments, + // decl.fields.clone(), + // deref_fields, + // self.inner.position, + // ); + // let invariant_expression = invariant.into_iter().conjoin(); + // let permission_expression = invariant_expression.convert_into_permission_expression(); + // let constructor = constructor_encoder + // .expression_to_snapshot_constructor(self.inner.lowerer, &permission_expression)?; + // self.add_unfolding_postcondition(vir_low::Expression::equals( + // result.into(), + // constructor, + // ))?; + // } + // Ok(()) + // } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/function_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/function_use.rs new file mode 100644 index 00000000000..9d511e04e05 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/function_use.rs @@ -0,0 +1,65 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + builtin_methods::CallContext, lowerer::Lowerer, + predicates::owned::builders::common::function_use::FunctionCallBuilder, + }, +}; +use vir_crate::{ + low::{self as vir_low}, + middle::{ + self as vir_mid, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes}, + }, +}; + +pub(in super::super::super) struct FracRefSnapCallBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + inner: FunctionCallBuilder<'l, 'p, 'v, 'tcx, G>, +} + +impl<'l, 'p, 'v, 'tcx, G> FracRefSnapCallBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + pub(in super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + context: CallContext, + ty: &'l vir_mid::Type, + generics: &'l G, + place: vir_low::Expression, + root_address: vir_low::Expression, + reference_lifetime: vir_low::Expression, + target_slice_len: Option, + ) -> SpannedEncodingResult { + let mut arguments = vec![place, root_address, reference_lifetime]; + if let Some(len) = target_slice_len { + arguments.push(len); + } + let name = "snap_current_frac_ref"; + let inner = FunctionCallBuilder::new( + lowerer, + name, + context, + ty, + generics, + arguments, + Default::default(), + )?; + Ok(Self { inner }) + } + + pub(in super::super::super) fn build(self) -> SpannedEncodingResult { + self.inner.build() + } + + pub(in super::super::super) fn add_lifetime_arguments(&mut self) -> SpannedEncodingResult<()> { + self.inner.add_lifetime_arguments() + } + + pub(in super::super::super) fn add_const_arguments(&mut self) -> SpannedEncodingResult<()> { + self.inner.add_const_arguments() + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/mod.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/mod.rs index ef427252419..7714b35a296 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/mod.rs @@ -1,2 +1,4 @@ pub(super) mod predicate_decl; pub(super) mod predicate_use; +pub(super) mod function_decl; +pub(super) mod function_use; diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/predicate_decl.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/predicate_decl.rs index 1a6745c0f26..45ff43669ff 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/predicate_decl.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/predicate_decl.rs @@ -6,11 +6,15 @@ use crate::encoder::{ lifetimes::LifetimesInterface, lowerer::Lowerer, places::PlacesInterface, - predicates::owned::builders::{ - common::predicate_decl::PredicateDeclBuilder, PredicateDeclBuilderMethods, + predicates::{ + owned::builders::{ + common::predicate_decl::PredicateDeclBuilder, PredicateDeclBuilderMethods, + }, + PredicatesOwnedInterface, }, snapshots::{ - IntoPureSnapshot, IntoSnapshot, SnapshotValidityInterface, SnapshotValuesInterface, + IntoPureSnapshot, IntoSnapshot, IntoSnapshotLowerer, PredicateKind, + SelfFramingAssertionToSnapshot, SnapshotValidityInterface, SnapshotValuesInterface, }, type_layouts::TypeLayoutsInterface, }, @@ -25,8 +29,8 @@ use super::predicate_use::FracRefUseBuilder; pub(in super::super::super) struct FracRefBuilder<'l, 'p, 'v, 'tcx> { inner: PredicateDeclBuilder<'l, 'p, 'v, 'tcx>, - place: vir_low::VariableDecl, - root_address: vir_low::VariableDecl, + // place: vir_low::VariableDecl, + // root_address: vir_low::VariableDecl, current_snapshot: vir_low::VariableDecl, reference_lifetime: vir_low::VariableDecl, slice_len: Option, @@ -55,8 +59,8 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { None }; Ok(Self { - place: vir_low::VariableDecl::new("place", lowerer.place_type()?), - root_address: vir_low::VariableDecl::new("root_address", lowerer.address_type()?), + // place: vir_low::VariableDecl::new("place", lowerer.place_type()?), + // root_address: vir_low::VariableDecl::new("root_address", lowerer.address_type()?), current_snapshot: vir_low::VariableDecl::new( "current_snapshot", ty.to_snapshot(lowerer)?, @@ -81,9 +85,9 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { } pub(in super::super::super) fn create_parameters(&mut self) -> SpannedEncodingResult<()> { - self.inner.parameters.push(self.place.clone()); - self.inner.parameters.push(self.root_address.clone()); - self.inner.parameters.push(self.current_snapshot.clone()); + self.inner.parameters.push(self.inner.place.clone()); + self.inner.parameters.push(self.inner.root_address.clone()); + // self.inner.parameters.push(self.current_snapshot.clone()); self.inner.parameters.push(self.reference_lifetime.clone()); self.inner.create_lifetime_parameters()?; if let Some(slice_len_mid) = &self.slice_len { @@ -105,7 +109,7 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { let field_place = self.inner.lowerer.encode_field_place( self.inner.ty, field, - self.place.clone().into(), + self.inner.place.clone().into(), self.inner.position, )?; let current_field_snapshot = self.inner.lowerer.obtain_struct_field_snapshot( @@ -114,19 +118,30 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { self.current_snapshot.clone().into(), Default::default(), )?; - let mut builder = FracRefUseBuilder::new( - self.inner.lowerer, + // let mut builder = FracRefUseBuilder::new( + // self.inner.lowerer, + // CallContext::BuiltinMethod, + // &field.ty, + // &field.ty, + // field_place, + // self.inner.root_address.clone().into(), + // // current_field_snapshot, + // self.reference_lifetime.clone().into(), + // )?; + // builder.add_lifetime_arguments()?; + // builder.add_const_arguments()?; + // let expression = builder.build(); + let TODO_target_slice_len = None; + let expression = self.inner.lowerer.frac_ref_predicate( CallContext::BuiltinMethod, &field.ty, &field.ty, field_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), current_field_snapshot, self.reference_lifetime.clone().into(), + TODO_target_slice_len, )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let expression = builder.build(); self.inner.add_conjunct(expression) } @@ -138,7 +153,7 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { let discriminant_place = self.inner.lowerer.encode_field_place( self.inner.ty, &discriminant_field, - self.place.clone().into(), + self.inner.place.clone().into(), self.inner.position, )?; let current_discriminant_call = self.inner.lowerer.obtain_enum_discriminant( @@ -151,17 +166,27 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { current_discriminant_call, self.inner.position, )?; - let builder = FracRefUseBuilder::new( - self.inner.lowerer, + // let builder = FracRefUseBuilder::new( + // self.inner.lowerer, + // CallContext::BuiltinMethod, + // &decl.discriminant_type, + // &decl.discriminant_type, + // discriminant_place, + // self.inner.root_address.clone().into(), + // // current_discriminant_snapshot, + // self.reference_lifetime.clone().into(), + // )?; + // let expression = builder.build(); + let expression = self.inner.lowerer.frac_ref_predicate( CallContext::BuiltinMethod, &decl.discriminant_type, &decl.discriminant_type, discriminant_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), current_discriminant_snapshot, self.reference_lifetime.clone().into(), + None, )?; - let expression = builder.build(); self.inner.add_conjunct(expression) } @@ -170,11 +195,14 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { target_type: &vir_mid::Type, lifetime: &vir_mid::ty::LifetimeConst, ) -> SpannedEncodingResult<()> { + let place = self.inner.place.clone(); + let root_address = self.inner.root_address.clone(); self.inner.add_frac_ref_target_predicate( target_type, lifetime, - &self.place, - &self.current_snapshot, + place.into(), + root_address, + // &self.current_snapshot, ) } @@ -216,7 +244,7 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { let array_length_int = self.inner.array_length_int(array_length_mid)?; let element_place = self.inner.lowerer.encode_index_place( self.inner.ty, - self.place.clone().into(), + self.inner.place.clone().into(), index.clone().into(), self.inner.position, )?; @@ -225,19 +253,30 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { index_int.clone(), self.inner.position, )?; - let mut builder = FracRefUseBuilder::new( - self.inner.lowerer, + // let mut builder = FracRefUseBuilder::new( + // self.inner.lowerer, + // CallContext::BuiltinMethod, + // element_type, + // element_type, + // element_place, + // self.inner.root_address.clone().into(), + // // current_element_snapshot, + // self.reference_lifetime.clone().into(), + // )?; + // builder.add_lifetime_arguments()?; + // builder.add_const_arguments()?; + // let element_predicate_acc = builder.build(); + let TODO_target_slice_len = None; + let element_predicate_acc = self.inner.lowerer.frac_ref_predicate( CallContext::BuiltinMethod, element_type, element_type, element_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), current_element_snapshot, self.reference_lifetime.clone().into(), + TODO_target_slice_len, )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let element_predicate_acc = builder.build(); let elements = vir_low::Expression::forall( vec![index], vec![vir_low::Trigger::new(vec![element_predicate_acc.clone()])], @@ -268,7 +307,7 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { let variant_place = self.inner.lowerer.encode_enum_variant_place( self.inner.ty, &variant_index, - self.place.clone().into(), + self.inner.place.clone().into(), self.inner.position, )?; let current_variant_snapshot = self.inner.lowerer.obtain_enum_variant_snapshot( @@ -277,19 +316,30 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { self.current_snapshot.clone().into(), self.inner.position, )?; - let mut builder = FracRefUseBuilder::new( - self.inner.lowerer, + // let mut builder = FracRefUseBuilder::new( + // self.inner.lowerer, + // CallContext::BuiltinMethod, + // variant_type, + // variant_type, + // variant_place, + // self.inner.root_address.clone().into(), + // // current_variant_snapshot, + // self.reference_lifetime.clone().into(), + // )?; + // builder.add_lifetime_arguments()?; + // builder.add_const_arguments()?; + // let predicate = builder.build(); + let TODO_target_slice_len = None; + let predicate = self.inner.lowerer.frac_ref_predicate( CallContext::BuiltinMethod, variant_type, variant_type, variant_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), current_variant_snapshot, self.reference_lifetime.clone().into(), + TODO_target_slice_len, )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let predicate = builder.build(); Ok((guard, predicate)) } @@ -300,4 +350,39 @@ impl<'l, 'p, 'v, 'tcx> FracRefBuilder<'l, 'p, 'v, 'tcx> { self.inner .add_conjunct(variant_predicates.into_iter().create_match()) } + + pub(in super::super::super) fn add_structural_invariant( + &mut self, + decl: &vir_mid::type_decl::Struct, + ) -> SpannedEncodingResult> { + self.inner.add_structural_invariant( + decl, + PredicateKind::FracRef { + lifetime: self.reference_lifetime.clone().into(), + }, + ) + } + + // pub(in super::super::super) fn add_structural_invariant( + // &mut self, + // decl: &vir_mid::type_decl::Struct, + // ) -> SpannedEncodingResult> { + // if let Some(invariant) = &decl.structural_invariant { + // let mut encoder = SelfFramingAssertionToSnapshot::for_predicate_body( + // self.inner.place.clone(), + // self.inner.root_address.clone(), + // PredicateKind::FracRef { + // lifetime: self.reference_lifetime.clone().into(), + // }, + // ); + // for assertion in invariant { + // let low_assertion = + // encoder.expression_to_snapshot(self.inner.lowerer, assertion, true)?; + // self.inner.add_conjunct(low_assertion)?; + // } + // Ok(encoder.into_created_predicate_types()) + // } else { + // Ok(Vec::new()) + // } + // } } diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/predicate_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/predicate_use.rs index 04f8bb3dd27..1fbadd257f8 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/predicate_use.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/frac_ref/predicate_use.rs @@ -1,12 +1,17 @@ use crate::encoder::{ errors::SpannedEncodingResult, middle::core_proof::{ - builtin_methods::CallContext, lowerer::Lowerer, - predicates::owned::builders::common::predicate_use::PredicateUseBuilder, - snapshots::SnapshotValuesInterface, type_layouts::TypeLayoutsInterface, + builtin_methods::CallContext, + lowerer::Lowerer, + predicates::{ + owned::builders::common::predicate_use::PredicateUseBuilder, PredicatesOwnedInterface, + }, + snapshots::SnapshotValuesInterface, + type_layouts::TypeLayoutsInterface, }, }; use vir_crate::{ + common::expression::BinaryOperationHelpers, low::{self as vir_low}, middle::{ self as vir_mid, @@ -19,7 +24,7 @@ where G: WithLifetimes + WithConstArguments, { inner: PredicateUseBuilder<'l, 'p, 'v, 'tcx, G>, - current_snapshot: vir_low::Expression, + current_snapshot: Option, } impl<'l, 'p, 'v, 'tcx, G> FracRefUseBuilder<'l, 'p, 'v, 'tcx, G> @@ -34,7 +39,7 @@ where generics: &'l G, place: vir_low::Expression, root_address: vir_low::Expression, - current_snapshot: vir_low::Expression, + // current_snapshot: vir_low::Expression, lifetime: vir_low::Expression, ) -> SpannedEncodingResult { let inner = PredicateUseBuilder::new( @@ -43,17 +48,49 @@ where context, ty, generics, - vec![place, root_address, current_snapshot.clone(), lifetime], + vec![ + place, + root_address, // current_snapshot.clone(), + lifetime, + ], Default::default(), )?; Ok(Self { inner, - current_snapshot, + current_snapshot: None, }) } - pub(in super::super::super::super::super) fn build(self) -> vir_low::Expression { - self.inner.build() + pub(in super::super::super::super::super) fn build( + mut self, + ) -> SpannedEncodingResult { + let expression = if let Some(snapshot) = self.current_snapshot.take() { + let TODO_target_slice_len = None; + let snap_call = self.inner.lowerer.frac_ref_snap( + self.inner.context, + self.inner.ty, + self.inner.generics, + self.inner.arguments[0].clone(), + self.inner.arguments[1].clone(), + self.inner.arguments[2].clone(), + TODO_target_slice_len, + )?; + vir_low::Expression::and( + self.inner.build(), + vir_low::Expression::equals(snapshot, snap_call), + ) + } else { + self.inner.build() + }; + Ok(expression) + } + + pub(in super::super::super::super::super) fn add_snapshot_argument( + &mut self, + snapshot: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + self.current_snapshot = Some(snapshot); + Ok(()) } pub(in super::super::super::super::super) fn add_lifetime_arguments( @@ -66,17 +103,18 @@ where &mut self, ) -> SpannedEncodingResult<()> { if self.inner.ty.is_slice() { - let snapshot_length = self - .inner - .lowerer - .obtain_array_len_snapshot(self.current_snapshot.clone(), self.inner.position)?; - let size_type = self.inner.lowerer.size_type_mid()?; - let argument = self.inner.lowerer.construct_constant_snapshot( - &size_type, - snapshot_length, - self.inner.position, - )?; - self.inner.arguments.push(argument); + unimplemented!(); + // let snapshot_length = self + // .inner + // .lowerer + // .obtain_array_len_snapshot(self.current_snapshot.clone(), self.inner.position)?; + // let size_type = self.inner.lowerer.size_type_mid()?; + // let argument = self.inner.lowerer.construct_constant_snapshot( + // &size_type, + // snapshot_length, + // self.inner.position, + // )?; + // self.inner.arguments.push(argument); } self.inner.add_const_arguments() } diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/mod.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/mod.rs index f271de25200..1a5d3b7f667 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/mod.rs @@ -1,15 +1,34 @@ mod common; mod frac_ref; mod owned_non_aliased; +mod owned_aliased; mod unique_ref; pub(super) use self::{ - common::predicate_decl::PredicateDeclBuilderMethods, frac_ref::predicate_decl::FracRefBuilder, - owned_non_aliased::predicate_decl::OwnedNonAliasedBuilder, - unique_ref::predicate_decl::UniqueRefBuilder, + common::predicate_decl::PredicateDeclBuilderMethods, + frac_ref::{ + function_decl::FracRefSnapFunctionBuilder, function_use::FracRefSnapCallBuilder, + predicate_decl::FracRefBuilder, + }, + owned_aliased::{ + function_decl::OwnedAliasedSnapFunctionBuilder, predicate_decl::OwnedAliasedBuilder, + }, + owned_non_aliased::{ + function_decl::OwnedNonAliasedSnapFunctionBuilder, predicate_decl::OwnedNonAliasedBuilder, + }, + unique_ref::{ + function_decl::UniqueRefSnapFunctionBuilder, function_use::UniqueRefSnapCallBuilder, + predicate_decl::UniqueRefBuilder, + }, }; pub(in super::super::super) use self::{ frac_ref::predicate_use::FracRefUseBuilder, - owned_non_aliased::predicate_use::OwnedNonAliasedUseBuilder, + owned_aliased::{ + function_use::OwnedAliasedSnapCallBuilder, + predicate_range_use::OwnedAliasedRangeUseBuilder, predicate_use::OwnedAliasedUseBuilder, + }, + owned_non_aliased::{ + function_use::OwnedNonAliasedSnapCallBuilder, predicate_use::OwnedNonAliasedUseBuilder, + }, unique_ref::predicate_use::UniqueRefUseBuilder, }; diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/function_decl.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/function_decl.rs new file mode 100644 index 00000000000..dc69d316548 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/function_decl.rs @@ -0,0 +1,527 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + footprint::FootprintInterface, + lifetimes::LifetimesInterface, + lowerer::Lowerer, + places::PlacesInterface, + pointers::PointersInterface, + predicates::{ + owned::builders::common::function_decl::FunctionDeclBuilder, OwnedNonAliasedUseBuilder, + PredicatesMemoryBlockInterface, PredicatesOwnedInterface, + }, + references::ReferencesInterface, + snapshots::{ + AssertionToSnapshotConstructor, IntoPureSnapshot, IntoSnapshotLowerer, PredicateKind, + SnapshotBytesInterface, SnapshotValidityInterface, SnapshotValuesInterface, + }, + type_layouts::TypeLayoutsInterface, + }, +}; + +use vir_crate::{ + common::{ + expression::{BinaryOperationHelpers, ExpressionIterator, QuantifierHelpers}, + position::Positioned, + }, + low::{self as vir_low}, + middle::{ + self as vir_mid, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes, ty::Typed}, + }, +}; + +pub(in super::super::super) struct OwnedAliasedSnapFunctionBuilder<'l, 'p, 'v, 'tcx> { + inner: FunctionDeclBuilder<'l, 'p, 'v, 'tcx>, + address: vir_low::VariableDecl, + slice_len: Option, +} + +impl<'l, 'p, 'v, 'tcx> OwnedAliasedSnapFunctionBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + ty: &'l vir_mid::Type, + type_decl: &'l vir_mid::TypeDecl, + ) -> SpannedEncodingResult { + let slice_len = if ty.is_slice() { + Some(vir_mid::VariableDecl::new( + "slice_len", + lowerer.size_type_mid()?, + )) + } else { + None + }; + Ok(Self { + address: vir_low::VariableDecl::new("address", lowerer.address_type()?), + slice_len, + inner: FunctionDeclBuilder::new( + lowerer, + "snap_owned_aliased", + ty, + type_decl, + Default::default(), + )?, + }) + } + + pub(in super::super::super) fn build(self) -> SpannedEncodingResult { + self.inner.build() + } + + pub(in super::super::super) fn create_parameters(&mut self) -> SpannedEncodingResult<()> { + self.inner.parameters.push(self.address.clone()); + self.inner.create_lifetime_parameters()?; + if let Some(slice_len_mid) = &self.slice_len { + let slice_len = slice_len_mid.to_pure_snapshot(self.inner.lowerer)?; + self.inner.parameters.push(slice_len); + } + self.inner.create_const_parameters()?; + Ok(()) + } + + // FIXME: Code duplication. + pub(in super::super::super) fn get_slice_len( + &self, + ) -> SpannedEncodingResult { + Ok(self.slice_len.as_ref().unwrap().clone()) + } + + // fn owned_predicate( + // &mut self, + // ty: &vir_mid::Type, + // generics: &G, + // address: vir_low::Expression, + // ) -> SpannedEncodingResult + // where + // G: WithLifetimes + WithConstArguments, + // { + // let mut builder = OwnedNonAliasedUseBuilder::new( + // self.inner.lowerer, + // CallContext::BuiltinMethod, + // ty, + // generics, + // place, + // root_address, + // )?; + // builder.add_lifetime_arguments()?; + // builder.add_const_arguments()?; + // builder.build() + // } + + // FIXME: Code duplication with add_quantified_permission. + pub(in super::super::super) fn add_quantifiers( + &mut self, + array_length_mid: &vir_mid::VariableDecl, + element_type: &vir_mid::Type, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let size_type_mid = self.inner.lowerer.size_type_mid()?; + var_decls! { + index_int: Int + }; + let index = self.inner.lowerer.construct_constant_snapshot( + &size_type_mid, + index_int.clone().into(), + self.inner.position, + )?; + let index_validity = self + .inner + .lowerer + .encode_snapshot_valid_call_for_type(index.clone(), &size_type_mid)?; + let array_length_int = self.inner.array_length_int(array_length_mid)?; + let element_address = self.inner.lowerer.encode_index_address( + self.inner.ty, + self.address.clone().into(), + index, + self.inner.position, + )?; + let element_predicate_acc = { + self.inner.lowerer.owned_aliased( + CallContext::BuiltinMethod, + element_type, + element_type, + element_address.clone(), + None, + )? + }; + let result = self.inner.result()?.into(); + let element_snapshot = self.inner.lowerer.obtain_array_element_snapshot( + result, + index_int.clone().into(), + self.inner.position, + )?; + let element_snap_call = self.inner.lowerer.owned_aliased_snap( + CallContext::BuiltinMethod, + element_type, + element_type, + element_address, + self.inner.position, + )?; + let elements = vir_low::Expression::forall( + vec![index_int.clone()], + vec![vir_low::Trigger::new(vec![element_predicate_acc])], + expr! { + ([index_validity] && (index_int < [array_length_int])) ==> + ([element_snapshot] == [element_snap_call]) + }, + ); + self.add_unfolding_postcondition(elements) + } + + pub(in super::super::super) fn add_unfolding_postcondition( + &mut self, + body: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + let predicate = self.precondition_predicate()?; + let unfolding = predicate.into_unfolding(body); + self.inner.add_postcondition(unfolding) + } + + pub(in super::super::super) fn add_validity_postcondition( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.add_validity_postcondition() + } + + pub(in super::super::super) fn add_snapshot_len_equal_to_postcondition( + &mut self, + array_length_mid: &vir_mid::VariableDecl, + ) -> SpannedEncodingResult<()> { + self.inner + .add_snapshot_len_equal_to_postcondition(array_length_mid) + } + + pub(in super::super::super) fn add_owned_precondition(&mut self) -> SpannedEncodingResult<()> { + let predicate = self.precondition_predicate()?; + self.inner.add_precondition(predicate) + } + + fn precondition_predicate(&mut self) -> SpannedEncodingResult { + self.inner.lowerer.owned_aliased( + CallContext::BuiltinMethod, + self.inner.ty, + self.inner.type_decl, + self.address.clone().into(), + None, + ) + } + + // fn compute_address(&self) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let compute_address = ty!(Address); + // let expression = expr! { + // ComputeAddress::compute_address( + // [self.place.clone().into()], + // [self.root_address.clone().into()] + // ) + // }; + // Ok(expression) + // } + + fn size_of(&mut self) -> SpannedEncodingResult { + self.inner + .lowerer + .encode_type_size_expression2(self.inner.ty, self.inner.type_decl) + } + + // FIXME: Code duplication. + fn add_bytes_snapshot_equality_with( + &mut self, + snap_ty: &vir_mid::Type, + snapshot: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let size_of = self.size_of()?; + let bytes = self + .inner + .lowerer + .encode_memory_block_bytes_expression(self.address.clone().into(), size_of)?; + let to_bytes = ty! { Bytes }; + let expression = expr! { + [bytes] == (Snap::to_bytes([snapshot])) + }; + self.add_unfolding_postcondition(expression) + } + + pub(in super::super::super) fn add_bytes_snapshot_equality( + &mut self, + ) -> SpannedEncodingResult<()> { + let result = self.inner.result()?.into(); + self.add_bytes_snapshot_equality_with(self.inner.ty, result) + } + + pub(in super::super::super) fn add_bytes_address_snapshot_equality( + &mut self, + ) -> SpannedEncodingResult<()> { + let result = self.inner.result()?.into(); + let address_type = self.inner.lowerer.reference_address_type(self.inner.ty)?; + self.inner + .lowerer + .encode_snapshot_to_bytes_function(&address_type)?; + let target_address_snapshot = self.inner.lowerer.reference_address_snapshot( + self.inner.ty, + result, + self.inner.position, + )?; + self.add_bytes_snapshot_equality_with(&address_type, target_address_snapshot) + } + + // // fn create_field_snap_call( + // // &mut self, + // // field: &vir_mid::FieldDecl, + // // ) -> SpannedEncodingResult { + // // let field_place = self.inner.lowerer.encode_field_place( + // // self.inner.ty, + // // field, + // // self.place.clone().into(), + // // self.inner.position, + // // )?; + // // self.inner.lowerer.owned_non_aliased_snap( + // // CallContext::BuiltinMethod, + // // &field.ty, + // // &field.ty, + // // field_place, + // // self.root_address.clone().into(), + // // self.inner.position, + // // ) + // // } + + // // pub(in super::super::super) fn create_field_snapshot_equality( + // // &mut self, + // // field: &vir_mid::FieldDecl, + // // ) -> SpannedEncodingResult { + // // use vir_low::macros::*; + // // let result = self.inner.result()?; + // // let field_snapshot = self.inner.lowerer.obtain_struct_field_snapshot( + // // self.inner.ty, + // // field, + // // result.into(), + // // self.inner.position, + // // )?; + // // let snap_call = self.create_field_snap_call(&field)?; + // // Ok(expr! { + // // [field_snapshot] == [snap_call] + // // }) + // // } + + pub(in super::super::super) fn create_field_snapshot_equality( + &mut self, + field: &vir_mid::FieldDecl, + ) -> SpannedEncodingResult { + let owned_call = self.field_owned_snap()?; + self.inner.create_field_snapshot_equality(field, owned_call) + } + + fn field_owned_snap( + &mut self, + ) -> SpannedEncodingResult< + impl Fn( + &mut FunctionDeclBuilder, + &vir_mid::FieldDecl, + vir_low::Expression, + ) -> SpannedEncodingResult, + > { + let address: vir_low::Expression = self.address.clone().into(); + let address = std::rc::Rc::new(address); + Ok( + move |builder: &mut FunctionDeclBuilder, field: &vir_mid::FieldDecl, _| { + let field_address = builder.lowerer.encode_field_address( + builder.ty, + field, + (*address).clone(), + builder.position, + )?; + builder.lowerer.owned_aliased_snap( + CallContext::BuiltinMethod, + &field.ty, + &field.ty, + field_address, + builder.position, + ) + }, + ) + } + + pub(in super::super::super) fn create_discriminant_snapshot_equality( + &mut self, + decl: &vir_mid::type_decl::Enum, + ) -> SpannedEncodingResult { + use vir_low::macros::*; + let result = self.inner.result()?; + let discriminant_snapshot = self.inner.lowerer.obtain_enum_discriminant( + result.into(), + self.inner.ty, + self.inner.position, + )?; + let discriminant_field = decl.discriminant_field(); + let discriminant_address = self.inner.lowerer.encode_field_address( + self.inner.ty, + &discriminant_field, + self.address.clone().into(), + self.inner.position, + )?; + let snap_call = self.inner.lowerer.owned_aliased_snap( + CallContext::BuiltinMethod, + &decl.discriminant_type, + &decl.discriminant_type, + discriminant_address, + self.inner.position, + )?; + let snap_call_int = self.inner.lowerer.obtain_constant_value( + &decl.discriminant_type, + snap_call, + self.inner.position, + )?; + Ok(expr! { + [discriminant_snapshot] == [snap_call_int] + }) + } + + pub(in super::super::super) fn create_variant_snapshot_equality( + &mut self, + discriminant_value: vir_mid::DiscriminantValue, + variant: &vir_mid::type_decl::Struct, + ) -> SpannedEncodingResult<(vir_low::Expression, vir_low::Expression)> { + use vir_low::macros::*; + let result = self.inner.result()?; + let discriminant_call = self.inner.lowerer.obtain_enum_discriminant( + result.clone().into(), + self.inner.ty, + self.inner.position, + )?; + let guard = expr! { + [ discriminant_call ] == [ discriminant_value.into() ] + }; + let variant_index = variant.name.clone().into(); + let variant_address = self.inner.lowerer.encode_enum_variant_address( + self.inner.ty, + &variant_index, + self.address.clone().into(), + self.inner.position, + )?; + let variant_snapshot = self.inner.lowerer.obtain_enum_variant_snapshot( + self.inner.ty, + &variant_index, + result.into(), + self.inner.position, + )?; + let ty = self.inner.ty.clone(); + let variant_type = ty.variant(variant_index); + let snap_call = self.inner.lowerer.owned_aliased_snap( + CallContext::BuiltinMethod, + &variant_type, + // Enum variant and enum have the same set of lifetime parameters, + // so we use type_decl here. We cannot use `variant_type` because + // `ty` is normalized. + self.inner.type_decl, + variant_address, + self.inner.position, + )?; + let equality = expr! { + [variant_snapshot] == [snap_call] + }; + Ok((guard, equality)) + } + + pub(in super::super::super) fn add_reference_snapshot_equalities( + &mut self, + decl: &vir_mid::type_decl::Reference, + lifetime: &vir_mid::ty::LifetimeConst, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let result = self.inner.result()?; + let guard = self + .inner + .lowerer + .encode_lifetime_const_into_pure_is_alive_variable(lifetime)?; + let lifetime = lifetime.to_pure_snapshot(self.inner.lowerer)?; + let place = self + .inner + .lowerer + .encode_aliased_place_root(self.inner.position)?; + let deref_place = self + .inner + .lowerer + .reference_deref_place(place, self.inner.position)?; + let current_snapshot = self.inner.lowerer.reference_target_current_snapshot( + self.inner.ty, + result.clone().into(), + self.inner.position, + )?; + let final_snapshot = self.inner.lowerer.reference_target_final_snapshot( + self.inner.ty, + result.clone().into(), + self.inner.position, + )?; + let address = self.inner.lowerer.reference_address( + self.inner.ty, + result.clone().into(), + self.inner.position, + )?; + let slice_len = self.inner.lowerer.reference_slice_len( + self.inner.ty, + result.into(), + self.inner.position, + )?; + let equalities = if decl.uniqueness.is_unique() { + let current_snap_call = self.inner.lowerer.unique_ref_snap( + CallContext::BuiltinMethod, + &decl.target_type, + &decl.target_type, + deref_place.clone(), + address.clone(), + lifetime.clone().into(), + slice_len.clone(), + false, + )?; + let final_snap_call = self.inner.lowerer.unique_ref_snap( + CallContext::BuiltinMethod, + &decl.target_type, + &decl.target_type, + deref_place, + address, + lifetime.into(), + slice_len, + true, + )?; + expr! { + ([current_snapshot] == [current_snap_call]) && + ([final_snapshot] == [final_snap_call]) + } + } else { + let snap_call = self.inner.lowerer.frac_ref_snap( + CallContext::BuiltinMethod, + &decl.target_type, + &decl.target_type, + deref_place.clone(), + address.clone(), + lifetime.clone().into(), + slice_len.clone(), + )?; + expr! { + [current_snapshot] == [snap_call] + } + }; + let expression = expr! { + guard ==> [equalities] + }; + self.add_unfolding_postcondition(expression) + } + + pub(in super::super::super) fn add_structural_invariant( + &mut self, + decl: &vir_mid::type_decl::Struct, + ) -> SpannedEncodingResult<()> { + let precondition_predicate = self.precondition_predicate()?; + let predicate_kind = PredicateKind::Owned; + let snap_call = self.field_owned_snap()?; + self.inner.add_structural_invariant( + decl, + Some(precondition_predicate), + predicate_kind, + &snap_call, + ) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/function_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/function_use.rs new file mode 100644 index 00000000000..ba8337a387d --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/function_use.rs @@ -0,0 +1,73 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + builtin_methods::CallContext, lowerer::Lowerer, + predicates::owned::builders::common::function_use::FunctionCallBuilder, + }, +}; +use vir_crate::{ + low::{self as vir_low}, + middle::{ + self as vir_mid, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes}, + }, +}; + +pub(in super::super::super::super::super) struct OwnedAliasedSnapCallBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + inner: FunctionCallBuilder<'l, 'p, 'v, 'tcx, G>, +} + +impl<'l, 'p, 'v, 'tcx, G> OwnedAliasedSnapCallBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + pub(in super::super::super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + context: CallContext, + ty: &'l vir_mid::Type, + generics: &'l G, + address: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let arguments = vec![address]; + let inner = FunctionCallBuilder::new( + lowerer, + "snap_owned_aliased", + context, + ty, + generics, + arguments, + position, + )?; + Ok(Self { inner }) + } + + pub(in super::super::super::super::super) fn build( + self, + ) -> SpannedEncodingResult { + self.inner.build() + } + + // pub(in super::super::super::super::super) fn add_custom_argument( + // &mut self, + // argument: vir_low::Expression, + // ) -> SpannedEncodingResult<()> { + // self.inner.arguments.push(argument); + // Ok(()) + // } + + pub(in super::super::super::super::super) fn add_lifetime_arguments( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.add_lifetime_arguments() + } + + pub(in super::super::super::super::super) fn add_const_arguments( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.add_const_arguments() + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/mod.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/mod.rs new file mode 100644 index 00000000000..2dcae474d30 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/mod.rs @@ -0,0 +1,5 @@ +pub(super) mod function_decl; +pub(super) mod function_use; +pub(super) mod predicate_decl; +pub(super) mod predicate_use; +pub(super) mod predicate_range_use; diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/predicate_decl.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/predicate_decl.rs new file mode 100644 index 00000000000..68b512aff6d --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/predicate_decl.rs @@ -0,0 +1,313 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + lowerer::Lowerer, + places::PlacesInterface, + predicates::{ + owned::builders::{ + common::predicate_decl::PredicateDeclBuilder, PredicateDeclBuilderMethods, + }, + PredicatesMemoryBlockInterface, PredicatesOwnedInterface, + }, + references::ReferencesInterface, + snapshots::{ + IntoPureSnapshot, IntoSnapshot, IntoSnapshotLowerer, PredicateKind, + SelfFramingAssertionToSnapshot, SnapshotBytesInterface, SnapshotValidityInterface, + SnapshotValuesInterface, + }, + type_layouts::TypeLayoutsInterface, + }, +}; +use prusti_common::config; +use vir_crate::{ + common::{ + expression::{GuardedExpressionIterator, QuantifierHelpers}, + position::Positioned, + }, + low::{self as vir_low}, + middle::{self as vir_mid}, +}; + +pub(in super::super::super) struct OwnedAliasedBuilder<'l, 'p, 'v, 'tcx> { + inner: PredicateDeclBuilder<'l, 'p, 'v, 'tcx>, + slice_len: Option, +} + +impl<'l, 'p, 'v, 'tcx> PredicateDeclBuilderMethods<'l, 'p, 'v, 'tcx> + for OwnedAliasedBuilder<'l, 'p, 'v, 'tcx> +{ + fn inner(&mut self) -> &mut PredicateDeclBuilder<'l, 'p, 'v, 'tcx> { + &mut self.inner + } +} + +impl<'l, 'p, 'v, 'tcx> OwnedAliasedBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + ty: &'l vir_mid::Type, + type_decl: &'l vir_mid::TypeDecl, + ) -> SpannedEncodingResult { + let slice_len = if ty.is_slice() { + Some(vir_mid::VariableDecl::new( + "slice_len", + lowerer.size_type_mid()?, + )) + } else { + None + }; + let position = type_decl.position(); + Ok(Self { + slice_len, + inner: PredicateDeclBuilder::new(lowerer, "OwnedAliased", ty, type_decl, position)?, + }) + } + + pub(in super::super::super) fn build(self) -> vir_low::PredicateDecl { + self.inner.build() + } + + pub(in super::super::super) fn create_parameters(&mut self) -> SpannedEncodingResult<()> { + self.inner.parameters.push(self.inner.address.clone()); + self.inner.create_lifetime_parameters()?; + if let Some(slice_len_mid) = &self.slice_len { + let slice_len = slice_len_mid.to_pure_snapshot(self.inner.lowerer)?; + self.inner.parameters.push(slice_len); + } + self.inner.create_const_parameters()?; + Ok(()) + } + + fn size_of(&mut self) -> SpannedEncodingResult { + self.inner + .lowerer + .encode_type_size_expression2(self.inner.ty, self.inner.type_decl) + } + + fn padding_size(&mut self) -> SpannedEncodingResult { + self.inner + .lowerer + .encode_type_padding_size_expression(self.inner.ty) + } + + pub(in super::super::super) fn add_base_memory_block(&mut self) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let size_of = self.size_of()?; + let address = &self.inner.address; + let expression = expr! { + acc(MemoryBlock(address, [size_of])) + }; + self.inner.add_conjunct(expression) + } + + pub(in super::super::super) fn add_padding_memory_block( + &mut self, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let padding_size = self.padding_size()?; + let address = &self.inner.address; + let expression = expr! { + acc(MemoryBlock(address, [padding_size])) + }; + self.inner.add_conjunct(expression) + } + + pub(in super::super::super) fn add_field_predicate( + &mut self, + field: &vir_mid::FieldDecl, + ) -> SpannedEncodingResult<()> { + let field_address = self.inner.lowerer.encode_field_address( + self.inner.ty, + field, + self.inner.address.clone().into(), + self.inner.position, + )?; + let expression = self.inner.lowerer.owned_aliased( + CallContext::BuiltinMethod, + &field.ty, + &field.ty, + field_address, + None, + )?; + self.inner.add_conjunct(expression) + } + + pub(in super::super::super) fn add_discriminant_predicate( + &mut self, + decl: &vir_mid::type_decl::Enum, + ) -> SpannedEncodingResult<()> { + let discriminant_field = decl.discriminant_field(); + let discriminant_address = self.inner.lowerer.encode_field_address( + self.inner.ty, + &discriminant_field, + self.inner.address.clone().into(), + self.inner.position, + )?; + let expression = self.inner.lowerer.owned_aliased( + CallContext::BuiltinMethod, + &decl.discriminant_type, + &decl.discriminant_type, + discriminant_address, + None, + )?; + self.inner.add_conjunct(expression) + } + + pub(in super::super::super) fn add_unique_ref_target_predicate( + &mut self, + target_type: &vir_mid::Type, + lifetime: &vir_mid::ty::LifetimeConst, + ) -> SpannedEncodingResult<()> { + let place = self + .inner + .lowerer + .encode_aliased_place_root(self.inner.position)?; + let root_address = self.inner.address.clone(); + self.inner.add_unique_ref_target_predicate( + target_type, + lifetime, + place, + root_address, + false, + ) + } + + pub(in super::super::super) fn add_frac_ref_target_predicate( + &mut self, + target_type: &vir_mid::Type, + lifetime: &vir_mid::ty::LifetimeConst, + ) -> SpannedEncodingResult<()> { + let place = self + .inner + .lowerer + .encode_aliased_place_root(self.inner.position)?; + let root_address = self.inner.address.clone(); + self.inner + .add_frac_ref_target_predicate(target_type, lifetime, place, root_address) + } + + // FIXME: Code duplication. + pub(in super::super::super) fn get_slice_len( + &self, + ) -> SpannedEncodingResult { + Ok(self.slice_len.as_ref().unwrap().clone()) + } + + // pub(in super::super::super) fn add_quantified_permission( + // &mut self, + // array_length_mid: &vir_mid::VariableDecl, + // element_type: &vir_mid::Type, + // ) -> SpannedEncodingResult<()> { + // use vir_low::macros::*; + // let size_type = self.inner.lowerer.size_type()?; + // let size_type_mid = self.inner.lowerer.size_type_mid()?; + // var_decls! { + // index: {size_type} + // }; + // let index_validity = self + // .inner + // .lowerer + // .encode_snapshot_valid_call_for_type(index.clone().into(), &size_type_mid)?; + // let index_int = self.inner.lowerer.obtain_constant_value( + // &size_type_mid, + // index.clone().into(), + // self.inner.position, + // )?; + // let array_length_int = self.inner.array_length_int(array_length_mid)?; + // let element_place = self.inner.lowerer.encode_index_place( + // self.inner.ty, + // self.inner.place.clone().into(), + // index.clone().into(), + // self.inner.position, + // )?; + // let element_snapshot = self.inner.lowerer.obtain_array_element_snapshot( + // self.snapshot.clone().into(), + // index_int.clone(), + // self.inner.position, + // )?; + // let element_predicate_acc = self.inner.lowerer.owned_non_aliased_predicate( + // CallContext::BuiltinMethod, + // element_type, + // element_type, + // element_place, + // self.inner.root_address.clone().into(), + // element_snapshot, + // None, + // )?; + // let elements = vir_low::Expression::forall( + // vec![index], + // vec![vir_low::Trigger::new(vec![element_predicate_acc.clone()])], + // expr! { + // ([index_validity] && ([index_int] < [array_length_int])) ==> + // [element_predicate_acc] + // }, + // ); + // self.inner.add_conjunct(elements) + // } + + pub(in super::super::super) fn create_variant_predicate( + &mut self, + decl: &vir_mid::type_decl::Enum, + discriminant_value: vir_mid::DiscriminantValue, + variant: &vir_mid::type_decl::Struct, + variant_type: &vir_mid::Type, + ) -> SpannedEncodingResult<(vir_low::Expression, vir_low::Expression)> { + use vir_low::macros::*; + let discriminant_call = { + let discriminant_field = decl.discriminant_field(); + let discriminant_address = self.inner.lowerer.encode_field_address( + self.inner.ty, + &discriminant_field, + self.inner.place.clone().into(), + self.inner.position, + )?; + let discriminant_snapshot = self.inner.lowerer.owned_aliased_snap( + CallContext::BuiltinMethod, + &decl.discriminant_type, + &decl.discriminant_type, + discriminant_address, + self.inner.position, + )?; + self.inner.lowerer.obtain_constant_value( + &decl.discriminant_type, + discriminant_snapshot, + self.inner.position, + )? + }; + let guard = expr! { + [ discriminant_call ] == [ discriminant_value.into() ] + }; + let variant_index = variant.name.clone().into(); + let variant_address = self.inner.lowerer.encode_enum_variant_address( + self.inner.ty, + &variant_index, + self.inner.place.clone().into(), + self.inner.position, + )?; + let predicate = self.inner.lowerer.owned_aliased( + CallContext::BuiltinMethod, + variant_type, + variant_type, + variant_address, + None, + )?; + Ok((guard, predicate)) + } + + pub(in super::super::super) fn add_variant_predicates( + &mut self, + variant_predicates: Vec<(vir_low::Expression, vir_low::Expression)>, + ) -> SpannedEncodingResult<()> { + self.inner + .add_conjunct(variant_predicates.into_iter().create_match()) + } + + pub(in super::super::super) fn add_structural_invariant( + &mut self, + decl: &vir_mid::type_decl::Struct, + ) -> SpannedEncodingResult> { + self.inner + .add_structural_invariant(decl, PredicateKind::Owned) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/predicate_range_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/predicate_range_use.rs new file mode 100644 index 00000000000..0c2790cc1ea --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/predicate_range_use.rs @@ -0,0 +1,116 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + lowerer::{DomainsLowererInterface, Lowerer}, + pointers::PointersInterface, + predicates::{ + owned::builders::common::predicate_use::PredicateUseBuilder, PredicatesOwnedInterface, + }, + snapshots::{SnapshotValidityInterface, SnapshotValuesInterface}, + type_layouts::TypeLayoutsInterface, + }, +}; +use prusti_common::config; +use vir_crate::{ + common::expression::{BinaryOperationHelpers, QuantifierHelpers}, + low::{self as vir_low}, + middle::{ + self as vir_mid, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes}, + }, +}; + +pub(in super::super::super::super::super) struct OwnedAliasedRangeUseBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + context: CallContext, + ty: &'l vir_mid::Type, + generics: &'l G, + address: vir_low::Expression, + start_index: vir_low::Expression, + end_index: vir_low::Expression, + permission_amount: Option, + position: vir_low::Position, +} + +impl<'l, 'p, 'v, 'tcx, G> OwnedAliasedRangeUseBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + pub(in super::super::super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + context: CallContext, + ty: &'l vir_mid::Type, + generics: &'l G, + address: vir_low::Expression, + start_index: vir_low::Expression, + end_index: vir_low::Expression, + permission_amount: Option, + ) -> SpannedEncodingResult { + Ok(Self { + lowerer, + context, + ty, + generics, + address, + start_index, + end_index, + permission_amount, + position: Default::default(), + }) + } + + pub(in super::super::super::super::super) fn build( + self, + ) -> SpannedEncodingResult { + use vir_low::macros::*; + let size_type = self.lowerer.size_type_mid()?; + var_decls! { + index: Int + } + let vir_mid::Type::Pointer(ty) = self.ty else { + unreachable!() + }; + let initial_address = self + .lowerer + .pointer_address(self.ty, self.address, self.position)?; + let vir_mid::Type::Pointer(pointer_type) = self.ty else { + unreachable!() + }; + let size = self + .lowerer + .encode_type_size_expression2(&*pointer_type.target_type, &*pointer_type.target_type)?; + let element_address = self.lowerer.address_offset( + size, + initial_address, + index.clone().into(), + self.position, + )?; + let predicate = self.lowerer.owned_aliased( + self.context, + &ty.target_type, + self.generics, + element_address.clone(), + self.permission_amount, + )?; + let start_index = + self.lowerer + .obtain_constant_value(&size_type, self.start_index, self.position)?; + let end_index = + self.lowerer + .obtain_constant_value(&size_type, self.end_index, self.position)?; + let body = expr!( + (([start_index] <= index) && (index < [end_index])) ==> [predicate] + ); + let expression = vir_low::Expression::forall( + vec![index], + vec![vir_low::Trigger::new(vec![element_address])], + body, + ); + Ok(expression) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/predicate_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/predicate_use.rs new file mode 100644 index 00000000000..0e074839964 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_aliased/predicate_use.rs @@ -0,0 +1,76 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + builtin_methods::CallContext, + lowerer::Lowerer, + predicates::{ + owned::builders::common::predicate_use::PredicateUseBuilder, PredicatesOwnedInterface, + }, + }, +}; +use prusti_common::config; +use vir_crate::{ + common::expression::BinaryOperationHelpers, + low::{self as vir_low}, + middle::{ + self as vir_mid, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes}, + }, +}; + +pub(in super::super::super::super::super) struct OwnedAliasedUseBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + inner: PredicateUseBuilder<'l, 'p, 'v, 'tcx, G>, +} + +impl<'l, 'p, 'v, 'tcx, G> OwnedAliasedUseBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + pub(in super::super::super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + context: CallContext, + ty: &'l vir_mid::Type, + generics: &'l G, + address: vir_low::Expression, + ) -> SpannedEncodingResult { + let arguments = vec![address]; + let inner = PredicateUseBuilder::new( + lowerer, + "OwnedAliased", + context, + ty, + generics, + arguments, + Default::default(), + )?; + Ok(Self { inner }) + } + + pub(in super::super::super::super::super) fn build( + self, + ) -> SpannedEncodingResult { + Ok(self.inner.build()) + } + + pub(in super::super::super::super::super) fn add_lifetime_arguments( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.add_lifetime_arguments() + } + + pub(in super::super::super::super::super) fn add_const_arguments( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.add_const_arguments() + } + + pub(in super::super::super::super::super) fn set_maybe_permission_amount( + &mut self, + permission_amount: Option, + ) -> SpannedEncodingResult<()> { + self.inner.set_maybe_permission_amount(permission_amount) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/function_decl.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/function_decl.rs new file mode 100644 index 00000000000..cb794d5e5f4 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/function_decl.rs @@ -0,0 +1,622 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + footprint::FootprintInterface, + lifetimes::LifetimesInterface, + lowerer::Lowerer, + places::PlacesInterface, + pointers::PointersInterface, + predicates::{ + owned::builders::common::function_decl::FunctionDeclBuilder, OwnedNonAliasedUseBuilder, + PredicatesMemoryBlockInterface, PredicatesOwnedInterface, + }, + references::ReferencesInterface, + snapshots::{ + AssertionToSnapshotConstructor, IntoPureSnapshot, IntoSnapshotLowerer, PredicateKind, + SnapshotBytesInterface, SnapshotValidityInterface, SnapshotValuesInterface, + }, + type_layouts::TypeLayoutsInterface, + }, +}; + +use vir_crate::{ + common::{ + expression::{BinaryOperationHelpers, ExpressionIterator, QuantifierHelpers}, + position::Positioned, + }, + low::{self as vir_low}, + middle::{ + self as vir_mid, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes, ty::Typed}, + }, +}; + +pub(in super::super::super) struct OwnedNonAliasedSnapFunctionBuilder<'l, 'p, 'v, 'tcx> { + inner: FunctionDeclBuilder<'l, 'p, 'v, 'tcx>, + place: vir_low::VariableDecl, + root_address: vir_low::VariableDecl, + slice_len: Option, +} + +impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedSnapFunctionBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + ty: &'l vir_mid::Type, + type_decl: &'l vir_mid::TypeDecl, + ) -> SpannedEncodingResult { + let slice_len = if ty.is_slice() { + Some(vir_mid::VariableDecl::new( + "slice_len", + lowerer.size_type_mid()?, + )) + } else { + None + }; + Ok(Self { + place: vir_low::VariableDecl::new("place", lowerer.place_type()?), + root_address: vir_low::VariableDecl::new("root_address", lowerer.address_type()?), + slice_len, + inner: FunctionDeclBuilder::new( + lowerer, + "snap_owned_non_aliased", + ty, + type_decl, + Default::default(), + )?, + }) + } + + pub(in super::super::super) fn build(self) -> SpannedEncodingResult { + self.inner.build() + } + + pub(in super::super::super) fn create_parameters(&mut self) -> SpannedEncodingResult<()> { + self.inner.parameters.push(self.place.clone()); + self.inner.parameters.push(self.root_address.clone()); + self.inner.create_lifetime_parameters()?; + if let Some(slice_len_mid) = &self.slice_len { + let slice_len = slice_len_mid.to_pure_snapshot(self.inner.lowerer)?; + self.inner.parameters.push(slice_len); + } + self.inner.create_const_parameters()?; + Ok(()) + } + + // FIXME: Code duplication. + pub(in super::super::super) fn get_slice_len( + &self, + ) -> SpannedEncodingResult { + Ok(self.slice_len.as_ref().unwrap().clone()) + } + + fn owned_predicate( + &mut self, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let mut builder = OwnedNonAliasedUseBuilder::new( + self.inner.lowerer, + CallContext::BuiltinMethod, + ty, + generics, + place, + root_address, + )?; + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.set_maybe_permission_amount(Some(vir_low::Expression::wildcard_permission()))?; + builder.build() + } + + // FIXME: Code duplication with add_quantified_permission. + pub(in super::super::super) fn add_quantifiers( + &mut self, + array_length_mid: &vir_mid::VariableDecl, + element_type: &vir_mid::Type, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let size_type_mid = self.inner.lowerer.size_type_mid()?; + var_decls! { + index_int: Int + }; + let index = self.inner.lowerer.construct_constant_snapshot( + &size_type_mid, + index_int.clone().into(), + self.inner.position, + )?; + let index_validity = self + .inner + .lowerer + .encode_snapshot_valid_call_for_type(index.clone(), &size_type_mid)?; + let array_length_int = self.inner.array_length_int(array_length_mid)?; + let element_place = self.inner.lowerer.encode_index_place( + self.inner.ty, + self.place.clone().into(), + index, + self.inner.position, + )?; + let element_predicate_acc = { + self.owned_predicate( + element_type, + element_type, + element_place.clone(), + self.root_address.clone().into(), + )? + }; + let result = self.inner.result()?.into(); + let element_snapshot = self.inner.lowerer.obtain_array_element_snapshot( + result, + index_int.clone().into(), + self.inner.position, + )?; + let element_snap_call = self.inner.lowerer.owned_non_aliased_snap( + CallContext::BuiltinMethod, + element_type, + element_type, + element_place, + self.root_address.clone().into(), + self.inner.position, + )?; + let elements = vir_low::Expression::forall( + vec![index_int.clone()], + vec![vir_low::Trigger::new(vec![element_predicate_acc])], + expr! { + ([index_validity] && (index_int < [array_length_int])) ==> + ([element_snapshot] == [element_snap_call]) + }, + ); + self.add_unfolding_postcondition(elements) + } + + pub(in super::super::super) fn add_unfolding_postcondition( + &mut self, + body: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + let predicate = self.precondition_predicate()?; + let unfolding = predicate.into_unfolding(body); + self.inner.add_postcondition(unfolding) + } + + pub(in super::super::super) fn add_validity_postcondition( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.add_validity_postcondition() + } + + pub(in super::super::super) fn add_snapshot_len_equal_to_postcondition( + &mut self, + array_length_mid: &vir_mid::VariableDecl, + ) -> SpannedEncodingResult<()> { + self.inner + .add_snapshot_len_equal_to_postcondition(array_length_mid) + } + + pub(in super::super::super) fn add_owned_precondition(&mut self) -> SpannedEncodingResult<()> { + let predicate = self.precondition_predicate()?; + self.inner.add_precondition(predicate) + } + + fn precondition_predicate(&mut self) -> SpannedEncodingResult { + self.owned_predicate( + self.inner.ty, + self.inner.type_decl, + self.place.clone().into(), + self.root_address.clone().into(), + ) + } + + fn compute_address(&self) -> SpannedEncodingResult { + use vir_low::macros::*; + let compute_address = ty!(Address); + let expression = expr! { + ComputeAddress::compute_address( + [self.place.clone().into()], + [self.root_address.clone().into()] + ) + }; + Ok(expression) + } + + fn size_of(&mut self) -> SpannedEncodingResult { + self.inner + .lowerer + .encode_type_size_expression2(self.inner.ty, self.inner.type_decl) + } + + fn add_bytes_snapshot_equality_with( + &mut self, + snap_ty: &vir_mid::Type, + snapshot: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let size_of = self.size_of()?; + let bytes = self + .inner + .lowerer + .encode_memory_block_bytes_expression(self.compute_address()?, size_of)?; + let to_bytes = ty! { Bytes }; + let expression = expr! { + [bytes] == (Snap::to_bytes([snapshot])) + }; + self.add_unfolding_postcondition(expression) + } + + pub(in super::super::super) fn add_bytes_snapshot_equality( + &mut self, + ) -> SpannedEncodingResult<()> { + let result = self.inner.result()?.into(); + self.add_bytes_snapshot_equality_with(self.inner.ty, result) + } + + pub(in super::super::super) fn add_bytes_address_snapshot_equality( + &mut self, + ) -> SpannedEncodingResult<()> { + let result = self.inner.result()?.into(); + let address_type = self.inner.lowerer.reference_address_type(self.inner.ty)?; + self.inner + .lowerer + .encode_snapshot_to_bytes_function(&address_type)?; + let target_address_snapshot = self.inner.lowerer.reference_address_snapshot( + self.inner.ty, + result, + self.inner.position, + )?; + self.add_bytes_snapshot_equality_with(&address_type, target_address_snapshot) + } + + // fn create_field_snap_call( + // &mut self, + // field: &vir_mid::FieldDecl, + // ) -> SpannedEncodingResult { + // let field_place = self.inner.lowerer.encode_field_place( + // self.inner.ty, + // field, + // self.place.clone().into(), + // self.inner.position, + // )?; + // self.inner.lowerer.owned_non_aliased_snap( + // CallContext::BuiltinMethod, + // &field.ty, + // &field.ty, + // field_place, + // self.root_address.clone().into(), + // self.inner.position, + // ) + // } + + // pub(in super::super::super) fn create_field_snapshot_equality( + // &mut self, + // field: &vir_mid::FieldDecl, + // ) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let result = self.inner.result()?; + // let field_snapshot = self.inner.lowerer.obtain_struct_field_snapshot( + // self.inner.ty, + // field, + // result.into(), + // self.inner.position, + // )?; + // let snap_call = self.create_field_snap_call(&field)?; + // Ok(expr! { + // [field_snapshot] == [snap_call] + // }) + // } + + pub(in super::super::super) fn create_field_snapshot_equality( + &mut self, + field: &vir_mid::FieldDecl, + ) -> SpannedEncodingResult { + let owned_call = self.field_owned_snap()?; + self.inner.create_field_snapshot_equality(field, owned_call) + } + + fn field_owned_snap( + &mut self, + ) -> SpannedEncodingResult< + impl Fn( + &mut FunctionDeclBuilder, + &vir_mid::FieldDecl, + vir_low::Expression, + ) -> SpannedEncodingResult, + > { + let root_address: vir_low::Expression = self.root_address.clone().into(); + let root_address = std::rc::Rc::new(root_address); + Ok( + move |builder: &mut FunctionDeclBuilder, field: &vir_mid::FieldDecl, field_place| { + builder.lowerer.owned_non_aliased_snap( + CallContext::BuiltinMethod, + &field.ty, + &field.ty, + field_place, + (*root_address).clone(), + builder.position, + ) + }, + ) + } + + pub(in super::super::super) fn create_discriminant_snapshot_equality( + &mut self, + decl: &vir_mid::type_decl::Enum, + ) -> SpannedEncodingResult { + use vir_low::macros::*; + let result = self.inner.result()?; + let discriminant_snapshot = self.inner.lowerer.obtain_enum_discriminant( + result.into(), + self.inner.ty, + self.inner.position, + )?; + let discriminant_field = decl.discriminant_field(); + let discriminant_place = self.inner.lowerer.encode_field_place( + self.inner.ty, + &discriminant_field, + self.place.clone().into(), + self.inner.position, + )?; + let snap_call = self.inner.lowerer.owned_non_aliased_snap( + CallContext::BuiltinMethod, + &decl.discriminant_type, + &decl.discriminant_type, + discriminant_place, + self.root_address.clone().into(), + self.inner.position, + )?; + let snap_call_int = self.inner.lowerer.obtain_constant_value( + &decl.discriminant_type, + snap_call, + self.inner.position, + )?; + Ok(expr! { + [discriminant_snapshot] == [snap_call_int] + }) + } + + pub(in super::super::super) fn create_variant_snapshot_equality( + &mut self, + discriminant_value: vir_mid::DiscriminantValue, + variant: &vir_mid::type_decl::Struct, + ) -> SpannedEncodingResult<(vir_low::Expression, vir_low::Expression)> { + use vir_low::macros::*; + let result = self.inner.result()?; + let discriminant_call = self.inner.lowerer.obtain_enum_discriminant( + result.clone().into(), + self.inner.ty, + self.inner.position, + )?; + let guard = expr! { + [ discriminant_call ] == [ discriminant_value.into() ] + }; + let variant_index = variant.name.clone().into(); + let variant_place = self.inner.lowerer.encode_enum_variant_place( + self.inner.ty, + &variant_index, + self.place.clone().into(), + self.inner.position, + )?; + let variant_snapshot = self.inner.lowerer.obtain_enum_variant_snapshot( + self.inner.ty, + &variant_index, + result.into(), + self.inner.position, + )?; + let ty = self.inner.ty.clone(); + // let mut enum_ty = ty.unwrap_enum(); + // enum_ty.lifetimes = variant.lifetimes.clone(); + // enum_ty.variant = Some(variant_index); + // let variant_type = vir_mid::Type::Enum(enum_ty); + let variant_type = ty.variant(variant_index); + let snap_call = self.inner.lowerer.owned_non_aliased_snap( + CallContext::BuiltinMethod, + &variant_type, + // Enum variant and enum have the same set of lifetime parameters, + // so we use type_decl here. We cannot use `variant_type` because + // `ty` is normalized. + self.inner.type_decl, + variant_place, + self.root_address.clone().into(), + self.inner.position, + )?; + let equality = expr! { + [variant_snapshot] == [snap_call] + }; + Ok((guard, equality)) + } + + pub(in super::super::super) fn add_reference_snapshot_equalities( + &mut self, + decl: &vir_mid::type_decl::Reference, + lifetime: &vir_mid::ty::LifetimeConst, + ) -> SpannedEncodingResult<()> { + use vir_low::macros::*; + let result = self.inner.result()?; + let guard = self + .inner + .lowerer + .encode_lifetime_const_into_pure_is_alive_variable(lifetime)?; + let lifetime = lifetime.to_pure_snapshot(self.inner.lowerer)?; + let deref_place = self + .inner + .lowerer + .reference_deref_place(self.place.clone().into(), self.inner.position)?; + let current_snapshot = self.inner.lowerer.reference_target_current_snapshot( + self.inner.ty, + result.clone().into(), + self.inner.position, + )?; + let final_snapshot = self.inner.lowerer.reference_target_final_snapshot( + self.inner.ty, + result.clone().into(), + self.inner.position, + )?; + let address = self.inner.lowerer.reference_address( + self.inner.ty, + result.clone().into(), + self.inner.position, + )?; + let slice_len = self.inner.lowerer.reference_slice_len( + self.inner.ty, + result.into(), + self.inner.position, + )?; + let equalities = if decl.uniqueness.is_unique() { + let current_snap_call = self.inner.lowerer.unique_ref_snap( + CallContext::BuiltinMethod, + &decl.target_type, + &decl.target_type, + deref_place.clone(), + address.clone(), + lifetime.clone().into(), + slice_len.clone(), + false, + )?; + let final_snap_call = self.inner.lowerer.unique_ref_snap( + CallContext::BuiltinMethod, + &decl.target_type, + &decl.target_type, + deref_place, + address, + lifetime.into(), + slice_len, + true, + )?; + expr! { + ([current_snapshot] == [current_snap_call]) && + ([final_snapshot] == [final_snap_call]) + } + } else { + let snap_call = self.inner.lowerer.frac_ref_snap( + CallContext::BuiltinMethod, + &decl.target_type, + &decl.target_type, + deref_place.clone(), + address.clone(), + lifetime.clone().into(), + slice_len.clone(), + )?; + expr! { + [current_snapshot] == [snap_call] + } + }; + let expression = expr! { + guard ==> [equalities] + }; + self.add_unfolding_postcondition(expression) + } + + pub(in super::super::super) fn add_structural_invariant( + &mut self, + decl: &vir_mid::type_decl::Struct, + ) -> SpannedEncodingResult<()> { + let precondition_predicate = self.precondition_predicate()?; + let predicate_kind = PredicateKind::Owned; + let snap_call = self.field_owned_snap()?; + self.inner.add_structural_invariant( + decl, + Some(precondition_predicate), + predicate_kind, + &snap_call, + ) + } + + // // FIXME: Code duplication. + // pub(in super::super::super) fn add_structural_invariant( + // &mut self, + // decl: &vir_mid::type_decl::Struct, + // ) -> SpannedEncodingResult<()> { + // if let Some(invariant) = decl.structural_invariant.clone() { + // let mut regular_field_arguments = Vec::new(); + // for field in &decl.fields { + // let owned_call = self.field_owned_snap()?; + // let snap_call = self.inner.create_field_snap_call(field, owned_call)?; + // regular_field_arguments.push(snap_call); + // // regular_field_arguments.push(self.create_field_snap_call(field)?); + // } + // let result = self.inner.result()?; + // let deref_fields = self + // .inner + // .lowerer + // .structural_invariant_to_deref_fields(&invariant)?; + // let mut constructor_encoder = AssertionToSnapshotConstructor::for_function_body( + // PredicateKind::Owned, + // self.inner.ty, + // regular_field_arguments, + // decl.fields.clone(), + // deref_fields, + // self.inner.position, + // ); + // let invariant_expression = invariant.into_iter().conjoin(); + // let permission_expression = invariant_expression.convert_into_permission_expression(); + // let constructor = constructor_encoder + // .expression_to_snapshot_constructor(self.inner.lowerer, &permission_expression)?; + // self.add_unfolding_postcondition(vir_low::Expression::equals( + // result.into(), + // constructor, + // ))?; + // // let mut equalities = Vec::new(); + // // for assertion in invariant { + // // for (guard, place) in assertion.collect_guarded_owned_places() { + // // let parameter = self.inner.lowerer.compute_deref_parameter(&place)?; + // // let deref_result_snapshot = self.inner.lowerer.obtain_parameter_snapshot( + // // self.inner.ty, + // // ¶meter.name, + // // parameter.ty, + // // result.clone().into(), + // // self.inner.position, + // // )?; + // // let ty = place.get_type(); + // // let place_low = self.inner.lowerer.encode_expression_as_place(&place)?; + // // let root_address_low = { + // // // Code duplication with pointer_deref_into_address + // // let deref_place = place.get_last_dereferenced_pointer().unwrap(); + // // // TODO: replace self in deref_place with result. + // // let base_snapshot = deref_place.to_pure_snapshot(self.inner.lowerer)?; + // // let ty = deref_place.get_type(); + // // self.inner + // // .lowerer + // // .pointer_address(ty, base_snapshot, place.position())? + // // }; + // // let snap_call = self.inner.lowerer.owned_non_aliased_snap( + // // CallContext::BuiltinMethod, + // // ty, + // // ty, + // // place_low, + // // root_address_low, + // // self.inner.position, + // // )?; + // // equalities.push(expr! { + // // [deref_result_snapshot] == [snap_call] + // // }); + // // } + // // } + // // self.add_unfolding_postcondition(equalities.into_iter().conjoin())?; + // } + + // // // FIXME: Code duplication with encode_assign_method_rvalue + // // if let Some(invariant) = &decl.structural_invariant { + // // let mut assertion_encoder = + // // crate::encoder::middle::core_proof::builtin_methods::AssertionEncoder::new( + // // &decl, + // // Vec::new(), + // // &None, + // // ); + // // assertion_encoder.set_result_value(self.inner.result()?.clone()); + // // assertion_encoder.set_in_function(); + // // for assertion in invariant { + // // let low_assertion = assertion_encoder.expression_to_snapshot( + // // self.inner.lowerer, + // // assertion, + // // true, + // // )?; + // // self.add_unfolding_postcondition(low_assertion)?; + // // } + // // } + // Ok(()) + // } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/function_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/function_use.rs new file mode 100644 index 00000000000..eb6d0791a89 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/function_use.rs @@ -0,0 +1,74 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + builtin_methods::CallContext, lowerer::Lowerer, + predicates::owned::builders::common::function_use::FunctionCallBuilder, + }, +}; +use vir_crate::{ + low::{self as vir_low}, + middle::{ + self as vir_mid, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes}, + }, +}; + +pub(in super::super::super::super::super) struct OwnedNonAliasedSnapCallBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + inner: FunctionCallBuilder<'l, 'p, 'v, 'tcx, G>, +} + +impl<'l, 'p, 'v, 'tcx, G> OwnedNonAliasedSnapCallBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + pub(in super::super::super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + context: CallContext, + ty: &'l vir_mid::Type, + generics: &'l G, + place: vir_low::Expression, + root_address: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let arguments = vec![place, root_address]; + let inner = FunctionCallBuilder::new( + lowerer, + "snap_owned_non_aliased", + context, + ty, + generics, + arguments, + position, + )?; + Ok(Self { inner }) + } + + pub(in super::super::super::super::super) fn build( + self, + ) -> SpannedEncodingResult { + self.inner.build() + } + + pub(in super::super::super::super::super) fn add_custom_argument( + &mut self, + argument: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + self.inner.arguments.push(argument); + Ok(()) + } + + pub(in super::super::super::super::super) fn add_lifetime_arguments( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.add_lifetime_arguments() + } + + pub(in super::super::super::super::super) fn add_const_arguments( + &mut self, + ) -> SpannedEncodingResult<()> { + self.inner.add_const_arguments() + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/mod.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/mod.rs index ef427252419..6bcb70532ad 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/mod.rs @@ -1,2 +1,4 @@ +pub(super) mod function_decl; +pub(super) mod function_use; pub(super) mod predicate_decl; pub(super) mod predicate_use; diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/predicate_decl.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/predicate_decl.rs index 768d0b4dd55..89b7a0a6cb4 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/predicate_decl.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/predicate_decl.rs @@ -1,4 +1,3 @@ -use super::predicate_use::OwnedNonAliasedUseBuilder; use crate::encoder::{ errors::SpannedEncodingResult, middle::core_proof::{ @@ -10,26 +9,29 @@ use crate::encoder::{ owned::builders::{ common::predicate_decl::PredicateDeclBuilder, PredicateDeclBuilderMethods, }, - PredicatesMemoryBlockInterface, + PredicatesMemoryBlockInterface, PredicatesOwnedInterface, }, references::ReferencesInterface, snapshots::{ - IntoPureSnapshot, IntoSnapshot, SnapshotBytesInterface, SnapshotValidityInterface, + IntoPureSnapshot, IntoSnapshot, IntoSnapshotLowerer, PredicateKind, + SelfFramingAssertionToSnapshot, SnapshotBytesInterface, SnapshotValidityInterface, SnapshotValuesInterface, }, type_layouts::TypeLayoutsInterface, }, }; +use prusti_common::config; use vir_crate::{ - common::expression::{GuardedExpressionIterator, QuantifierHelpers}, + common::{ + expression::{GuardedExpressionIterator, QuantifierHelpers}, + position::Positioned, + }, low::{self as vir_low}, - middle as vir_mid, + middle::{self as vir_mid}, }; pub(in super::super::super) struct OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { inner: PredicateDeclBuilder<'l, 'p, 'v, 'tcx>, - place: vir_low::VariableDecl, - root_address: vir_low::VariableDecl, snapshot: vir_low::VariableDecl, slice_len: Option, } @@ -56,18 +58,11 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { } else { None }; + let position = type_decl.position(); Ok(Self { - place: vir_low::VariableDecl::new("place", lowerer.place_type()?), - root_address: vir_low::VariableDecl::new("root_address", lowerer.address_type()?), snapshot: vir_low::VariableDecl::new("snapshot", ty.to_snapshot(lowerer)?), slice_len, - inner: PredicateDeclBuilder::new( - lowerer, - "OwnedNonAliased", - ty, - type_decl, - Default::default(), - )?, + inner: PredicateDeclBuilder::new(lowerer, "OwnedNonAliased", ty, type_decl, position)?, }) } @@ -76,9 +71,11 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { } pub(in super::super::super) fn create_parameters(&mut self) -> SpannedEncodingResult<()> { - self.inner.parameters.push(self.place.clone()); - self.inner.parameters.push(self.root_address.clone()); - self.inner.parameters.push(self.snapshot.clone()); + self.inner.parameters.push(self.inner.place.clone()); + self.inner.parameters.push(self.inner.root_address.clone()); + if config::use_snapshot_parameters_in_predicates() { + self.inner.parameters.push(self.snapshot.clone()); + } self.inner.create_lifetime_parameters()?; if let Some(slice_len_mid) = &self.slice_len { let slice_len = slice_len_mid.to_pure_snapshot(self.inner.lowerer)?; @@ -97,8 +94,8 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { let compute_address = ty!(Address); let expression = expr! { ComputeAddress::compute_address( - [self.place.clone().into()], - [self.root_address.clone().into()] + [self.inner.place.clone().into()], + [self.inner.root_address.clone().into()] ) }; Ok(expression) @@ -184,7 +181,7 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { let field_place = self.inner.lowerer.encode_field_place( self.inner.ty, field, - self.place.clone().into(), + self.inner.place.clone().into(), self.inner.position, )?; let field_snapshot = self.inner.lowerer.obtain_struct_field_snapshot( @@ -193,18 +190,15 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { self.snapshot.clone().into(), Default::default(), )?; - let mut builder = OwnedNonAliasedUseBuilder::new( - self.inner.lowerer, + let expression = self.inner.lowerer.owned_non_aliased_predicate( CallContext::BuiltinMethod, &field.ty, &field.ty, field_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), field_snapshot, + None, )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let expression = builder.build(); self.inner.add_conjunct(expression) } @@ -216,7 +210,7 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { let discriminant_place = self.inner.lowerer.encode_field_place( self.inner.ty, &discriminant_field, - self.place.clone().into(), + self.inner.place.clone().into(), self.inner.position, )?; let discriminant_call = self.inner.lowerer.obtain_enum_discriminant( @@ -229,16 +223,15 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { discriminant_call, self.inner.position, )?; - let builder = OwnedNonAliasedUseBuilder::new( - self.inner.lowerer, + let expression = self.inner.lowerer.owned_non_aliased_predicate( CallContext::BuiltinMethod, &decl.discriminant_type, &decl.discriminant_type, discriminant_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), discriminant_snapshot, + None, )?; - let expression = builder.build(); self.inner.add_conjunct(expression) } @@ -247,11 +240,15 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { target_type: &vir_mid::Type, lifetime: &vir_mid::ty::LifetimeConst, ) -> SpannedEncodingResult<()> { + let place = self.inner.place.clone(); + let root_address = self.inner.root_address.clone(); self.inner.add_unique_ref_target_predicate( target_type, lifetime, - &self.place, - &self.snapshot, + place.into(), + root_address, + // &self.snapshot, + false, ) } @@ -260,10 +257,18 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { target_type: &vir_mid::Type, lifetime: &vir_mid::ty::LifetimeConst, ) -> SpannedEncodingResult<()> { - self.inner - .add_frac_ref_target_predicate(target_type, lifetime, &self.place, &self.snapshot) + let place = self.inner.place.clone(); + let root_address = self.inner.root_address.clone(); + self.inner.add_frac_ref_target_predicate( + target_type, + lifetime, + place.into(), + root_address, + // &self.snapshot, + ) } + // FIXME: Code duplication. pub(in super::super::super) fn get_slice_len( &self, ) -> SpannedEncodingResult { @@ -301,7 +306,7 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { let array_length_int = self.inner.array_length_int(array_length_mid)?; let element_place = self.inner.lowerer.encode_index_place( self.inner.ty, - self.place.clone().into(), + self.inner.place.clone().into(), index.clone().into(), self.inner.position, )?; @@ -310,18 +315,15 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { index_int.clone(), self.inner.position, )?; - let mut builder = OwnedNonAliasedUseBuilder::new( - self.inner.lowerer, + let element_predicate_acc = self.inner.lowerer.owned_non_aliased_predicate( CallContext::BuiltinMethod, element_type, element_type, element_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), element_snapshot, + None, )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let element_predicate_acc = builder.build(); let elements = vir_low::Expression::forall( vec![index], vec![vir_low::Trigger::new(vec![element_predicate_acc.clone()])], @@ -335,16 +337,40 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { pub(in super::super::super) fn create_variant_predicate( &mut self, + decl: &vir_mid::type_decl::Enum, discriminant_value: vir_mid::DiscriminantValue, variant: &vir_mid::type_decl::Struct, variant_type: &vir_mid::Type, ) -> SpannedEncodingResult<(vir_low::Expression, vir_low::Expression)> { use vir_low::macros::*; - let discriminant_call = self.inner.lowerer.obtain_enum_discriminant( - self.snapshot.clone().into(), - self.inner.ty, - self.inner.position, - )?; + let discriminant_call = if config::use_snapshot_parameters_in_predicates() { + self.inner.lowerer.obtain_enum_discriminant( + self.snapshot.clone().into(), + self.inner.ty, + self.inner.position, + )? + } else { + let discriminant_field = decl.discriminant_field(); + let discriminant_place = self.inner.lowerer.encode_field_place( + self.inner.ty, + &discriminant_field, + self.inner.place.clone().into(), + self.inner.position, + )?; + let discriminant_snapshot = self.inner.lowerer.owned_non_aliased_snap( + CallContext::BuiltinMethod, + &decl.discriminant_type, + &decl.discriminant_type, + discriminant_place, + self.inner.root_address.clone().into(), + self.inner.position, + )?; + self.inner.lowerer.obtain_constant_value( + &decl.discriminant_type, + discriminant_snapshot, + self.inner.position, + )? + }; let guard = expr! { [ discriminant_call ] == [ discriminant_value.into() ] }; @@ -352,7 +378,7 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { let variant_place = self.inner.lowerer.encode_enum_variant_place( self.inner.ty, &variant_index, - self.place.clone().into(), + self.inner.place.clone().into(), self.inner.position, )?; let variant_snapshot = self.inner.lowerer.obtain_enum_variant_snapshot( @@ -361,18 +387,15 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { self.snapshot.clone().into(), self.inner.position, )?; - let mut builder = OwnedNonAliasedUseBuilder::new( - self.inner.lowerer, + let predicate = self.inner.lowerer.owned_non_aliased_predicate( CallContext::BuiltinMethod, variant_type, variant_type, variant_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), variant_snapshot, + None, )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let predicate = builder.build(); Ok((guard, predicate)) } @@ -383,4 +406,313 @@ impl<'l, 'p, 'v, 'tcx> OwnedNonAliasedBuilder<'l, 'p, 'v, 'tcx> { self.inner .add_conjunct(variant_predicates.into_iter().create_match()) } + + pub(in super::super::super) fn add_structural_invariant( + &mut self, + decl: &vir_mid::type_decl::Struct, + ) -> SpannedEncodingResult> { + self.inner + .add_structural_invariant(decl, PredicateKind::Owned) + } + + // pub(in super::super::super) fn add_structural_invariant( + // &mut self, + // decl: &vir_mid::type_decl::Struct, + // ) -> SpannedEncodingResult<()> { + // if let Some(invariant) = &decl.structural_invariant { + // let mut encoder = SelfFramingAssertionToSnapshot::for_predicate_body( + // self.inner.place.clone(), + // self.inner.root_address.clone(), + // PredicateKind::Owned, + // ); + // // let mut encoder = PredicateAssertionEncoder { + // // place: &self.inner.place, + // // root_address: &self.inner.root_address, + // // snap_calls: Default::default(), + // // }; + // for assertion in invariant { + // let low_assertion = + // encoder.expression_to_snapshot(self.inner.lowerer, assertion, true)?; + // self.inner.add_conjunct(low_assertion)?; + // } + // } + // Ok(()) + // } } + +// // FIXME: Move this to its own module. +// FIXME: This should be replaced by prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/self_framing.rs +// struct PredicateAssertionEncoder<'a> { +// place: &'a vir_low::VariableDecl, +// root_address: &'a vir_low::VariableDecl, +// /// Mapping from place to snapshot. We use a vector because we need to know +// /// the insertion order. +// snap_calls: Vec<(vir_mid::Expression, vir_low::Expression)>, +// } + +// impl<'a> PredicateAssertionEncoder<'a> { +// // FIXME: Code duplication. +// fn pointer_deref_into_address<'p, 'v, 'tcx>( +// &mut self, +// lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// place: &vir_mid::Expression, +// ) -> SpannedEncodingResult { +// if let Some(deref_place) = place.get_last_dereferenced_pointer() { +// let base_snapshot = self.expression_to_snapshot(lowerer, deref_place, true)?; +// let ty = deref_place.get_type(); +// lowerer.pointer_address(ty, base_snapshot, place.position()) +// } else { +// unreachable!() +// } +// // match place { +// // vir_mid::Expression::Deref(deref) => { +// // let base_snapshot = self.expression_to_snapshot(lowerer, &deref.base, true)?; +// // let ty = deref.base.get_type(); +// // assert!(ty.is_pointer()); +// // lowerer.pointer_address(ty, base_snapshot, place.position()) +// // } +// // _ => unreachable!(), +// // } +// } +// } + +// impl<'a, 'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for PredicateAssertionEncoder<'a> { +// fn expression_to_snapshot( +// &mut self, +// lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// expression: &vir_mid::Expression, +// expect_math_bool: bool, +// ) -> SpannedEncodingResult { +// for (place, call) in &self.snap_calls { +// if place == expression { +// return Ok(call.clone()); +// } +// } +// self.expression_to_snapshot_impl(lowerer, expression, expect_math_bool) +// } + +// fn variable_to_snapshot( +// &mut self, +// lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// variable: &vir_mid::VariableDecl, +// ) -> SpannedEncodingResult { +// assert!(variable.is_self_variable(), "{} must be self", variable); +// Ok(vir_low::VariableDecl { +// name: variable.name.clone(), +// ty: self.type_to_snapshot(lowerer, &variable.ty)?, +// }) +// } + +// fn labelled_old_to_snapshot( +// &mut self, +// _lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// _old: &vir_mid::LabelledOld, +// _expect_math_bool: bool, +// ) -> SpannedEncodingResult { +// unreachable!("Old expression are not allowed in predicates"); +// } + +// fn func_app_to_snapshot( +// &mut self, +// lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// app: &vir_mid::FuncApp, +// expect_math_bool: bool, +// ) -> SpannedEncodingResult { +// todo!() +// } + +// fn binary_op_to_snapshot( +// &mut self, +// lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// op: &vir_mid::BinaryOp, +// expect_math_bool: bool, +// ) -> SpannedEncodingResult { +// // TODO: Create impl versions of each method so that I can override +// // without copying. +// let mut introduced_snap = false; +// if op.op_kind == vir_mid::BinaryOpKind::And { +// if let box vir_mid::Expression::AccPredicate(expression) = &op.left { +// if expression.predicate.is_owned_non_aliased() { +// introduced_snap = true; +// } +// } +// } +// let expression = self.binary_op_to_snapshot_impl(lowerer, op, expect_math_bool)?; +// if introduced_snap { +// // TODO: Use the snap calls from this vector instead of generating +// // on demand. This must always succeed because we require +// // expressions to be framed. +// self.snap_calls.pop(); +// } +// Ok(expression) +// } + +// fn acc_predicate_to_snapshot( +// &mut self, +// lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// acc_predicate: &vir_mid::AccPredicate, +// expect_math_bool: bool, +// ) -> SpannedEncodingResult { +// assert!(expect_math_bool); +// let expression = match &*acc_predicate.predicate { +// vir_mid::Predicate::OwnedNonAliased(predicate) => { +// let ty = predicate.place.get_type(); +// let place = lowerer.encode_expression_as_place(&predicate.place)?; +// let root_address = self.pointer_deref_into_address(lowerer, &predicate.place)?; +// let snapshot = true.into(); +// let acc = lowerer.owned_non_aliased_predicate( +// CallContext::Procedure, +// ty, +// ty, +// place.clone(), +// root_address.clone(), +// snapshot, +// None, +// )?; +// let snap_call = lowerer.owned_non_aliased_snap( +// CallContext::BuiltinMethod, +// ty, +// ty, +// place, +// root_address, +// predicate.place.position(), +// )?; +// self.snap_calls.push((predicate.place.clone(), snap_call)); +// acc +// } +// vir_mid::Predicate::MemoryBlockHeap(predicate) => { +// let place = lowerer.encode_expression_as_place(&predicate.address)?; +// let root_address = self.pointer_deref_into_address(lowerer, &predicate.address)?; +// use vir_low::macros::*; +// let compute_address = ty!(Address); +// let address = expr! { +// ComputeAddress::compute_address([place], [root_address]) +// }; +// let size = +// self.expression_to_snapshot(lowerer, &predicate.size, expect_math_bool)?; +// lowerer.encode_memory_block_stack_acc(address, size, acc_predicate.position)? +// } +// vir_mid::Predicate::MemoryBlockHeapDrop(predicate) => { +// let place = self.pointer_deref_into_address(lowerer, &predicate.address)?; +// let size = +// self.expression_to_snapshot(lowerer, &predicate.size, expect_math_bool)?; +// lowerer.encode_memory_block_heap_drop_acc(place, size, acc_predicate.position)? +// } +// _ => unimplemented!("{acc_predicate}"), +// }; +// Ok(expression) +// } + +// fn field_to_snapshot( +// &mut self, +// lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// field: &vir_mid::Field, +// expect_math_bool: bool, +// ) -> SpannedEncodingResult { +// match &*field.base { +// vir_mid::Expression::Local(local) => { +// assert!(local.variable.is_self_variable()); +// let field_place = lowerer.encode_field_place( +// &local.variable.ty, +// &field.field, +// self.inner.place.clone().into(), +// field.position, +// )?; +// lowerer.owned_non_aliased_snap( +// CallContext::BuiltinMethod, +// &field.field.ty, +// &field.field.ty, +// field_place, +// self.inner.root_address.clone().into(), +// local.position, +// ) +// } +// _ => { +// // FIXME: Code duplication because Rust does not have syntax for calling +// // overriden methods. +// let base_snapshot = +// self.expression_to_snapshot(lowerer, &field.base, expect_math_bool)?; +// let result = if field.field.is_discriminant() { +// let ty = field.base.get_type(); +// // FIXME: Create a method for obtainging the discriminant type. +// let type_decl = lowerer.encoder.get_type_decl_mid(ty)?; +// let enum_decl = type_decl.unwrap_enum(); +// let discriminant_call = +// lowerer.obtain_enum_discriminant(base_snapshot, ty, field.position)?; +// lowerer.construct_constant_snapshot( +// &enum_decl.discriminant_type, +// discriminant_call, +// field.position, +// )? +// } else { +// lowerer.obtain_struct_field_snapshot( +// field.base.get_type(), +// &field.field, +// base_snapshot, +// field.position, +// )? +// }; +// self.ensure_bool_expression(lowerer, field.get_type(), result, expect_math_bool) +// } +// } +// } + +// // FIXME: Code duplication. +// fn deref_to_snapshot( +// &mut self, +// lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// deref: &vir_mid::Deref, +// expect_math_bool: bool, +// ) -> SpannedEncodingResult { +// let base_snapshot = self.expression_to_snapshot(lowerer, &deref.base, expect_math_bool)?; +// let ty = deref.base.get_type(); +// let result = if ty.is_reference() { +// lowerer.reference_target_current_snapshot(ty, base_snapshot, Default::default())? +// } else { +// let aliased_root_place = lowerer.encode_aliased_place_root(deref.position)?; +// let root_address = lowerer.pointer_address(ty, base_snapshot, deref.position)?; +// lowerer.owned_non_aliased_snap( +// CallContext::BuiltinMethod, +// &deref.ty, +// &deref.ty, +// aliased_root_place, +// root_address, +// deref.position, +// )? +// // snap_owned_non_aliased$I32(aliased_place_root(), destructor$Snap$ptr$I32$$address(snap_owned_non_aliased$ptr$I32(field_place$$struct$m_T5$$$f$p2(place), root_address))) + +// // FIXME: This should be unreachable. Most likely, in predicates we should use snap +// // functions. +// // let heap = vir_low::VariableDecl::new("predicate_heap$", lowerer.heap_type()?); +// // lowerer.pointer_target_snapshot_in_heap( +// // deref.base.get_type(), +// // heap, +// // base_snapshot, +// // deref.position, +// // )? +// }; +// self.ensure_bool_expression(lowerer, deref.get_type(), result, expect_math_bool) +// } + +// fn owned_non_aliased_snap( +// &mut self, +// lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// ty: &vir_mid::Type, +// pointer_snapshot: &vir_mid::Expression, +// ) -> SpannedEncodingResult { +// unimplemented!() +// } + +// fn call_context(&self) -> CallContext { +// CallContext::BuiltinMethod +// } + +// // fn unfolding_to_snapshot( +// // &mut self, +// // lowerer: &mut Lowerer<'p, 'v, 'tcx>, +// // unfolding: &vir_mid::Unfolding, +// // expect_math_bool: bool, +// // ) -> SpannedEncodingResult { +// // todo!() +// // } +// } diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/predicate_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/predicate_use.rs index f40e59eee36..789c24a0dac 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/predicate_use.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/predicate_use.rs @@ -1,11 +1,16 @@ use crate::encoder::{ errors::SpannedEncodingResult, middle::core_proof::{ - builtin_methods::CallContext, lowerer::Lowerer, - predicates::owned::builders::common::predicate_use::PredicateUseBuilder, + builtin_methods::CallContext, + lowerer::Lowerer, + predicates::{ + owned::builders::common::predicate_use::PredicateUseBuilder, PredicatesOwnedInterface, + }, }, }; +use prusti_common::config; use vir_crate::{ + common::expression::BinaryOperationHelpers, low::{self as vir_low}, middle::{ self as vir_mid, @@ -18,6 +23,7 @@ where G: WithLifetimes + WithConstArguments, { inner: PredicateUseBuilder<'l, 'p, 'v, 'tcx, G>, + snapshot: Option, } impl<'l, 'p, 'v, 'tcx, G> OwnedNonAliasedUseBuilder<'l, 'p, 'v, 'tcx, G> @@ -31,22 +37,55 @@ where generics: &'l G, place: vir_low::Expression, root_address: vir_low::Expression, - snapshot: vir_low::Expression, ) -> SpannedEncodingResult { + let arguments = vec![place, root_address]; let inner = PredicateUseBuilder::new( lowerer, "OwnedNonAliased", context, ty, generics, - vec![place, root_address, snapshot], + arguments, Default::default(), )?; - Ok(Self { inner }) + Ok(Self { + inner, + snapshot: None, + }) + } + + pub(in super::super::super::super::super) fn build( + mut self, + ) -> SpannedEncodingResult { + let expression = if let Some(snapshot) = self.snapshot.take() { + let snap_call = self.inner.lowerer.owned_non_aliased_snap( + self.inner.context, + self.inner.ty, + self.inner.generics, + self.inner.arguments[0].clone(), + self.inner.arguments[1].clone(), + self.inner.position, + )?; + vir_low::Expression::and( + self.inner.build(), + vir_low::Expression::equals(snapshot, snap_call), + ) + } else { + self.inner.build() + }; + Ok(expression) } - pub(in super::super::super::super::super) fn build(self) -> vir_low::Expression { - self.inner.build() + pub(in super::super::super::super::super) fn add_snapshot_argument( + &mut self, + snapshot: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + if config::use_snapshot_parameters_in_predicates() { + self.inner.arguments.push(snapshot); + } else { + self.snapshot = Some(snapshot); + } + Ok(()) } pub(in super::super::super::super::super) fn add_lifetime_arguments( diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/function_decl.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/function_decl.rs new file mode 100644 index 00000000000..191e0194734 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/function_decl.rs @@ -0,0 +1,496 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + lifetimes::LifetimesInterface, + lowerer::Lowerer, + places::PlacesInterface, + predicates::{ + owned::builders::common::function_decl::FunctionDeclBuilder, PredicatesOwnedInterface, + }, + snapshots::{IntoPureSnapshot, PredicateKind}, + type_layouts::TypeLayoutsInterface, + }, +}; + +use vir_crate::{ + common::identifier::WithIdentifier, + low::{self as vir_low}, + middle::{ + self as vir_mid, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes}, + }, +}; + +use super::predicate_use::UniqueRefUseBuilder; + +pub(in super::super::super) struct UniqueRefSnapFunctionBuilder<'l, 'p, 'v, 'tcx> { + inner: FunctionDeclBuilder<'l, 'p, 'v, 'tcx>, + // place: vir_low::VariableDecl, + root_address: vir_low::VariableDecl, + reference_lifetime: vir_low::VariableDecl, + slice_len: Option, + is_final: bool, +} + +impl<'l, 'p, 'v, 'tcx> UniqueRefSnapFunctionBuilder<'l, 'p, 'v, 'tcx> { + pub(in super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + ty: &'l vir_mid::Type, + type_decl: &'l vir_mid::TypeDecl, + is_final: bool, + ) -> SpannedEncodingResult { + let slice_len = if ty.is_slice() { + Some(vir_mid::VariableDecl::new( + "slice_len", + lowerer.size_type_mid()?, + )) + } else { + None + }; + let function_name = if is_final { + "snap_final_unique_ref" + } else { + "snap_current_unique_ref" + }; + // let place = vir_low::VariableDecl::new("place", lowerer.place_type()?); + Ok(Self { + root_address: vir_low::VariableDecl::new("root_address", lowerer.address_type()?), + reference_lifetime: vir_low::VariableDecl::new( + "reference_lifetime", + lowerer.lifetime_type()?, + ), + slice_len, + inner: FunctionDeclBuilder::new( + lowerer, + function_name, + ty, + type_decl, + Default::default(), + // place, + )?, + is_final, + }) + } + + pub(in super::super::super) fn build(self) -> SpannedEncodingResult { + // let snap_final_name = format!("snap_final_unique_ref${}", self.inner.ty.get_identifier()); + // let snap_current = self.inner.build()?; + // let mut snap_final = snap_current.clone(); + // snap_final.name = snap_final_name; + // Ok((snap_current, snap_final)) + self.inner.build() + } + + pub(in super::super::super) fn create_parameters(&mut self) -> SpannedEncodingResult<()> { + self.inner.parameters.push(self.inner.place.clone()); + self.inner.parameters.push(self.root_address.clone()); + self.inner.parameters.push(self.reference_lifetime.clone()); + self.inner.create_lifetime_parameters()?; + if let Some(slice_len_mid) = &self.slice_len { + let slice_len = slice_len_mid.to_pure_snapshot(self.inner.lowerer)?; + self.inner.parameters.push(slice_len); + } + self.inner.create_const_parameters()?; + Ok(()) + } + + // // FIXME: Code duplication. + // pub(in super::super::super) fn get_slice_len( + // &self, + // ) -> SpannedEncodingResult { + // Ok(self.slice_len.as_ref().unwrap().clone()) + // } + + fn unique_ref_predicate( + &mut self, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + reference_lifetime: vir_low::Expression, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let slice_len = if let Some(slice_len_mid) = &self.slice_len { + let slice_len = slice_len_mid.to_pure_snapshot(self.inner.lowerer)?; + Some(slice_len.into()) + } else { + None + }; + // let final_snapshot = self.inner.lowerer.unique_ref_snap( + // CallContext::BuiltinMethod, + // ty, + // generics, + // place.clone(), + // root_address.clone(), + // reference_lifetime.clone(), + // slice_len.clone(), + // true, + // )?; + let mut builder = UniqueRefUseBuilder::new( + self.inner.lowerer, + CallContext::BuiltinMethod, + ty, + generics, + place, + root_address, + // final_snapshot, + reference_lifetime, + slice_len, + )?; + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.build() + } + + // // FIXME: Code duplication with add_quantified_permission. + // pub(in super::super::super) fn add_quantifiers( + // &mut self, + // array_length_mid: &vir_mid::VariableDecl, + // element_type: &vir_mid::Type, + // ) -> SpannedEncodingResult<()> { + // use vir_low::macros::*; + // let size_type_mid = self.inner.lowerer.size_type_mid()?; + // var_decls! { + // index_int: Int + // }; + // let index = self.inner.lowerer.construct_constant_snapshot( + // &size_type_mid, + // index_int.clone().into(), + // self.inner.position, + // )?; + // let index_validity = self + // .inner + // .lowerer + // .encode_snapshot_valid_call_for_type(index.clone(), &size_type_mid)?; + // let array_length_int = self.inner.array_length_int(array_length_mid)?; + // let element_place = self.inner.lowerer.encode_index_place( + // self.inner.ty, + // self.place.clone().into(), + // index, + // self.inner.position, + // )?; + // let element_predicate_acc = { + // self.owned_predicate( + // element_type, + // element_type, + // element_place.clone(), + // self.root_address.clone().into(), + // )? + // }; + // let result = self.inner.result()?.into(); + // let element_snapshot = self.inner.lowerer.obtain_array_element_snapshot( + // result, + // index_int.clone().into(), + // self.inner.position, + // )?; + // let element_snap_call = self.inner.lowerer.owned_non_aliased_snap( + // CallContext::BuiltinMethod, + // element_type, + // element_type, + // element_place, + // self.root_address.clone().into(), + // )?; + // let elements = vir_low::Expression::forall( + // vec![index_int.clone()], + // vec![vir_low::Trigger::new(vec![element_predicate_acc])], + // expr! { + // ([index_validity] && (index_int < [array_length_int])) ==> + // ([element_snapshot] == [element_snap_call]) + // }, + // ); + // self.add_unfolding_postcondition(elements) + // } + + pub(in super::super::super) fn add_unfolding_postcondition( + &mut self, + body: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + let predicate = self.precondition_predicate()?; + let unfolding = predicate.into_unfolding(body); + self.inner.add_postcondition(unfolding) + } + + pub(in super::super::super) fn add_postcondition( + &mut self, + expression: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + self.inner.add_postcondition(expression) + } + + // pub(in super::super::super) fn add_validity_postcondition( + // &mut self, + // ) -> SpannedEncodingResult<()> { + // self.inner.add_validity_postcondition() + // } + + // pub(in super::super::super) fn add_snapshot_len_equal_to_postcondition( + // &mut self, + // array_length_mid: &vir_mid::VariableDecl, + // ) -> SpannedEncodingResult<()> { + // self.inner + // .add_snapshot_len_equal_to_postcondition(array_length_mid) + // } + + pub(in super::super::super) fn add_unique_ref_precondition( + &mut self, + ) -> SpannedEncodingResult<()> { + let predicate = self.precondition_predicate()?; + self.inner.add_precondition(predicate) + } + + fn precondition_predicate(&mut self) -> SpannedEncodingResult { + assert!(!self.is_final); + self.unique_ref_predicate( + self.inner.ty, + self.inner.type_decl, + self.inner.place.clone().into(), + self.root_address.clone().into(), + self.reference_lifetime.clone().into(), + ) + } + + pub(in super::super::super) fn add_structural_invariant( + &mut self, + decl: &vir_mid::type_decl::Struct, + ) -> SpannedEncodingResult<()> { + let precondition_predicate = if self.is_final { + None + } else { + Some(self.precondition_predicate()?) + }; + let predicate_kind = PredicateKind::UniqueRef { + lifetime: self.reference_lifetime.clone().into(), + is_final: self.is_final, + }; + let snap_call = self.field_unique_ref_snap()?; + self.inner.add_structural_invariant( + decl, + precondition_predicate, + predicate_kind, + &snap_call, + ) + } + + // fn compute_address(&self) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let compute_address = ty!(Address); + // let expression = expr! { + // ComputeAddress::compute_address( + // [self.place.clone().into()], + // [self.root_address.clone().into()] + // ) + // }; + // Ok(expression) + // } + + // fn size_of(&mut self) -> SpannedEncodingResult { + // self.inner + // .lowerer + // .encode_type_size_expression2(self.inner.ty, self.inner.type_decl) + // } + + // fn add_bytes_snapshot_equality_with( + // &mut self, + // snap_ty: &vir_mid::Type, + // snapshot: vir_low::Expression, + // ) -> SpannedEncodingResult<()> { + // use vir_low::macros::*; + // let size_of = self.size_of()?; + // let bytes = self + // .inner + // .lowerer + // .encode_memory_block_bytes_expression(self.compute_address()?, size_of)?; + // let to_bytes = ty! { Bytes }; + // let expression = expr! { + // [bytes] == (Snap::to_bytes([snapshot])) + // }; + // self.add_unfolding_postcondition(expression) + // } + + // pub(in super::super::super) fn add_bytes_snapshot_equality( + // &mut self, + // ) -> SpannedEncodingResult<()> { + // let result = self.inner.result()?.into(); + // self.add_bytes_snapshot_equality_with(self.inner.ty, result) + // } + + // pub(in super::super::super) fn add_bytes_address_snapshot_equality( + // &mut self, + // ) -> SpannedEncodingResult<()> { + // let result = self.inner.result()?.into(); + // let address_type = self.inner.lowerer.reference_address_type(self.inner.ty)?; + // self.inner + // .lowerer + // .encode_snapshot_to_bytes_function(&address_type)?; + // let target_address_snapshot = self.inner.lowerer.reference_address_snapshot( + // self.inner.ty, + // result, + // self.inner.position, + // )?; + // self.add_bytes_snapshot_equality_with(&address_type, target_address_snapshot) + // } + + // pub(in super::super::super) fn create_field_snapshot_equality( + // &mut self, + // field: &vir_mid::FieldDecl, + // ) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let result = self.inner.result()?; + // let field_place = self.inner.lowerer.encode_field_place( + // self.inner.ty, + // field, + // self.place.clone().into(), + // self.inner.position, + // )?; + // let field_snapshot = self.inner.lowerer.obtain_struct_field_snapshot( + // self.inner.ty, + // field, + // result.into(), + // self.inner.position, + // )?; + // let snap_call = self.inner.lowerer.owned_non_aliased_snap( + // CallContext::BuiltinMethod, + // &field.ty, + // &field.ty, + // field_place, + // self.root_address.clone().into(), + // )?; + // Ok(expr! { + // [field_snapshot] == [snap_call] + // }) + // } + + // pub(in super::super::super) fn create_discriminant_snapshot_equality( + // &mut self, + // decl: &vir_mid::type_decl::Enum, + // ) -> SpannedEncodingResult { + // use vir_low::macros::*; + // let result = self.inner.result()?; + // let discriminant_snapshot = self.inner.lowerer.obtain_enum_discriminant( + // result.into(), + // self.inner.ty, + // self.inner.position, + // )?; + // let discriminant_field = decl.discriminant_field(); + // let discriminant_place = self.inner.lowerer.encode_field_place( + // self.inner.ty, + // &discriminant_field, + // self.place.clone().into(), + // self.inner.position, + // )?; + // let snap_call = self.inner.lowerer.owned_non_aliased_snap( + // CallContext::BuiltinMethod, + // &decl.discriminant_type, + // &decl.discriminant_type, + // discriminant_place, + // self.root_address.clone().into(), + // )?; + // let snap_call_int = self.inner.lowerer.obtain_constant_value( + // &decl.discriminant_type, + // snap_call, + // self.inner.position, + // )?; + // Ok(expr! { + // [discriminant_snapshot] == [snap_call_int] + // }) + // } + + // pub(in super::super::super) fn create_variant_snapshot_equality( + // &mut self, + // discriminant_value: vir_mid::DiscriminantValue, + // variant: &vir_mid::type_decl::Struct, + // ) -> SpannedEncodingResult<(vir_low::Expression, vir_low::Expression)> { + // use vir_low::macros::*; + // let result = self.inner.result()?; + // let discriminant_call = self.inner.lowerer.obtain_enum_discriminant( + // result.clone().into(), + // self.inner.ty, + // self.inner.position, + // )?; + // let guard = expr! { + // [ discriminant_call ] == [ discriminant_value.into() ] + // }; + // let variant_index = variant.name.clone().into(); + // let variant_place = self.inner.lowerer.encode_enum_variant_place( + // self.inner.ty, + // &variant_index, + // self.place.clone().into(), + // self.inner.position, + // )?; + // let variant_snapshot = self.inner.lowerer.obtain_enum_variant_snapshot( + // self.inner.ty, + // &variant_index, + // result.into(), + // self.inner.position, + // )?; + // let variant_type = self.inner.ty.clone().variant(variant_index); + // let snap_call = self.inner.lowerer.owned_non_aliased_snap( + // CallContext::BuiltinMethod, + // &variant_type, + // &variant_type, + // variant_place, + // self.root_address.clone().into(), + // )?; + // let equality = expr! { + // [variant_snapshot] == [snap_call] + // }; + // Ok((guard, equality)) + // } + + // FIXME: Code duplication. + fn slice_len(&mut self) -> SpannedEncodingResult> { + self.slice_len + .as_ref() + .map(|slice_len_mid| slice_len_mid.to_pure_snapshot(self.inner.lowerer)) + .transpose() + } + + // FIXME: Code duplication. + fn slice_len_expression(&mut self) -> SpannedEncodingResult> { + Ok(self.slice_len()?.map(|slice_len| slice_len.into())) + } + + pub(in super::super::super) fn create_field_snapshot_equality( + &mut self, + field: &vir_mid::FieldDecl, + ) -> SpannedEncodingResult { + let unique_ref_call = self.field_unique_ref_snap()?; + self.inner + .create_field_snapshot_equality(field, unique_ref_call) + } + + fn field_unique_ref_snap( + &mut self, + ) -> SpannedEncodingResult< + impl Fn( + &mut FunctionDeclBuilder, + &vir_mid::FieldDecl, + vir_low::Expression, + ) -> SpannedEncodingResult, + > { + let target_slice_len = self.slice_len_expression()?; + let root_address: vir_low::Expression = self.root_address.clone().into(); + let root_address = std::rc::Rc::new(root_address); + let lifetime: vir_low::Expression = self.reference_lifetime.clone().into(); + let lifetime = std::rc::Rc::new(lifetime); + let is_final = self.is_final; + Ok( + move |builder: &mut FunctionDeclBuilder, field: &vir_mid::FieldDecl, field_place| { + builder.lowerer.unique_ref_snap( + CallContext::BuiltinMethod, + &field.ty, + &field.ty, + field_place, + (*root_address).clone(), + (*lifetime).clone(), + target_slice_len.clone(), + is_final, + ) + }, + ) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/function_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/function_use.rs new file mode 100644 index 00000000000..7278aed0380 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/function_use.rs @@ -0,0 +1,70 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + builtin_methods::CallContext, lowerer::Lowerer, + predicates::owned::builders::common::function_use::FunctionCallBuilder, + }, +}; +use vir_crate::{ + low::{self as vir_low}, + middle::{ + self as vir_mid, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes}, + }, +}; + +pub(in super::super::super) struct UniqueRefSnapCallBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + inner: FunctionCallBuilder<'l, 'p, 'v, 'tcx, G>, +} + +impl<'l, 'p, 'v, 'tcx, G> UniqueRefSnapCallBuilder<'l, 'p, 'v, 'tcx, G> +where + G: WithLifetimes + WithConstArguments, +{ + pub(in super::super::super) fn new( + lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + context: CallContext, + ty: &'l vir_mid::Type, + generics: &'l G, + place: vir_low::Expression, + root_address: vir_low::Expression, + reference_lifetime: vir_low::Expression, + target_slice_len: Option, + is_final: bool, + ) -> SpannedEncodingResult { + let mut arguments = vec![place, root_address, reference_lifetime]; + if let Some(len) = target_slice_len { + arguments.push(len); + } + let name = if is_final { + "snap_final_unique_ref" + } else { + "snap_current_unique_ref" + }; + let inner = FunctionCallBuilder::new( + lowerer, + name, + context, + ty, + generics, + arguments, + Default::default(), + )?; + Ok(Self { inner }) + } + + pub(in super::super::super) fn build(self) -> SpannedEncodingResult { + self.inner.build() + } + + pub(in super::super::super) fn add_lifetime_arguments(&mut self) -> SpannedEncodingResult<()> { + self.inner.add_lifetime_arguments() + } + + pub(in super::super::super) fn add_const_arguments(&mut self) -> SpannedEncodingResult<()> { + self.inner.add_const_arguments() + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/mod.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/mod.rs index ef427252419..6bcb70532ad 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/mod.rs @@ -1,2 +1,4 @@ +pub(super) mod function_decl; +pub(super) mod function_use; pub(super) mod predicate_decl; pub(super) mod predicate_use; diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/predicate_decl.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/predicate_decl.rs index 280520840fa..ae09604f441 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/predicate_decl.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/predicate_decl.rs @@ -1,4 +1,3 @@ -use super::predicate_use::UniqueRefUseBuilder; use crate::encoder::{ errors::SpannedEncodingResult, middle::core_proof::{ @@ -7,15 +6,20 @@ use crate::encoder::{ lifetimes::LifetimesInterface, lowerer::Lowerer, places::PlacesInterface, - predicates::owned::builders::{ - common::predicate_decl::PredicateDeclBuilder, PredicateDeclBuilderMethods, + predicates::{ + owned::builders::{ + common::predicate_decl::PredicateDeclBuilder, PredicateDeclBuilderMethods, + }, + PredicatesOwnedInterface, }, snapshots::{ - IntoPureSnapshot, IntoSnapshot, SnapshotValidityInterface, SnapshotValuesInterface, + IntoPureSnapshot, IntoSnapshot, IntoSnapshotLowerer, PredicateKind, + SelfFramingAssertionToSnapshot, SnapshotValidityInterface, SnapshotValuesInterface, }, type_layouts::TypeLayoutsInterface, }, }; +use prusti_common::config; use vir_crate::{ common::expression::{GuardedExpressionIterator, QuantifierHelpers}, low::{self as vir_low}, @@ -24,10 +28,8 @@ use vir_crate::{ pub(in super::super::super) struct UniqueRefBuilder<'l, 'p, 'v, 'tcx> { inner: PredicateDeclBuilder<'l, 'p, 'v, 'tcx>, - place: vir_low::VariableDecl, - root_address: vir_low::VariableDecl, - current_snapshot: vir_low::VariableDecl, - final_snapshot: vir_low::VariableDecl, + // current_snapshot: vir_low::VariableDecl, + // final_snapshot: vir_low::VariableDecl, reference_lifetime: vir_low::VariableDecl, slice_len: Option, } @@ -55,13 +57,11 @@ impl<'l, 'p, 'v, 'tcx> UniqueRefBuilder<'l, 'p, 'v, 'tcx> { None }; Ok(Self { - place: vir_low::VariableDecl::new("place", lowerer.place_type()?), - root_address: vir_low::VariableDecl::new("root_address", lowerer.address_type()?), - current_snapshot: vir_low::VariableDecl::new( - "current_snapshot", - ty.to_snapshot(lowerer)?, - ), - final_snapshot: vir_low::VariableDecl::new("final_snapshot", ty.to_snapshot(lowerer)?), + // current_snapshot: vir_low::VariableDecl::new( + // "current_snapshot", + // ty.to_snapshot(lowerer)?, + // ), + // final_snapshot: vir_low::VariableDecl::new("final_snapshot", ty.to_snapshot(lowerer)?), reference_lifetime: vir_low::VariableDecl::new( "reference_lifetime", lowerer.lifetime_type()?, @@ -82,11 +82,13 @@ impl<'l, 'p, 'v, 'tcx> UniqueRefBuilder<'l, 'p, 'v, 'tcx> { } pub(in super::super::super) fn create_parameters(&mut self) -> SpannedEncodingResult<()> { - self.inner.parameters.push(self.place.clone()); - self.inner.parameters.push(self.root_address.clone()); - self.inner.parameters.push(self.current_snapshot.clone()); - self.inner.parameters.push(self.final_snapshot.clone()); + self.inner.parameters.push(self.inner.place.clone()); + self.inner.parameters.push(self.inner.root_address.clone()); self.inner.parameters.push(self.reference_lifetime.clone()); + // if config::use_snapshot_parameters_in_predicates() { + // self.inner.parameters.push(self.current_snapshot.clone()); + // self.inner.parameters.push(self.final_snapshot.clone()); + // } self.inner.create_lifetime_parameters()?; if let Some(slice_len_mid) = &self.slice_len { let slice_len = slice_len_mid.to_pure_snapshot(self.inner.lowerer)?; @@ -96,9 +98,9 @@ impl<'l, 'p, 'v, 'tcx> UniqueRefBuilder<'l, 'p, 'v, 'tcx> { Ok(()) } - pub(in super::super::super) fn add_validity(&mut self) -> SpannedEncodingResult<()> { - self.inner.add_validity(&self.current_snapshot) - } + // pub(in super::super::super) fn add_validity(&mut self) -> SpannedEncodingResult<()> { + // self.inner.add_validity(&self.current_snapshot) + // } pub(in super::super::super) fn add_field_predicate( &mut self, @@ -107,35 +109,34 @@ impl<'l, 'p, 'v, 'tcx> UniqueRefBuilder<'l, 'p, 'v, 'tcx> { let field_place = self.inner.lowerer.encode_field_place( self.inner.ty, field, - self.place.clone().into(), - self.inner.position, - )?; - let current_field_snapshot = self.inner.lowerer.obtain_struct_field_snapshot( - self.inner.ty, - field, - self.current_snapshot.clone().into(), + self.inner.place.clone().into(), self.inner.position, )?; - let final_field_snapshot = self.inner.lowerer.obtain_struct_field_snapshot( - self.inner.ty, - field, - self.final_snapshot.clone().into(), - self.inner.position, - )?; - let mut builder = UniqueRefUseBuilder::new( - self.inner.lowerer, + // let current_field_snapshot = self.inner.lowerer.obtain_struct_field_snapshot( + // self.inner.ty, + // field, + // self.current_snapshot.clone().into(), + // self.inner.position, + // )?; + // let final_field_snapshot = self.inner.lowerer.obtain_struct_field_snapshot( + // self.inner.ty, + // field, + // self.final_snapshot.clone().into(), + // self.inner.position, + // )?; + let current_field_snapshot = true.into(); + let final_field_snapshot = true.into(); + let expression = self.inner.lowerer.unique_ref_predicate( CallContext::BuiltinMethod, &field.ty, &field.ty, field_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), current_field_snapshot, final_field_snapshot, self.reference_lifetime.clone().into(), + None, // FIXME: This should be a proper value )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let expression = builder.build(); self.inner.add_conjunct(expression) } @@ -147,54 +148,85 @@ impl<'l, 'p, 'v, 'tcx> UniqueRefBuilder<'l, 'p, 'v, 'tcx> { let discriminant_place = self.inner.lowerer.encode_field_place( self.inner.ty, &discriminant_field, - self.place.clone().into(), + self.inner.place.clone().into(), self.inner.position, )?; - let current_discriminant_call = self.inner.lowerer.obtain_enum_discriminant( - self.current_snapshot.clone().into(), - self.inner.ty, - self.inner.position, - )?; - let current_discriminant_snapshot = self.inner.lowerer.construct_constant_snapshot( - &decl.discriminant_type, - current_discriminant_call, - self.inner.position, - )?; - let final_discriminant_call = self.inner.lowerer.obtain_enum_discriminant( - self.final_snapshot.clone().into(), - self.inner.ty, - self.inner.position, - )?; - let final_discriminant_snapshot = self.inner.lowerer.construct_constant_snapshot( - &decl.discriminant_type, - final_discriminant_call, - self.inner.position, - )?; - let builder = UniqueRefUseBuilder::new( - self.inner.lowerer, + // let current_discriminant_call = self.inner.lowerer.obtain_enum_discriminant( + // self.current_snapshot.clone().into(), + // self.inner.ty, + // self.inner.position, + // )?; + // let current_discriminant_snapshot = self.inner.lowerer.construct_constant_snapshot( + // &decl.discriminant_type, + // current_discriminant_call, + // self.inner.position, + // )?; + // let final_discriminant_call = self.inner.lowerer.obtain_enum_discriminant( + // self.final_snapshot.clone().into(), + // self.inner.ty, + // self.inner.position, + // )?; + // let final_discriminant_snapshot = self.inner.lowerer.construct_constant_snapshot( + // &decl.discriminant_type, + // final_discriminant_call, + // self.inner.position, + // )?; + // let builder = UniqueRefUseBuilder::new( + // self.inner.lowerer, + // CallContext::BuiltinMethod, + // &decl.discriminant_type, + // &decl.discriminant_type, + // discriminant_place, + // self.inner.root_address.clone().into(), + // current_discriminant_snapshot, + // final_discriminant_snapshot, + // self.reference_lifetime.clone().into(), + // )?; + // let expression = builder.build(); + let current_discriminant_snapshot = true.into(); + let final_discriminant_snapshot = true.into(); + let expression = self.inner.lowerer.unique_ref_predicate( CallContext::BuiltinMethod, &decl.discriminant_type, &decl.discriminant_type, discriminant_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), current_discriminant_snapshot, final_discriminant_snapshot, self.reference_lifetime.clone().into(), + None, // FIXME: This should be a proper value )?; - let expression = builder.build(); self.inner.add_conjunct(expression) } + pub(in super::super::super) fn add_unique_ref_pointer_predicate( + &mut self, + lifetime: &vir_mid::ty::LifetimeConst, + ) -> SpannedEncodingResult { + let place = self.inner.place.clone(); + let root_address = self.inner.root_address.clone(); + self.inner.add_unique_ref_pointer_predicate( + lifetime, + place, + root_address, + // &self.current_snapshot, + ) + } + pub(in super::super::super) fn add_unique_ref_target_predicate( &mut self, target_type: &vir_mid::Type, lifetime: &vir_mid::ty::LifetimeConst, ) -> SpannedEncodingResult<()> { + let place = self.inner.place.clone(); + let root_address = self.inner.root_address.clone(); self.inner.add_unique_ref_target_predicate( target_type, lifetime, - &self.place, - &self.current_snapshot, + place.into(), + root_address, + // &self.current_snapshot, + true, ) } @@ -203,11 +235,14 @@ impl<'l, 'p, 'v, 'tcx> UniqueRefBuilder<'l, 'p, 'v, 'tcx> { target_type: &vir_mid::Type, lifetime: &vir_mid::ty::LifetimeConst, ) -> SpannedEncodingResult<()> { + let place = self.inner.place.clone(); + let root_address = self.inner.root_address.clone(); self.inner.add_frac_ref_target_predicate( target_type, lifetime, - &self.place, - &self.current_snapshot, + place.into(), + root_address, + // &self.current_snapshot, ) } @@ -221,10 +256,11 @@ impl<'l, 'p, 'v, 'tcx> UniqueRefBuilder<'l, 'p, 'v, 'tcx> { &mut self, array_length_mid: &vir_mid::VariableDecl, ) -> SpannedEncodingResult<()> { - self.inner - .add_snapshot_len_equal_to(&self.current_snapshot, array_length_mid)?; - self.inner - .add_snapshot_len_equal_to(&self.final_snapshot, array_length_mid)?; + unimplemented!(); + // self.inner + // .add_snapshot_len_equal_to(&self.current_snapshot, array_length_mid)?; + // self.inner + // .add_snapshot_len_equal_to(&self.final_snapshot, array_length_mid)?; Ok(()) } @@ -251,34 +287,47 @@ impl<'l, 'p, 'v, 'tcx> UniqueRefBuilder<'l, 'p, 'v, 'tcx> { let array_length_int = self.inner.array_length_int(array_length_mid)?; let element_place = self.inner.lowerer.encode_index_place( self.inner.ty, - self.place.clone().into(), + self.inner.place.clone().into(), index.clone().into(), self.inner.position, )?; - let current_element_snapshot = self.inner.lowerer.obtain_array_element_snapshot( - self.current_snapshot.clone().into(), - index_int.clone(), - self.inner.position, - )?; - let final_element_snapshot = self.inner.lowerer.obtain_array_element_snapshot( - self.final_snapshot.clone().into(), - index_int.clone(), - self.inner.position, - )?; - let mut builder = UniqueRefUseBuilder::new( - self.inner.lowerer, + // let current_element_snapshot = self.inner.lowerer.obtain_array_element_snapshot( + // self.current_snapshot.clone().into(), + // index_int.clone(), + // self.inner.position, + // )?; + // let final_element_snapshot = self.inner.lowerer.obtain_array_element_snapshot( + // self.final_snapshot.clone().into(), + // index_int.clone(), + // self.inner.position, + // )?; + let current_element_snapshot = true.into(); // FIXME + let final_element_snapshot = true.into(); // FIXME + // let mut builder = UniqueRefUseBuilder::new( + // self.inner.lowerer, + // CallContext::BuiltinMethod, + // element_type, + // element_type, + // element_place, + // self.inner.root_address.clone().into(), + // current_element_snapshot, + // final_element_snapshot, + // self.reference_lifetime.clone().into(), + // )?; + // builder.add_lifetime_arguments()?; + // builder.add_const_arguments()?; + // let element_predicate_acc = builder.build(); + let element_predicate_acc = self.inner.lowerer.unique_ref_predicate( CallContext::BuiltinMethod, element_type, element_type, element_place, - self.root_address.clone().into(), + self.inner.root_address.clone().into(), current_element_snapshot, final_element_snapshot, self.reference_lifetime.clone().into(), + None, // FIXME: This should be a proper value )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let element_predicate_acc = builder.build(); let elements = vir_low::Expression::forall( vec![index], vec![vir_low::Trigger::new(vec![element_predicate_acc.clone()])], @@ -296,49 +345,61 @@ impl<'l, 'p, 'v, 'tcx> UniqueRefBuilder<'l, 'p, 'v, 'tcx> { variant: &vir_mid::type_decl::Struct, variant_type: &vir_mid::Type, ) -> SpannedEncodingResult<(vir_low::Expression, vir_low::Expression)> { - use vir_low::macros::*; - let discriminant_call = self.inner.lowerer.obtain_enum_discriminant( - self.current_snapshot.clone().into(), - self.inner.ty, - self.inner.position, - )?; - let guard = expr! { - [ discriminant_call ] == [ discriminant_value.into() ] - }; - let variant_index = variant.name.clone().into(); - let variant_place = self.inner.lowerer.encode_enum_variant_place( - self.inner.ty, - &variant_index, - self.place.clone().into(), - self.inner.position, - )?; - let current_variant_snapshot = self.inner.lowerer.obtain_enum_variant_snapshot( - self.inner.ty, - &variant_index, - self.current_snapshot.clone().into(), - self.inner.position, - )?; - let final_variant_snapshot = self.inner.lowerer.obtain_enum_variant_snapshot( - self.inner.ty, - &variant_index, - self.final_snapshot.clone().into(), - self.inner.position, - )?; - let mut builder = UniqueRefUseBuilder::new( - self.inner.lowerer, - CallContext::BuiltinMethod, - variant_type, - variant_type, - variant_place, - self.root_address.clone().into(), - current_variant_snapshot, - final_variant_snapshot, - self.reference_lifetime.clone().into(), - )?; - builder.add_lifetime_arguments()?; - builder.add_const_arguments()?; - let predicate = builder.build(); - Ok((guard, predicate)) + unimplemented!(); + // use vir_low::macros::*; + // let discriminant_call = self.inner.lowerer.obtain_enum_discriminant( + // self.current_snapshot.clone().into(), + // self.inner.ty, + // self.inner.position, + // )?; + // let guard = expr! { + // [ discriminant_call ] == [ discriminant_value.into() ] + // }; + // let variant_index = variant.name.clone().into(); + // let variant_place = self.inner.lowerer.encode_enum_variant_place( + // self.inner.ty, + // &variant_index, + // self.inner.place.clone().into(), + // self.inner.position, + // )?; + // let current_variant_snapshot = self.inner.lowerer.obtain_enum_variant_snapshot( + // self.inner.ty, + // &variant_index, + // self.current_snapshot.clone().into(), + // self.inner.position, + // )?; + // let final_variant_snapshot = self.inner.lowerer.obtain_enum_variant_snapshot( + // self.inner.ty, + // &variant_index, + // self.final_snapshot.clone().into(), + // self.inner.position, + // )?; + // // let mut builder = UniqueRefUseBuilder::new( + // // self.inner.lowerer, + // // CallContext::BuiltinMethod, + // // variant_type, + // // variant_type, + // // variant_place, + // // self.inner.root_address.clone().into(), + // // current_variant_snapshot, + // // final_variant_snapshot, + // // self.reference_lifetime.clone().into(), + // // )?; + // // builder.add_lifetime_arguments()?; + // // builder.add_const_arguments()?; + // // let predicate = builder.build(); + // let predicate = self.inner.lowerer.unique_ref( + // CallContext::BuiltinMethod, + // variant_type, + // variant_type, + // variant_place, + // self.inner.root_address.clone().into(), + // current_variant_snapshot, + // final_variant_snapshot, + // self.reference_lifetime.clone().into(), + // None, // FIXME: This should be a proper value + // )?; + // Ok((guard, predicate)) } pub(in super::super::super) fn add_variant_predicates( @@ -348,4 +409,41 @@ impl<'l, 'p, 'v, 'tcx> UniqueRefBuilder<'l, 'p, 'v, 'tcx> { self.inner .add_conjunct(variant_predicates.into_iter().create_match()) } + + pub(in super::super::super) fn add_structural_invariant( + &mut self, + decl: &vir_mid::type_decl::Struct, + ) -> SpannedEncodingResult> { + self.inner.add_structural_invariant( + decl, + PredicateKind::UniqueRef { + lifetime: self.reference_lifetime.clone().into(), + is_final: false, + }, + ) + } + + // /// FIXME: Code duplication. + // pub(in super::super::super) fn add_structural_invariant( + // &mut self, + // decl: &vir_mid::type_decl::Struct, + // ) -> SpannedEncodingResult> { + // if let Some(invariant) = &decl.structural_invariant { + // let mut encoder = SelfFramingAssertionToSnapshot::for_predicate_body( + // self.inner.place.clone(), + // self.inner.root_address.clone(), + // PredicateKind::UniqueRef { + // lifetime: self.reference_lifetime.clone().into(), + // }, + // ); + // for assertion in invariant { + // let low_assertion = + // encoder.expression_to_snapshot(self.inner.lowerer, assertion, true)?; + // self.inner.add_conjunct(low_assertion)?; + // } + // Ok(encoder.into_created_predicate_types()) + // } else { + // Ok(Vec::new()) + // } + // } } diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/predicate_use.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/predicate_use.rs index 799f81227d1..8be59ba1e60 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/predicate_use.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/unique_ref/predicate_use.rs @@ -1,12 +1,16 @@ use crate::encoder::{ errors::SpannedEncodingResult, middle::core_proof::{ - builtin_methods::CallContext, lowerer::Lowerer, - predicates::owned::builders::common::predicate_use::PredicateUseBuilder, - snapshots::SnapshotValuesInterface, type_layouts::TypeLayoutsInterface, + builtin_methods::CallContext, + lowerer::Lowerer, + predicates::{ + owned::builders::common::predicate_use::PredicateUseBuilder, PredicatesOwnedInterface, + }, }, }; +use prusti_common::config; use vir_crate::{ + common::expression::BinaryOperationHelpers, low::{self as vir_low}, middle::{ self as vir_mid, @@ -19,7 +23,8 @@ where G: WithLifetimes + WithConstArguments, { inner: PredicateUseBuilder<'l, 'p, 'v, 'tcx, G>, - current_snapshot: vir_low::Expression, + current_snapshot: Option, + target_slice_len: Option, } impl<'l, 'p, 'v, 'tcx, G> UniqueRefUseBuilder<'l, 'p, 'v, 'tcx, G> @@ -34,33 +39,86 @@ where generics: &'l G, place: vir_low::Expression, root_address: vir_low::Expression, - current_snapshot: vir_low::Expression, - final_snapshot: vir_low::Expression, + // current_snapshot: vir_low::Expression, + // final_snapshot: vir_low::Expression, lifetime: vir_low::Expression, + target_slice_len: Option, ) -> SpannedEncodingResult { + let mut arguments = vec![ + place, + root_address, + // current_snapshot.clone(), + lifetime, + // final_snapshot, + ]; + if let Some(len) = target_slice_len.clone() { + arguments.push(len); + } let inner = PredicateUseBuilder::new( lowerer, "UniqueRef2", context, ty, generics, - vec![ - place, - root_address, - current_snapshot.clone(), - final_snapshot, - lifetime, - ], + arguments, Default::default(), )?; Ok(Self { inner, - current_snapshot, + target_slice_len, + current_snapshot: None, + // current_snapshot, }) } - pub(in super::super::super::super::super) fn build(self) -> vir_low::Expression { - self.inner.build() + pub(in super::super::super::super::super) fn build( + mut self, + ) -> SpannedEncodingResult { + let expression = if let Some(current_snapshot) = self.current_snapshot.take() { + let snap_current_call = self.inner.lowerer.unique_ref_snap( + self.inner.context, + self.inner.ty, + self.inner.generics, + self.inner.arguments[0].clone(), + self.inner.arguments[1].clone(), + self.inner.arguments[2].clone(), + self.target_slice_len.clone(), + false, + )?; + // let snap_final_call = self.inner.lowerer.unique_ref_snap( + // self.inner.context, + // self.inner.ty, + // self.inner.generics, + // self.inner.arguments[0].clone(), + // self.inner.arguments[1].clone(), + // self.inner.arguments[2].clone(), + // self.target_slice_len, + // true, + // )?; + vir_low::Expression::and( + self.inner.build(), + // vir_low::Expression::and( + vir_low::Expression::equals(current_snapshot, snap_current_call), + // vir_low::Expression::equals(final_snapshot, snap_final_call), + // ), + ) + } else { + self.inner.build() + }; + Ok(expression) + } + + pub(in super::super::super::super::super) fn add_current_snapshot_argument( + &mut self, + current_snapshot: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + // if config::use_snapshot_parameters_in_predicates() { + // self.inner.arguments.push(current_snapshot); + // self.inner.arguments.push(final_snapshot); + // } else { + self.current_snapshot = Some(current_snapshot); + // } + Ok(()) } pub(in super::super::super::super::super) fn add_lifetime_arguments( @@ -73,17 +131,19 @@ where &mut self, ) -> SpannedEncodingResult<()> { if self.inner.ty.is_slice() { - let snapshot_length = self - .inner - .lowerer - .obtain_array_len_snapshot(self.current_snapshot.clone(), self.inner.position)?; - let size_type = self.inner.lowerer.size_type_mid()?; - let argument = self.inner.lowerer.construct_constant_snapshot( - &size_type, - snapshot_length, - self.inner.position, - )?; - self.inner.arguments.push(argument); + // FIXME + eprintln!("FIXME!!!"); + // let snapshot_length = self + // .inner + // .lowerer + // .obtain_array_len_snapshot(self.current_snapshot.clone(), self.inner.position)?; + // let size_type = self.inner.lowerer.size_type_mid()?; + // let argument = self.inner.lowerer.construct_constant_snapshot( + // &size_type, + // snapshot_length, + // self.inner.position, + // )?; + // self.inner.arguments.push(argument); } self.inner.add_const_arguments() } diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/encoder.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/encoder.rs index 2d9e138997f..96ee7563a82 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/encoder.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/encoder.rs @@ -1,9 +1,11 @@ +use std::collections::BTreeMap; + use crate::encoder::{ errors::SpannedEncodingResult, high::types::HighTypeEncoderInterface, middle::core_proof::{ compute_address::ComputeAddressInterface, - lowerer::Lowerer, + lowerer::{FunctionsLowererInterface, Lowerer}, places::PlacesInterface, predicates::{ owned::builders::{ @@ -16,34 +18,46 @@ use crate::encoder::{ types::TypesInterface, }, }; +use prusti_common::config; use rustc_hash::FxHashSet; use vir_crate::{ - common::identifier::WithIdentifier, + common::{ + expression::{ExpressionIterator, GuardedExpressionIterator}, + identifier::WithIdentifier, + }, low::{self as vir_low}, middle as vir_mid, }; +use super::{ + builders::{ + FracRefSnapFunctionBuilder, OwnedAliasedBuilder, OwnedAliasedSnapFunctionBuilder, + OwnedNonAliasedSnapFunctionBuilder, UniqueRefSnapFunctionBuilder, + }, + interface::PredicateInfo, +}; + pub(super) struct PredicateEncoder<'l, 'p, 'v, 'tcx> { lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, - unfolded_owned_non_aliased_predicates: &'l FxHashSet, - encoded_owned_predicates: FxHashSet, + encoded_owned_non_aliased_predicates: FxHashSet, + encoded_owned_aliased_predicates: FxHashSet, encoded_mut_borrow_predicates: FxHashSet, encoded_frac_borrow_predicates: FxHashSet, predicates: Vec, + /// A map from predicate names to snapshot function names and snapshot types. + predicate_info: BTreeMap, } impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { - pub(super) fn new( - lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, - unfolded_owned_non_aliased_predicates: &'l FxHashSet, - ) -> Self { + pub(super) fn new(lowerer: &'l mut Lowerer<'p, 'v, 'tcx>) -> Self { Self { lowerer, - unfolded_owned_non_aliased_predicates, - encoded_owned_predicates: Default::default(), + encoded_owned_non_aliased_predicates: Default::default(), + encoded_owned_aliased_predicates: Default::default(), encoded_mut_borrow_predicates: Default::default(), encoded_frac_borrow_predicates: Default::default(), predicates: Default::default(), + predicate_info: Default::default(), } } @@ -51,22 +65,284 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { self.predicates } + pub(super) fn take_predicate_info(&mut self) -> BTreeMap { + std::mem::take(&mut self.predicate_info) + } + + pub(super) fn encode_owned_non_aliased_snapshot( + &mut self, + normalized_type: &vir_mid::Type, + type_decl: &vir_mid::TypeDecl, + ) -> SpannedEncodingResult<(String, vir_low::Type)> { + let mut builder = + OwnedNonAliasedSnapFunctionBuilder::new(self.lowerer, normalized_type, type_decl)?; + builder.create_parameters()?; + builder.add_owned_precondition()?; + builder.add_validity_postcondition()?; + match type_decl { + vir_mid::TypeDecl::Bool + | vir_mid::TypeDecl::Int(_) + | vir_mid::TypeDecl::Float(_) + | vir_mid::TypeDecl::Pointer(_) + | vir_mid::TypeDecl::Sequence(_) + | vir_mid::TypeDecl::Map(_) => { + builder.add_bytes_snapshot_equality()?; + } + vir_mid::TypeDecl::Trusted(_) | vir_mid::TypeDecl::TypeVar(_) => {} + vir_mid::TypeDecl::Struct(decl) => { + let mut equalities = Vec::new(); + for field in &decl.fields { + equalities.push(builder.create_field_snapshot_equality(field)?); + } + builder.add_unfolding_postcondition(equalities.into_iter().conjoin())?; + builder.add_structural_invariant(decl)?; + } + vir_mid::TypeDecl::Enum(decl) => { + let mut equalities = Vec::new(); + if decl.safety.is_enum() { + let discriminant_equality = + builder.create_discriminant_snapshot_equality(decl)?; + builder.add_unfolding_postcondition(discriminant_equality)?; + } + for (discriminant, variant) in decl.iter_discriminant_variants() { + equalities + .push(builder.create_variant_snapshot_equality(discriminant, variant)?); + } + builder.add_unfolding_postcondition(equalities.into_iter().create_match())?; + } + vir_mid::TypeDecl::Reference(decl) => { + builder.add_bytes_address_snapshot_equality()?; + // FIXME: Have a getter for the first lifetime. + let lifetime = &decl.lifetimes[0]; + builder.add_reference_snapshot_equalities(decl, lifetime)?; + } + vir_mid::TypeDecl::Array(decl) => { + let length = if normalized_type.is_slice() { + builder.get_slice_len()? + } else { + decl.const_parameters[0].clone() + }; + builder.add_snapshot_len_equal_to_postcondition(&length)?; + builder.add_quantifiers(&length, &decl.element_type)?; + } + _ => { + unimplemented!("{}", type_decl); + } + } + let function = builder.build()?; + let function_name = function.name.clone(); + let snapshot_type = function.return_type.clone(); + self.lowerer.declare_function(function)?; + Ok((function_name, snapshot_type)) + } + + pub(super) fn encode_owned_aliased_snapshot( + &mut self, + normalized_type: &vir_mid::Type, + type_decl: &vir_mid::TypeDecl, + ) -> SpannedEncodingResult<()> { + let mut builder = + OwnedAliasedSnapFunctionBuilder::new(self.lowerer, normalized_type, type_decl)?; + builder.create_parameters()?; + builder.add_owned_precondition()?; + builder.add_validity_postcondition()?; + match type_decl { + vir_mid::TypeDecl::Bool + | vir_mid::TypeDecl::Int(_) + | vir_mid::TypeDecl::Float(_) + | vir_mid::TypeDecl::Pointer(_) + | vir_mid::TypeDecl::Sequence(_) + | vir_mid::TypeDecl::Map(_) => { + builder.add_bytes_snapshot_equality()?; + } + vir_mid::TypeDecl::Trusted(_) | vir_mid::TypeDecl::TypeVar(_) => {} + vir_mid::TypeDecl::Struct(decl) => { + let mut equalities = Vec::new(); + for field in &decl.fields { + equalities.push(builder.create_field_snapshot_equality(field)?); + } + builder.add_unfolding_postcondition(equalities.into_iter().conjoin())?; + builder.add_structural_invariant(decl)?; + } + vir_mid::TypeDecl::Enum(decl) => { + let mut equalities = Vec::new(); + if decl.safety.is_enum() { + let discriminant_equality = + builder.create_discriminant_snapshot_equality(decl)?; + builder.add_unfolding_postcondition(discriminant_equality)?; + } + for (discriminant, variant) in decl.iter_discriminant_variants() { + equalities + .push(builder.create_variant_snapshot_equality(discriminant, variant)?); + } + builder.add_unfolding_postcondition(equalities.into_iter().create_match())?; + } + vir_mid::TypeDecl::Reference(decl) => { + builder.add_bytes_address_snapshot_equality()?; + // FIXME: Have a getter for the first lifetime. + let lifetime = &decl.lifetimes[0]; + builder.add_reference_snapshot_equalities(decl, lifetime)?; + } + vir_mid::TypeDecl::Array(decl) => { + let length = if normalized_type.is_slice() { + builder.get_slice_len()? + } else { + decl.const_parameters[0].clone() + }; + builder.add_snapshot_len_equal_to_postcondition(&length)?; + builder.add_quantifiers(&length, &decl.element_type)?; + } + _ => { + unimplemented!("{}", type_decl); + } + } + let function = builder.build()?; + self.lowerer.declare_function(function)?; + Ok(()) + } + + pub(super) fn encode_unique_ref_current_snapshot( + &mut self, + normalized_type: &vir_mid::Type, + type_decl: &vir_mid::TypeDecl, + ) -> SpannedEncodingResult<()> { + let mut builder = + UniqueRefSnapFunctionBuilder::new(self.lowerer, normalized_type, type_decl, false)?; + builder.create_parameters()?; + builder.add_unique_ref_precondition()?; + match &type_decl { + vir_mid::TypeDecl::Bool + | vir_mid::TypeDecl::Int(_) + | vir_mid::TypeDecl::Float(_) + | vir_mid::TypeDecl::Pointer(_) + | vir_mid::TypeDecl::Sequence(_) + | vir_mid::TypeDecl::Map(_) => { + // For these types the unique ref predicate is abstract. + } + vir_mid::TypeDecl::Trusted(_) | vir_mid::TypeDecl::TypeVar(_) => {} + vir_mid::TypeDecl::Struct(decl) => { + let mut equalities = Vec::new(); + for field in &decl.fields { + equalities.push(builder.create_field_snapshot_equality(field)?); + } + builder.add_unfolding_postcondition(equalities.into_iter().conjoin())?; + builder.add_structural_invariant(decl)?; + } + vir_mid::TypeDecl::Enum(_decl) => {} + vir_mid::TypeDecl::Reference(_decl) => {} + vir_mid::TypeDecl::Array(_decl) => {} + _ => { + unimplemented!("{}", type_decl); + } + } + let function = builder.build()?; + self.lowerer.declare_function(function)?; + Ok(()) + } + + pub(super) fn encode_unique_ref_final_snapshot( + &mut self, + normalized_type: &vir_mid::Type, + type_decl: &vir_mid::TypeDecl, + ) -> SpannedEncodingResult<()> { + let mut builder = + UniqueRefSnapFunctionBuilder::new(self.lowerer, normalized_type, type_decl, true)?; + builder.create_parameters()?; + match &type_decl { + vir_mid::TypeDecl::Bool + | vir_mid::TypeDecl::Int(_) + | vir_mid::TypeDecl::Float(_) + | vir_mid::TypeDecl::Pointer(_) + | vir_mid::TypeDecl::Sequence(_) + | vir_mid::TypeDecl::Map(_) => { + // For these types the unique ref predicate is abstract. + } + vir_mid::TypeDecl::Trusted(_) | vir_mid::TypeDecl::TypeVar(_) => {} + vir_mid::TypeDecl::Struct(decl) => { + let mut equalities = Vec::new(); + for field in &decl.fields { + equalities.push(builder.create_field_snapshot_equality(field)?); + } + builder.add_postcondition(equalities.into_iter().conjoin())?; + builder.add_structural_invariant(decl)?; + } + vir_mid::TypeDecl::Enum(_decl) => {} + vir_mid::TypeDecl::Reference(_decl) => { + // For references, the final snapshot is abstract. + } + vir_mid::TypeDecl::Array(_decl) => {} + _ => { + unimplemented!("{}", type_decl); + } + } + let function = builder.build()?; + self.lowerer.declare_function(function)?; + Ok(()) + } + + pub(super) fn encode_frac_ref_snapshot( + &mut self, + normalized_type: &vir_mid::Type, + type_decl: &vir_mid::TypeDecl, + ) -> SpannedEncodingResult<()> { + let mut builder = + FracRefSnapFunctionBuilder::new(self.lowerer, normalized_type, type_decl)?; + builder.create_parameters()?; + builder.add_frac_ref_precondition()?; + match &type_decl { + vir_mid::TypeDecl::Bool + | vir_mid::TypeDecl::Int(_) + | vir_mid::TypeDecl::Float(_) + | vir_mid::TypeDecl::Pointer(_) + | vir_mid::TypeDecl::Sequence(_) + | vir_mid::TypeDecl::Map(_) => { + // For these types the unique ref predicate is abstract. + } + vir_mid::TypeDecl::Trusted(_) | vir_mid::TypeDecl::TypeVar(_) => {} + vir_mid::TypeDecl::Struct(decl) => { + let mut equalities = Vec::new(); + for field in &decl.fields { + equalities.push(builder.create_field_snapshot_equality(field)?); + } + builder.add_unfolding_postcondition(equalities.into_iter().conjoin())?; + builder.add_structural_invariant(decl)?; + } + vir_mid::TypeDecl::Enum(_decl) => {} + vir_mid::TypeDecl::Reference(_decl) => {} + vir_mid::TypeDecl::Array(_decl) => {} + _ => { + unimplemented!("{}", type_decl); + } + } + let function = builder.build()?; + self.lowerer.declare_function(function)?; + Ok(()) + } + pub(super) fn encode_owned_non_aliased( &mut self, ty: &vir_mid::Type, ) -> SpannedEncodingResult<()> { let ty_identifier = ty.get_identifier(); - if self.encoded_owned_predicates.contains(&ty_identifier) { + if self + .encoded_owned_non_aliased_predicates + .contains(&ty_identifier) + { return Ok(()); } - self.encoded_owned_predicates.insert(ty_identifier); + self.encoded_owned_non_aliased_predicates + .insert(ty_identifier); self.lowerer.encode_compute_address(ty)?; let type_decl = self.lowerer.encoder.get_type_decl_mid(ty)?; let normalized_type = ty.normalize_type(); self.lowerer .encode_snapshot_to_bytes_function(&normalized_type)?; + // if !config::use_snapshot_parameters_in_predicates() { + let (snap_function_name, snap_type) = + self.encode_owned_non_aliased_snapshot(&normalized_type, &type_decl)?; + // } let mut owned_predicates_to_encode = Vec::new(); let mut unique_ref_predicates_to_encode = Vec::new(); let mut frac_ref_predicates_to_encode = Vec::new(); @@ -75,7 +351,9 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { builder.create_parameters()?; if !(type_decl.is_type_var() || type_decl.is_trusted()) { builder.create_body(); - builder.add_validity()?; + if config::use_snapshot_parameters_in_predicates() { + builder.add_validity()?; + } } // Build the body. match &type_decl { @@ -86,20 +364,26 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { | vir_mid::TypeDecl::Sequence(_) | vir_mid::TypeDecl::Map(_) => { builder.add_base_memory_block()?; - builder.add_bytes_snapshot_equality()?; + // if config::use_snapshot_parameters_in_predicates() { + // builder.add_bytes_snapshot_equality()?; + // } + if let vir_mid::TypeDecl::Pointer(decl) = &type_decl { + owned_predicates_to_encode.push(decl.target_type.clone()); + } } vir_mid::TypeDecl::Trusted(_) | vir_mid::TypeDecl::TypeVar(_) => {} vir_mid::TypeDecl::Struct(decl) => { builder.add_padding_memory_block()?; for field in &decl.fields { builder.add_field_predicate(field)?; - if !self - .unfolded_owned_non_aliased_predicates - .contains(&field.ty) - { - owned_predicates_to_encode.push(field.ty.clone()); - } + // if !self + // .unfolded_owned_non_aliased_predicates + // .contains(&field.ty) + // { + owned_predicates_to_encode.push(field.ty.clone()); + // } } + owned_predicates_to_encode.extend(builder.add_structural_invariant(decl)?); } vir_mid::TypeDecl::Enum(decl) => { builder.add_padding_memory_block()?; @@ -114,24 +398,25 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { decl.lifetimes.clone(), ); variant_predicates.push(builder.create_variant_predicate( + decl, discriminant, variant, &variant_type, )?); let variant_type = ty.clone().variant(variant_index); - if !self - .unfolded_owned_non_aliased_predicates - .contains(&variant_type) - { - owned_predicates_to_encode.push(variant_type); - } - } - if !self - .unfolded_owned_non_aliased_predicates - .contains(&decl.discriminant_type) - { - owned_predicates_to_encode.push(decl.discriminant_type.clone()); + // if !self + // .unfolded_owned_non_aliased_predicates + // .contains(&variant_type) + // { + owned_predicates_to_encode.push(variant_type); + // } } + // if !self + // .unfolded_owned_non_aliased_predicates + // .contains(&decl.discriminant_type) + // { + owned_predicates_to_encode.push(decl.discriminant_type.clone()); + // } if decl.safety.is_enum() { builder.add_discriminant_predicate(decl)?; } @@ -139,7 +424,9 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { } vir_mid::TypeDecl::Reference(decl) => { builder.add_base_memory_block()?; - builder.add_bytes_address_snapshot_equality()?; + if config::use_snapshot_parameters_in_predicates() { + builder.add_bytes_address_snapshot_equality()?; + } // FIXME: Have a getter for the first lifetime. let lifetime = &decl.lifetimes[0]; if decl.uniqueness.is_unique() { @@ -163,7 +450,9 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { } else { decl.const_parameters[0].clone() }; - builder.add_snapshot_len_equal_to(&length)?; + if config::use_snapshot_parameters_in_predicates() { + builder.add_snapshot_len_equal_to(&length)?; + } builder.add_quantified_permission(&length, &decl.element_type)?; } _ => { @@ -172,7 +461,15 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { unimplemented!("{}", type_decl); } } - self.predicates.push(builder.build()); + let predicate = builder.build(); + self.predicate_info.insert( + predicate.name.clone(), + PredicateInfo { + snapshot_function_name: snap_function_name, + snapshot_type: snap_type, + }, + ); + self.predicates.push(predicate); for ty in owned_predicates_to_encode { // TODO: Optimization: This variant is never unfolded, // encode it as abstract predicate. @@ -191,6 +488,133 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { Ok(()) } + pub(super) fn encode_owned_aliased(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()> { + let ty_identifier = ty.get_identifier(); + if self + .encoded_owned_aliased_predicates + .contains(&ty_identifier) + { + return Ok(()); + } + + self.encoded_owned_aliased_predicates.insert(ty_identifier); + self.lowerer.encode_compute_address(ty)?; + let type_decl = self.lowerer.encoder.get_type_decl_mid(ty)?; + + let normalized_type = ty.normalize_type(); + self.lowerer + .encode_snapshot_to_bytes_function(&normalized_type)?; + self.encode_owned_aliased_snapshot(&normalized_type, &type_decl)?; + let mut owned_predicates_to_encode = Vec::new(); + let mut unique_ref_predicates_to_encode = Vec::new(); + let mut frac_ref_predicates_to_encode = Vec::new(); + self.lowerer.encode_memory_block_predicate()?; + let mut builder = OwnedAliasedBuilder::new(self.lowerer, &normalized_type, &type_decl)?; + builder.create_parameters()?; + if !(type_decl.is_type_var() || type_decl.is_trusted()) { + builder.create_body(); + } + // Build the body. + match &type_decl { + vir_mid::TypeDecl::Bool + | vir_mid::TypeDecl::Int(_) + | vir_mid::TypeDecl::Float(_) + | vir_mid::TypeDecl::Pointer(_) + | vir_mid::TypeDecl::Sequence(_) + | vir_mid::TypeDecl::Map(_) => { + builder.add_base_memory_block()?; + if let vir_mid::TypeDecl::Pointer(decl) = &type_decl { + owned_predicates_to_encode.push(decl.target_type.clone()); + } + } + vir_mid::TypeDecl::Trusted(_) | vir_mid::TypeDecl::TypeVar(_) => {} + vir_mid::TypeDecl::Struct(decl) => { + builder.add_padding_memory_block()?; + for field in &decl.fields { + builder.add_field_predicate(field)?; + owned_predicates_to_encode.push(field.ty.clone()); + } + owned_predicates_to_encode.extend(builder.add_structural_invariant(decl)?); + } + vir_mid::TypeDecl::Enum(decl) => { + builder.add_padding_memory_block()?; + let mut variant_predicates = Vec::new(); + for (discriminant, variant) in decl.iter_discriminant_variants() { + let variant_index: vir_mid::ty::VariantIndex = variant.name.clone().into(); + let variant_type = vir_mid::Type::enum_( + decl.name.clone(), + decl.safety, + decl.arguments.clone(), + Some(variant_index.clone()), + decl.lifetimes.clone(), + ); + variant_predicates.push(builder.create_variant_predicate( + decl, + discriminant, + variant, + &variant_type, + )?); + let variant_type = ty.clone().variant(variant_index); + owned_predicates_to_encode.push(variant_type); + } + owned_predicates_to_encode.push(decl.discriminant_type.clone()); + if decl.safety.is_enum() { + builder.add_discriminant_predicate(decl)?; + } + builder.add_variant_predicates(variant_predicates)?; + } + vir_mid::TypeDecl::Reference(decl) => { + builder.add_base_memory_block()?; + // FIXME: Have a getter for the first lifetime. + let lifetime = &decl.lifetimes[0]; + if decl.uniqueness.is_unique() { + builder.add_unique_ref_target_predicate(&decl.target_type, lifetime)?; + unique_ref_predicates_to_encode.push(decl.target_type.clone()); + } else { + builder.add_frac_ref_target_predicate(&decl.target_type, lifetime)?; + frac_ref_predicates_to_encode.push(decl.target_type.clone()); + } + } + vir_mid::TypeDecl::Array(decl) => { + builder.lowerer().encode_place_array_index_axioms(ty)?; + builder + .lowerer() + .ensure_type_definition(&decl.element_type)?; + owned_predicates_to_encode.push(decl.element_type.clone()); + builder.add_const_parameters_validity()?; + // FIXME: Have a getter for the first const parameter. + let length = if normalized_type.is_slice() { + builder.get_slice_len()? + } else { + decl.const_parameters[0].clone() + }; + unimplemented!(); + // builder.add_quantified_permission(&length, &decl.element_type)?; + } + _ => { + builder.add_base_memory_block()?; + unimplemented!("{}", type_decl); + } + } + self.predicates.push(builder.build()); + for ty in owned_predicates_to_encode { + // TODO: Optimization: This variant is never unfolded, + // encode it as abstract predicate. + self.encode_owned_aliased(&ty)?; + } + for ty in unique_ref_predicates_to_encode { + // TODO: Optimization: This variant is never unfolded, + // encode it as abstract predicate. + self.encode_unique_ref(&ty)?; + } + for ty in frac_ref_predicates_to_encode { + // TODO: Optimization: This variant is never unfolded, + // encode it as abstract predicate. + self.encode_frac_ref(&ty)?; + } + Ok(()) + } + fn encode_frac_ref(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()> { let ty_identifier = ty.get_identifier(); if self.encoded_frac_borrow_predicates.contains(&ty_identifier) { @@ -198,10 +622,12 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { } self.encoded_frac_borrow_predicates.insert(ty_identifier); self.lowerer.encode_compute_address(ty)?; - let type_decl = self.lowerer.encoder.get_type_decl_mid(ty)?; let normalized_type = ty.normalize_type(); + if !config::use_snapshot_parameters_in_predicates() { + self.encode_frac_ref_snapshot(&normalized_type, &type_decl)?; + } let mut predicates_to_encode = Vec::new(); let mut builder = FracRefBuilder::new(self.lowerer, &normalized_type, &type_decl)?; builder.create_parameters()?; @@ -217,7 +643,7 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { | vir_mid::TypeDecl::TypeVar(_) ) { builder.create_body(); - builder.add_validity()?; + // builder.add_validity()?; } // Build the body. match &type_decl { @@ -234,6 +660,7 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { builder.add_field_predicate(field)?; predicates_to_encode.push(field.ty.clone()); } + predicates_to_encode.extend(builder.add_structural_invariant(decl)?); } vir_mid::TypeDecl::Enum(decl) => { let mut variant_predicates = Vec::new(); @@ -300,6 +727,10 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { // FIXME: Make get_type_decl_mid to return the erased ty for which it // returned type_decl. let normalized_type = ty.normalize_type(); + // if !config::use_snapshot_parameters_in_predicates() { + self.encode_unique_ref_current_snapshot(&normalized_type, &type_decl)?; + self.encode_unique_ref_final_snapshot(&normalized_type, &type_decl)?; + // } let mut unique_ref_predicates_to_encode = Vec::new(); let mut frac_ref_predicates_to_encode = Vec::new(); let mut builder = UniqueRefBuilder::new(self.lowerer, &normalized_type, &type_decl)?; @@ -316,7 +747,9 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { | vir_mid::TypeDecl::TypeVar(_) ) { builder.create_body(); - builder.add_validity()?; + // if config::use_snapshot_parameters_in_predicates() { + // builder.add_validity()?; + // } } // Build the body. match &type_decl { @@ -333,6 +766,7 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { builder.add_field_predicate(field)?; unique_ref_predicates_to_encode.push(field.ty.clone()); } + unique_ref_predicates_to_encode.extend(builder.add_structural_invariant(decl)?); } vir_mid::TypeDecl::Enum(decl) => { let mut variant_predicates = Vec::new(); @@ -355,12 +789,15 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { vir_mid::TypeDecl::Reference(decl) => { // FIXME: Have a getter for the first lifetime. let lifetime = &decl.lifetimes[0]; + let pointer_type = builder.add_unique_ref_pointer_predicate(lifetime)?; if decl.uniqueness.is_unique() { builder.add_unique_ref_target_predicate(&decl.target_type, lifetime)?; unique_ref_predicates_to_encode.push(decl.target_type.clone()); + unique_ref_predicates_to_encode.push(pointer_type); } else { builder.add_frac_ref_target_predicate(&decl.target_type, lifetime)?; frac_ref_predicates_to_encode.push(decl.target_type.clone()); + frac_ref_predicates_to_encode.push(pointer_type); } } vir_mid::TypeDecl::Array(decl) => { @@ -376,7 +813,9 @@ impl<'l, 'p, 'v, 'tcx> PredicateEncoder<'l, 'p, 'v, 'tcx> { } else { decl.const_parameters[0].clone() }; - builder.add_snapshot_len_equal_to(&length)?; + if config::use_snapshot_parameters_in_predicates() { + builder.add_snapshot_len_equal_to(&length)?; + } builder.add_quantified_permission(&length, &decl.element_type)?; } _ => { diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/interface.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/interface.rs index 16dc6b0981b..d384230bd33 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/interface.rs @@ -1,23 +1,46 @@ +use std::collections::BTreeMap; + use super::{ - encoder::PredicateEncoder, FracRefUseBuilder, OwnedNonAliasedUseBuilder, UniqueRefUseBuilder, + builders::{ + FracRefSnapCallBuilder, OwnedAliasedRangeUseBuilder, OwnedAliasedUseBuilder, + OwnedNonAliasedSnapCallBuilder, UniqueRefSnapCallBuilder, + }, + encoder::PredicateEncoder, + FracRefUseBuilder, OwnedAliasedSnapCallBuilder, OwnedNonAliasedUseBuilder, UniqueRefUseBuilder, }; use crate::encoder::{ errors::SpannedEncodingResult, - middle::core_proof::{builtin_methods::CallContext, lowerer::Lowerer, types::TypesInterface}, + middle::core_proof::{ + addresses::AddressesInterface, builtin_methods::CallContext, lowerer::Lowerer, + places::PlacesInterface, types::TypesInterface, + }, }; +use prusti_common::config; use rustc_hash::FxHashSet; use vir_crate::{ low::{self as vir_low}, middle::{ self as vir_mid, - operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes}, + operations::{const_generics::WithConstArguments, lifetimes::WithLifetimes, ty::Typed}, }, }; #[derive(Default)] pub(in super::super) struct PredicatesOwnedState { unfolded_owned_non_aliased_predicates: FxHashSet, + unfolded_owned_aliased_predicates: FxHashSet, used_unique_ref_predicates: FxHashSet, + snap_wrappers: FxHashSet, +} + +/// Addidional information about the predicate used by purification +/// optimizations. +#[derive(Clone, Debug)] +pub(in super::super::super) struct PredicateInfo { + /// The name of the snapshot function. + pub(in super::super::super) snapshot_function_name: String, + /// The snapshot type. + pub(in super::super::super) snapshot_type: vir_low::Type, } pub(in super::super::super) trait PredicatesOwnedInterface { @@ -27,12 +50,16 @@ pub(in super::super::super) trait PredicatesOwnedInterface { &mut self, ty: &vir_mid::Type, ) -> SpannedEncodingResult<()>; + /// Marks that `OwnedAliased` was unfolded in the program and we need to + /// provide its body. + fn mark_owned_aliased_as_unfolded(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()>; /// Marks that `UniqueRef` was used in the program. fn mark_unique_ref_as_used(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()>; fn collect_owned_predicate_decls( &mut self, - ) -> SpannedEncodingResult>; + ) -> SpannedEncodingResult<(Vec, BTreeMap)>; /// A version of `owned_non_aliased` for the most common case. + #[allow(clippy::too_many_arguments)] fn owned_non_aliased_full_vars( &mut self, context: CallContext, @@ -41,6 +68,7 @@ pub(in super::super::super) trait PredicatesOwnedInterface { place: &vir_low::VariableDecl, root_address: &vir_low::VariableDecl, snapshot: &vir_low::VariableDecl, + must_be_predicate: bool, ) -> SpannedEncodingResult where G: WithLifetimes + WithConstArguments; @@ -55,6 +83,72 @@ pub(in super::super::super) trait PredicatesOwnedInterface { snapshot: vir_low::Expression, permission_amount: Option, ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + /// The result is guaranteed to be a `acc(predicate)`, which is needed + /// for fold/unfold operations. + #[allow(clippy::too_many_arguments)] + fn owned_non_aliased_predicate( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + snapshot: vir_low::Expression, + permission_amount: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + #[allow(clippy::too_many_arguments)] + fn owned_non_aliased_snap( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + #[allow(clippy::too_many_arguments)] + fn owned_aliased_snap( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + address: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + // fn wrap_snap_into_bool( + // &mut self, + // ty: &vir_mid::Type, + // expression: vir_low::Expression, + // ) -> SpannedEncodingResult; + #[allow(clippy::too_many_arguments)] + fn owned_aliased( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + address: vir_low::Expression, + permission_amount: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + fn owned_aliased_range( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + address: vir_low::Expression, + start_index: vir_low::Expression, + end_index: vir_low::Expression, + permission_amount: Option, + ) -> SpannedEncodingResult where G: WithLifetimes + WithConstArguments; #[allow(clippy::too_many_arguments)] @@ -68,6 +162,7 @@ pub(in super::super::super) trait PredicatesOwnedInterface { current_snapshot: &vir_low::VariableDecl, final_snapshot: &vir_low::VariableDecl, lifetime: &vir_low::VariableDecl, + target_slice_len: Option, ) -> SpannedEncodingResult where G: WithLifetimes + WithConstArguments; @@ -82,6 +177,36 @@ pub(in super::super::super) trait PredicatesOwnedInterface { current_snapshot: vir_low::Expression, final_snapshot: vir_low::Expression, lifetime: vir_low::Expression, + target_slice_len: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + #[allow(clippy::too_many_arguments)] + fn unique_ref_predicate( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + current_snapshot: vir_low::Expression, + final_snapshot: vir_low::Expression, + lifetime: vir_low::Expression, + target_slice_len: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + #[allow(clippy::too_many_arguments)] + fn unique_ref_snap( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + lifetime: vir_low::Expression, + target_slice_len: Option, + is_final: bool, ) -> SpannedEncodingResult where G: WithLifetimes + WithConstArguments; @@ -99,6 +224,19 @@ pub(in super::super::super) trait PredicatesOwnedInterface { where G: WithLifetimes + WithConstArguments; #[allow(clippy::too_many_arguments)] + fn frac_ref_full_vars_opt( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: &vir_low::VariableDecl, + root_address: &vir_low::VariableDecl, + current_snapshot: &Option, + lifetime: &vir_low::VariableDecl, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + #[allow(clippy::too_many_arguments)] fn frac_ref( &mut self, context: CallContext, @@ -111,6 +249,58 @@ pub(in super::super::super) trait PredicatesOwnedInterface { ) -> SpannedEncodingResult where G: WithLifetimes + WithConstArguments; + #[allow(clippy::too_many_arguments)] + fn frac_ref_opt( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + current_snapshot: Option, + lifetime: vir_low::Expression, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + fn frac_ref_predicate( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + current_snapshot: vir_low::Expression, + lifetime: vir_low::Expression, + target_slice_len: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + #[allow(clippy::too_many_arguments)] + fn frac_ref_snap( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + lifetime: vir_low::Expression, + target_slice_len: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments; + // /// A snap function that is chosen based on whether the place is behind + // /// a reference or not. + // fn place_snap( + // &mut self, + // context: CallContext, + // ty: &vir_mid::Type, + // generics: &G, + // place: &vir_mid::Expression, + // position: vir_low::Position, + // deref_to_final: bool, + // ) -> SpannedEncodingResult + // where + // G: WithLifetimes + WithConstArguments; } impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { @@ -133,6 +323,22 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { Ok(()) } + fn mark_owned_aliased_as_unfolded(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()> { + if !self + .predicates_encoding_state + .owned + .unfolded_owned_aliased_predicates + .contains(ty) + { + self.ensure_type_definition(ty)?; + self.predicates_encoding_state + .owned + .unfolded_owned_aliased_predicates + .insert(ty.clone()); + } + Ok(()) + } + fn mark_unique_ref_as_used(&mut self, ty: &vir_mid::Type) -> SpannedEncodingResult<()> { if !self .predicates_encoding_state @@ -150,27 +356,37 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { fn collect_owned_predicate_decls( &mut self, - ) -> SpannedEncodingResult> { - let unfolded_predicates = std::mem::take( + ) -> SpannedEncodingResult<(Vec, BTreeMap)> { + let unfolded_owned_non_aliased_predicates = std::mem::take( &mut self .predicates_encoding_state .owned .unfolded_owned_non_aliased_predicates, ); + let unfolded_owned_aliased_predicates = std::mem::take( + &mut self + .predicates_encoding_state + .owned + .unfolded_owned_aliased_predicates, + ); let used_unique_ref_predicates = std::mem::take( &mut self .predicates_encoding_state .owned .used_unique_ref_predicates, ); - let mut predicate_encoder = PredicateEncoder::new(self, &unfolded_predicates); - for ty in &unfolded_predicates { + let mut predicate_encoder = PredicateEncoder::new(self); + for ty in &unfolded_owned_non_aliased_predicates { predicate_encoder.encode_owned_non_aliased(ty)?; } + for ty in &unfolded_owned_aliased_predicates { + predicate_encoder.encode_owned_aliased(ty)?; + } for ty in &used_unique_ref_predicates { predicate_encoder.encode_unique_ref(ty)?; } - Ok(predicate_encoder.into_predicates()) + let predicate_info = predicate_encoder.take_predicate_info(); + Ok((predicate_encoder.into_predicates(), predicate_info)) } fn owned_non_aliased_full_vars( @@ -181,19 +397,32 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { place: &vir_low::VariableDecl, root_address: &vir_low::VariableDecl, snapshot: &vir_low::VariableDecl, + must_be_predicate: bool, ) -> SpannedEncodingResult where G: WithLifetimes + WithConstArguments, { - self.owned_non_aliased( - context, - ty, - generics, - place.clone().into(), - root_address.clone().into(), - snapshot.clone().into(), - None, - ) + if must_be_predicate { + self.owned_non_aliased_predicate( + context, + ty, + generics, + place.clone().into(), + root_address.clone().into(), + snapshot.clone().into(), + None, + ) + } else { + self.owned_non_aliased( + context, + ty, + generics, + place.clone().into(), + root_address.clone().into(), + snapshot.clone().into(), + None, + ) + } } fn owned_non_aliased( @@ -209,19 +438,179 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { where G: WithLifetimes + WithConstArguments, { - let mut builder = OwnedNonAliasedUseBuilder::new( + let mut builder = + OwnedNonAliasedUseBuilder::new(self, context, ty, generics, place, root_address)?; + builder.add_snapshot_argument(snapshot)?; + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.set_maybe_permission_amount(permission_amount)?; + builder.build() + } + + fn owned_non_aliased_predicate( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + snapshot: vir_low::Expression, + permission_amount: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let mut builder = + OwnedNonAliasedUseBuilder::new(self, context, ty, generics, place, root_address)?; + if config::use_snapshot_parameters_in_predicates() { + builder.add_snapshot_argument(snapshot)?; + } + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.set_maybe_permission_amount(permission_amount)?; + builder.build() + } + + fn owned_non_aliased_snap( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let mut builder = OwnedNonAliasedSnapCallBuilder::new( self, context, ty, generics, place, root_address, - snapshot, + position, )?; builder.add_lifetime_arguments()?; builder.add_const_arguments()?; + builder.build() + } + + fn owned_aliased_snap( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + address: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let mut builder = + OwnedAliasedSnapCallBuilder::new(self, context, ty, generics, address, position)?; + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.build() + } + + // fn wrap_snap_into_bool( + // &mut self, + // ty: &vir_mid::Type, + // expression: vir_low::Expression, + // ) -> SpannedEncodingResult { + // let ty_identifier = ty.get_identifier(); + // let name = format!("snap_call_wrapper${}", ty_identifier); + // let domain_name = "SnapCallWrappers"; + // let position = expression.position(); + // if !self + // .predicates_encoding_state + // .owned + // .snap_wrappers + // .contains(&ty_identifier) + // { + // use vir_low::macros::*; + // var_decls!( + // snapshot: { ty.to_snapshot(self)? } + // ); + // let call = self.create_domain_func_app( + // domain_name, + // name.clone(), + // vec![snapshot.clone().into()], + // vir_low::Type::Bool, + // position, + // )?; + // let body = vir_low::Expression::forall( + // vec![snapshot], + // vec![vir_low::Trigger::new(vec![call.clone()])], + // call, + // ); + // // let body = expr! { + // // forall( snapshot: { ty.to_snapshot(self)? } :: [ {[call.clone()]} ] [call] ) + // // }; + // let axiom = vir_low::DomainAxiomDecl { + // name: format!("snap_call_wrapper_always_true${}", ty_identifier), + // body, + // }; + // self.declare_axiom(domain_name, axiom)?; + // self.predicates_encoding_state + // .owned + // .snap_wrappers + // .insert(ty_identifier); + // } + // self.create_domain_func_app( + // domain_name, + // name, + // vec![expression], + // vir_low::Type::Bool, + // position, + // ) + // } + + fn owned_aliased( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + address: vir_low::Expression, + permission_amount: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + self.mark_owned_aliased_as_unfolded(ty)?; + let mut builder = OwnedAliasedUseBuilder::new(self, context, ty, generics, address)?; + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; builder.set_maybe_permission_amount(permission_amount)?; - Ok(builder.build()) + builder.build() + } + + fn owned_aliased_range( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + address: vir_low::Expression, + start_index: vir_low::Expression, + end_index: vir_low::Expression, + permission_amount: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let builder = OwnedAliasedRangeUseBuilder::new( + self, + context, + ty, + generics, + address, + start_index, + end_index, + permission_amount, + )?; + builder.build() } fn unique_ref_full_vars( @@ -234,6 +623,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { current_snapshot: &vir_low::VariableDecl, final_snapshot: &vir_low::VariableDecl, lifetime: &vir_low::VariableDecl, + target_slice_len: Option, ) -> SpannedEncodingResult where G: WithLifetimes + WithConstArguments, @@ -247,6 +637,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { current_snapshot.clone().into(), final_snapshot.clone().into(), lifetime.clone().into(), + target_slice_len, ) } @@ -260,6 +651,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { current_snapshot: vir_low::Expression, final_snapshot: vir_low::Expression, lifetime: vir_low::Expression, + target_slice_len: Option, ) -> SpannedEncodingResult where G: WithLifetimes + WithConstArguments, @@ -271,13 +663,80 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { generics, place, root_address, - current_snapshot, - final_snapshot, + // current_snapshot, + // final_snapshot, + lifetime, + target_slice_len, + )?; + builder.add_current_snapshot_argument(current_snapshot)?; + // builder.add_snapshot_arguments(current_snapshot, final_snapshot)?; + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.build() + } + + fn unique_ref_predicate( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + current_snapshot: vir_low::Expression, + final_snapshot: vir_low::Expression, + lifetime: vir_low::Expression, + target_slice_len: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let mut builder = UniqueRefUseBuilder::new( + self, + context, + ty, + generics, + place, + root_address, + // final_snapshot, + lifetime, + target_slice_len, + )?; + // if config::use_snapshot_parameters_in_predicates() { + // builder.add_snapshot_arguments(current_snapshot, final_snapshot)?; + // } + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.build() + } + + fn unique_ref_snap( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + lifetime: vir_low::Expression, + target_slice_len: Option, + is_final: bool, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let mut builder = UniqueRefSnapCallBuilder::new( + self, + context, + ty, + generics, + place, + root_address, lifetime, + target_slice_len, + is_final, )?; builder.add_lifetime_arguments()?; builder.add_const_arguments()?; - Ok(builder.build()) + builder.build() } fn frac_ref_full_vars( @@ -304,6 +763,36 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { ) } + fn frac_ref_full_vars_opt( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: &vir_low::VariableDecl, + root_address: &vir_low::VariableDecl, + current_snapshot: &Option, + lifetime: &vir_low::VariableDecl, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let mut builder = FracRefUseBuilder::new( + self, + context, + ty, + generics, + place.clone().into(), + root_address.clone().into(), + lifetime.clone().into(), + )?; + if let Some(current_snapshot) = current_snapshot { + builder.add_snapshot_argument(current_snapshot.clone().into())?; + } + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.build() + } + fn frac_ref( &mut self, context: CallContext, @@ -314,6 +803,55 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { current_snapshot: vir_low::Expression, lifetime: vir_low::Expression, ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + self.frac_ref_opt( + context, + ty, + generics, + place, + root_address, + Some(current_snapshot), + lifetime, + ) + } + + fn frac_ref_opt( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + current_snapshot: Option, + lifetime: vir_low::Expression, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let mut builder = + FracRefUseBuilder::new(self, context, ty, generics, place, root_address, lifetime)?; + if let Some(current_snapshot) = current_snapshot { + builder.add_snapshot_argument(current_snapshot)?; + } + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.build() + } + + fn frac_ref_predicate( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + // FIXME: Remove the current_snapshto argument. + current_snapshot: vir_low::Expression, + lifetime: vir_low::Expression, + target_slice_len: Option, + ) -> SpannedEncodingResult where G: WithLifetimes + WithConstArguments, { @@ -324,11 +862,78 @@ impl<'p, 'v: 'p, 'tcx: 'v> PredicatesOwnedInterface for Lowerer<'p, 'v, 'tcx> { generics, place, root_address, - current_snapshot, lifetime, + // target_slice_len, + )?; + builder.add_lifetime_arguments()?; + builder.add_const_arguments()?; + builder.build() + } + + fn frac_ref_snap( + &mut self, + context: CallContext, + ty: &vir_mid::Type, + generics: &G, + place: vir_low::Expression, + root_address: vir_low::Expression, + lifetime: vir_low::Expression, + target_slice_len: Option, + ) -> SpannedEncodingResult + where + G: WithLifetimes + WithConstArguments, + { + let mut builder = FracRefSnapCallBuilder::new( + self, + context, + ty, + generics, + place, + root_address, + lifetime, + target_slice_len, )?; builder.add_lifetime_arguments()?; builder.add_const_arguments()?; - Ok(builder.build()) + builder.build() } + + // fn place_snap( + // &mut self, + // context: CallContext, + // ty: &vir_mid::Type, + // generics: &G, + // place: &vir_mid::Expression, + // position: vir_low::Position, + // deref_to_final: bool, + // ) -> SpannedEncodingResult + // where + // G: WithLifetimes + WithConstArguments, + // { + // let place_low = self.encode_expression_as_place(place)?; + // let root_address = self.extract_root_address(place)?; + // if let Some(reference_place) = place.get_first_dereferenced_reference() { + // let vir_mid::Type::Reference(reference_type) = reference_place.get_type() else { + // unreachable!() + // }; + // let TODO_target_slice_len = None; + // match reference_type.uniqueness { + // vir_mid::ty::Uniqueness::Unique => { + // self.unique_ref_snap( + // context, ty, generics, place_low, root_address, + // reference_type.lifetime, TODO_target_slice_len, deref_to_final) + // }, + // vir_mid::ty::Uniqueness::Shared => todo!(), + // } + // } else { + // self.owned_non_aliased_snap( + // context, + // ty, + // generics, + // place_low, + // root_address, + // position, + // ) + // } + // } } diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/mod.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/mod.rs index 75979d536e4..c8dc3789338 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/owned/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/owned/mod.rs @@ -6,6 +6,9 @@ mod interface; pub(super) use self::interface::PredicatesOwnedState; pub(in super::super) use self::{ - builders::{FracRefUseBuilder, OwnedNonAliasedUseBuilder, UniqueRefUseBuilder}, - interface::PredicatesOwnedInterface, + builders::{ + FracRefUseBuilder, OwnedAliasedSnapCallBuilder, OwnedNonAliasedSnapCallBuilder, + OwnedNonAliasedUseBuilder, UniqueRefUseBuilder, + }, + interface::{PredicateInfo, PredicatesOwnedInterface}, }; diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/restoration/interface.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/restoration/interface.rs new file mode 100644 index 00000000000..c6a31a14fa7 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/restoration/interface.rs @@ -0,0 +1,70 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::lowerer::{Lowerer, PredicatesLowererInterface}, +}; +use rustc_hash::FxHashSet; +use vir_crate::{common::identifier::WithIdentifier, low as vir_low, middle as vir_mid}; + +#[derive(Default)] +pub(in super::super) struct RestorationState { + encoded_restore_raw_borrowed_transition_predicate: FxHashSet, +} + +pub(in super::super::super) trait RestorationInterface { + fn encode_restore_raw_borrowed_transition_predicate( + &mut self, + ty: &vir_mid::Type, + ) -> SpannedEncodingResult<()>; + fn restore_raw_borrowed( + &mut self, + ty: &vir_mid::Type, + place: vir_low::Expression, + address: vir_low::Expression, + ) -> SpannedEncodingResult; +} + +impl<'p, 'v: 'p, 'tcx: 'v> RestorationInterface for Lowerer<'p, 'v, 'tcx> { + fn encode_restore_raw_borrowed_transition_predicate( + &mut self, + ty: &vir_mid::Type, + ) -> SpannedEncodingResult<()> { + let ty_identifier = ty.get_identifier(); + if !self + .predicates_encoding_state + .restoration + .encoded_restore_raw_borrowed_transition_predicate + .contains(&ty_identifier) + { + self.predicates_encoding_state + .restoration + .encoded_restore_raw_borrowed_transition_predicate + .insert(ty_identifier); + + use vir_low::macros::*; + let predicate = vir_low::PredicateDecl::new( + predicate_name! { RestoreRawBorrowed }, + vir_low::PredicateKind::WithoutSnapshotWhole, + vars!(place: Place, address: Address), + None, + ); + self.declare_predicate(predicate)?; + } + Ok(()) + } + fn restore_raw_borrowed( + &mut self, + ty: &vir_mid::Type, + place: vir_low::Expression, + address: vir_low::Expression, + ) -> SpannedEncodingResult { + self.encode_restore_raw_borrowed_transition_predicate(ty)?; + use vir_low::macros::*; + let predicate = expr! { + acc(RestoreRawBorrowed( + [place], + [address] + )) + }; + Ok(predicate) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/restoration/mod.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/restoration/mod.rs new file mode 100644 index 00000000000..58ddd243565 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/restoration/mod.rs @@ -0,0 +1,6 @@ +//! Encoder of predicates that guard restoration of permissions. + +mod interface; + +pub(in super::super) use self::interface::RestorationInterface; +pub(super) use self::interface::RestorationState; diff --git a/prusti-viper/src/encoder/middle/core_proof/predicates/state.rs b/prusti-viper/src/encoder/middle/core_proof/predicates/state.rs index 6b5f047885e..ea46746dee4 100644 --- a/prusti-viper/src/encoder/middle/core_proof/predicates/state.rs +++ b/prusti-viper/src/encoder/middle/core_proof/predicates/state.rs @@ -1,7 +1,11 @@ -use super::{memory_block::PredicatesMemoryBlockState, owned::PredicatesOwnedState}; +use super::{ + memory_block::PredicatesMemoryBlockState, owned::PredicatesOwnedState, + restoration::RestorationState, +}; #[derive(Default)] pub(in super::super) struct PredicatesState { pub(super) owned: PredicatesOwnedState, pub(super) memory_block: PredicatesMemoryBlockState, + pub(super) restoration: RestorationState, } diff --git a/prusti-viper/src/encoder/middle/core_proof/references/interface.rs b/prusti-viper/src/encoder/middle/core_proof/references/interface.rs index 20ec1f726b2..68c437e4bff 100644 --- a/prusti-viper/src/encoder/middle/core_proof/references/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/references/interface.rs @@ -7,6 +7,7 @@ use crate::encoder::{ snapshots::{ IntoSnapshot, SnapshotAdtsInterface, SnapshotDomainsInterface, SnapshotValuesInterface, }, + type_layouts::TypeLayoutsInterface, types::TypesInterface, }, }; @@ -72,6 +73,12 @@ pub(in super::super) trait ReferencesInterface { snapshot: vir_low::Expression, position: vir_low::Position, ) -> SpannedEncodingResult; + fn reference_slice_len( + &mut self, + reference_type: &vir_mid::Type, + snapshot: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult>; fn reference_address_snapshot( &mut self, reference_type: &vir_mid::Type, @@ -133,11 +140,33 @@ impl<'p, 'v: 'p, 'tcx: 'v> ReferencesInterface for Lowerer<'p, 'v, 'tcx> { position: vir_low::Position, ) -> SpannedEncodingResult { assert!(reference_type.is_reference()); - let domain_name = self.encode_snapshot_domain_name(reference_type)?; + // let domain_name = self.encode_snapshot_domain_name(reference_type)?; let return_type = self.address_type()?; - Ok(self - .snapshot_destructor_struct_call(&domain_name, "address", return_type, snapshot)? - .set_default_position(position)) + self.obtain_parameter_snapshot(reference_type, "address", return_type, snapshot, position) + // Ok(self + // .snapshot_destructor_struct_call(&domain_name, "address", return_type, snapshot)? + // .set_default_position(position)) + } + fn reference_slice_len( + &mut self, + reference_type: &vir_mid::Type, + snapshot: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult> { + assert!(reference_type.is_reference()); + let len = if reference_type.is_reference_to_slice() { + let return_type = self.size_type()?; + Some(self.obtain_parameter_snapshot( + reference_type, + "len", + return_type, + snapshot, + position, + )?) + } else { + None + }; + Ok(len) } fn reference_address_snapshot( &mut self, @@ -145,9 +174,13 @@ impl<'p, 'v: 'p, 'tcx: 'v> ReferencesInterface for Lowerer<'p, 'v, 'tcx> { snapshot: vir_low::Expression, position: vir_low::Position, ) -> SpannedEncodingResult { - let address = self.reference_address(reference_type, snapshot, position)?; + let address = self.reference_address(reference_type, snapshot.clone(), position)?; + let mut arguments = vec![address]; let address_type = self.reference_address_type(reference_type)?; - self.construct_struct_snapshot(&address_type, vec![address], position) + if let Some(len) = self.reference_slice_len(reference_type, snapshot, position)? { + arguments.push(len); + }; + self.construct_struct_snapshot(&address_type, arguments, position) } fn reference_address_type( &mut self, diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/bytes/interface.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/bytes/interface.rs index cd6f7741c4c..dd8efe349ec 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/bytes/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/bytes/interface.rs @@ -6,8 +6,9 @@ use crate::encoder::{ snapshots::SnapshotDomainsInterface, }, }; +use prusti_common::config; use vir_crate::{ - common::identifier::WithIdentifier, + common::{expression::QuantifierHelpers, identifier::WithIdentifier}, low::{self as vir_low}, middle::{self as vir_mid}, }; @@ -38,13 +39,66 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotBytesInterface for Lowerer<'p, 'v, 'tcx> { let domain_name = self.encode_snapshot_domain_name(ty)?; let domain_type = self.encode_snapshot_domain_type(ty)?; let return_type = self.bytes_type()?; + let to_bytes = format!("to_bytes${}", ty.get_identifier()); + let snapshot = vir_low::VariableDecl::new("snapshot", domain_type.clone()); self.declare_domain_function( &domain_name, - std::borrow::Cow::Owned(format!("to_bytes${}", ty.get_identifier())), + std::borrow::Cow::Owned(to_bytes.clone()), false, - std::borrow::Cow::Owned(vec![vir_low::VariableDecl::new("snapshot", domain_type)]), - std::borrow::Cow::Owned(return_type), + std::borrow::Cow::Owned(vec![snapshot.clone()]), + std::borrow::Cow::Owned(return_type.clone()), )?; + if !config::use_snapshot_parameters_in_predicates() + && matches!( + ty, + vir_mid::Type::Bool + | vir_mid::Type::Int(_) + | vir_mid::Type::Float(_) + | vir_mid::Type::Pointer(_) + | vir_mid::Type::Sequence(_) + | vir_mid::Type::Map(_) + ) + { + // This is sound only for primitive types. + let from_bytes = format!("from_bytes${}", ty.get_identifier()); + self.declare_domain_function( + &domain_name, + std::borrow::Cow::Owned(from_bytes.clone()), + false, + std::borrow::Cow::Owned(vec![vir_low::VariableDecl::new( + "bytes", + return_type.clone(), + )]), + std::borrow::Cow::Owned(domain_type.clone()), + )?; + use vir_low::macros::*; + let to_bytes_call = vir_low::Expression::domain_function_call( + domain_name.clone(), + to_bytes, + vec![snapshot.clone().into()], + return_type, + ); + let from_bytes_call = vir_low::Expression::domain_function_call( + domain_name.clone(), + from_bytes, + vec![to_bytes_call.clone()], + domain_type, + ); + let body = vir_low::Expression::forall( + vec![snapshot.clone()], + vec![vir_low::Trigger::new(vec![to_bytes_call])], + expr! { + snapshot == [ from_bytes_call ] + }, + ); + let axiom = vir_low::DomainAxiomDecl { + // We use ty identifier to distinguish sequences from arrays. + name: format!("{}${}$to_bytes_injective", domain_name, ty.get_identifier()), + comment: None, + body, + }; + self.declare_axiom(&domain_name, axiom)?; + } } Ok(()) } diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/constructor.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/constructor.rs new file mode 100644 index 00000000000..5141f39d8b4 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/constructor.rs @@ -0,0 +1,378 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + lowerer::{DomainsLowererInterface, Lowerer}, + places::PlacesInterface, + pointers::PointersInterface, + predicates::PredicatesOwnedInterface, + snapshots::{ + IntoSnapshot, IntoSnapshotLowerer, SnapshotDomainsInterface, SnapshotValidityInterface, + SnapshotValuesInterface, + }, + }, +}; +use rustc_hash::FxHashMap; +use std::collections::BTreeMap; +use vir_crate::{ + common::position::Positioned, + low::{self as vir_low}, + middle::{self as vir_mid, operations::ty::Typed}, +}; + +use super::PredicateKind; + +pub(in super::super::super::super) struct AssertionToSnapshotConstructor<'a> { + predicate_kind: PredicateKind, + ty: &'a vir_mid::Type, + /// Arguments for the regular struct fields. + regular_field_arguments: Vec, + /// A map for replacing `self.field` with a matching argument. Used in + /// assign postcondition. + field_replacement_map: FxHashMap, + /// Mapping from deref fields to their positions in the arguments' list. + deref_fields: BTreeMap, + /// Which places are framed on the path being explored. + framed_places: Vec, + /// If Some, uses the heap variable instead of snap functions. + heap: Option, + /// Whether should wrap all snap calls into old. + is_in_old_state: bool, + /// A flag used to check whether a conditional has nested conditionals. + found_conditional: bool, + position: vir_low::Position, +} + +impl<'a> AssertionToSnapshotConstructor<'a> { + pub(in super::super::super::super) fn for_assign_aggregate_postcondition( + ty: &'a vir_mid::Type, + regular_field_arguments: Vec, + fields: Vec, + deref_fields: Vec<(vir_mid::Expression, String, vir_low::Type)>, + heap: Option, + position: vir_low::Position, + ) -> Self { + let field_replacement_map = fields + .into_iter() + .zip(regular_field_arguments.iter().cloned()) + .collect(); + let deref_fields = deref_fields + .into_iter() + .enumerate() + .map(|(i, (e, _, _))| (i + regular_field_arguments.len(), e)) + .collect(); + Self { + predicate_kind: PredicateKind::Owned, + ty, + regular_field_arguments, + field_replacement_map, + deref_fields, + framed_places: Vec::new(), + heap, + is_in_old_state: true, + found_conditional: false, + position, + } + } + + pub(in super::super::super::super) fn for_function_body( + predicate_kind: PredicateKind, + ty: &'a vir_mid::Type, + regular_field_arguments: Vec, + fields: Vec, + deref_fields: Vec<(vir_mid::Expression, String, vir_low::Type)>, + position: vir_low::Position, + ) -> Self { + let field_replacement_map = fields + .into_iter() + .zip(regular_field_arguments.iter().cloned()) + .collect(); + let deref_fields = deref_fields + .into_iter() + .enumerate() + .map(|(i, (e, _, _))| (i + regular_field_arguments.len(), e)) + .collect(); + Self { + predicate_kind, + ty, + regular_field_arguments, + field_replacement_map, + deref_fields, + framed_places: Vec::new(), + heap: None, + is_in_old_state: false, + found_conditional: false, + position, + } + } + + pub(in super::super::super::super) fn expression_to_snapshot_constructor<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + expression: &vir_mid::Expression, + ) -> SpannedEncodingResult { + let constructor_expression = self.expression_to_snapshot(lowerer, expression, true)?; + if self.found_conditional { + Ok(constructor_expression) + } else { + self.generate_snapshot_constructor(lowerer) + } + } + + // FIXME: Code duplication. + fn snap_call<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ty: &vir_mid::Type, + place: vir_low::Expression, + root_address: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + match &self.predicate_kind { + PredicateKind::Owned => lowerer.owned_non_aliased_snap( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address, + position, + ), + PredicateKind::FracRef { lifetime } => { + let TODO_target_slice_len = None; + lowerer.frac_ref_snap( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address, + lifetime.clone(), + TODO_target_slice_len, + ) + } + PredicateKind::UniqueRef { lifetime, is_final } => { + let TODO_target_slice_len = None; + lowerer.unique_ref_snap( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address, + lifetime.clone(), + TODO_target_slice_len, + *is_final, + ) + } + } + } + + fn generate_snapshot_constructor<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ) -> SpannedEncodingResult { + let mut arguments = self.regular_field_arguments.clone(); + for deref_field in self.deref_fields.clone().values() { + let ty = deref_field.get_type(); + let deref_field_snapshot = if self.framed_places.contains(deref_field) { + // The place is framed, generate the snap call. + if self.heap.is_some() { + self.expression_to_snapshot(lowerer, deref_field, false)? + } else { + let place = lowerer.encode_expression_as_place(deref_field)?; + let root_address = self.pointer_deref_into_address(lowerer, deref_field)?; + let snap_call = + self.snap_call(lowerer, ty, place, root_address, self.position)?; + // let snap_call = lowerer.owned_non_aliased_snap( + // CallContext::BuiltinMethod, + // ty, + // ty, + // place, + // root_address, + // self.position, + // )?; + if self.is_in_old_state { + vir_low::Expression::labelled_old(None, snap_call, self.position) + } else { + snap_call + } + } + } else { + // The place is not framed. Create a dangling (null) snapshot. + let domain_name = lowerer.encode_snapshot_domain_name(ty)?; + let function_name = format!("{}$dangling", domain_name); + let return_type = ty.to_snapshot(lowerer)?; + lowerer.create_unique_domain_func_app( + domain_name, + function_name, + Vec::new(), + return_type, + self.position, + )? + }; + arguments.push(deref_field_snapshot); + } + lowerer.construct_struct_snapshot(self.ty, arguments, self.position) + } + + // FIXME: Code duplication. + fn pointer_deref_into_address<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + if let Some(deref_place) = place.get_last_dereferenced_pointer() { + let base_snapshot = self.expression_to_snapshot(lowerer, deref_place, true)?; + let ty = deref_place.get_type(); + lowerer.pointer_address(ty, base_snapshot, place.position()) + } else { + unreachable!() + } + } + + fn conditional_branch_to_snapshot<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + branch: &vir_mid::Expression, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + self.found_conditional = false; + let old_framed_places_count = self.framed_places.len(); + let branch_snapshot = self.expression_to_snapshot(lowerer, branch, expect_math_bool)?; + let expression = if !self.found_conditional { + // We reached the lowest level, generate the snapshot constructor. + self.generate_snapshot_constructor(lowerer)? + } else { + branch_snapshot + }; + self.framed_places.truncate(old_framed_places_count); + Ok(expression) + } +} + +impl<'a, 'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> + for AssertionToSnapshotConstructor<'a> +{ + fn conditional_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + conditional: &vir_mid::Conditional, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + let guard_snapshot = self.expression_to_snapshot(lowerer, &conditional.guard, true)?; + + let then_expr_snapshot = + self.conditional_branch_to_snapshot(lowerer, &conditional.then_expr, expect_math_bool)?; + let else_expr_snapshot = + self.conditional_branch_to_snapshot(lowerer, &conditional.else_expr, expect_math_bool)?; + + self.found_conditional = true; + Ok(vir_low::Expression::conditional( + guard_snapshot, + then_expr_snapshot, + else_expr_snapshot, + conditional.position, + )) + } + + fn field_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + field: &vir_mid::Field, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + match &*field.base { + vir_mid::Expression::Local(local) + if local.variable.is_self_variable() + && self.field_replacement_map.contains_key(&field.field) => + { + Ok(self.field_replacement_map[&field.field].clone()) + } + _ => self.field_to_snapshot_impl(lowerer, field, expect_math_bool), + } + } + + fn variable_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + variable: &vir_mid::VariableDecl, + ) -> SpannedEncodingResult { + todo!() + } + + fn labelled_old_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + old: &vir_mid::LabelledOld, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn func_app_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + app: &vir_mid::FuncApp, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn acc_predicate_to_snapshot( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + acc_predicate: &vir_mid::AccPredicate, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + match &*acc_predicate.predicate { + vir_mid::Predicate::LifetimeToken(_) + | vir_mid::Predicate::MemoryBlockStack(_) + | vir_mid::Predicate::MemoryBlockStackDrop(_) => { + unreachable!(); + } + vir_mid::Predicate::MemoryBlockHeap(_) + | vir_mid::Predicate::MemoryBlockHeapRange(_) + | vir_mid::Predicate::MemoryBlockHeapDrop(_) => { + // Do nothing. + } + vir_mid::Predicate::OwnedNonAliased(predicate) => { + self.framed_places.push(predicate.place.clone()); + } + vir_mid::Predicate::OwnedRange(_) => todo!(), + vir_mid::Predicate::OwnedSet(_) => todo!(), + } + Ok(true.into()) + } + + // FIXME: Code duplication. + fn pointer_deref_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + deref: &vir_mid::Deref, + base_snapshot: vir_low::Expression, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + let heap = self + .heap + .clone() + .expect("This function should be reachable only when heap is Some"); + lowerer.pointer_target_snapshot_in_heap( + deref.base.get_type(), + heap, + base_snapshot, + deref.position, + ) + } + + fn call_context(&self) -> CallContext { + todo!() + } + + fn owned_non_aliased_snap( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ty: &vir_mid::Type, + pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + todo!() + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/mod.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/mod.rs new file mode 100644 index 00000000000..1fe71adb3d4 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/mod.rs @@ -0,0 +1,30 @@ +use vir_crate::low::{self as vir_low}; + +/// Assertions that are self-framing: each dereference of a pointer needs to be +/// behind `own`. +mod self_framing; +/// Assertions where the places (leaves) are translated to `snap` calls. +mod snap; +/// Assertions where the places are translated by using `heap$` pure variable. +mod pure_heap; +/// The snapshot validity assertion. +mod validity; +/// Structural invariant that needs to be translated into a snapshot +/// constructor. +mod constructor; + +pub(in super::super::super::super) enum PredicateKind { + Owned, + FracRef { + lifetime: vir_low::Expression, + }, + UniqueRef { + lifetime: vir_low::Expression, + is_final: bool, + }, +} + +pub(in super::super::super) use self::{ + constructor::AssertionToSnapshotConstructor, self_framing::SelfFramingAssertionToSnapshot, + validity::ValidityAssertionToSnapshot, +}; diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/pure_heap.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/pure_heap.rs new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/pure_heap.rs @@ -0,0 +1 @@ + diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/self_framing.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/self_framing.rs new file mode 100644 index 00000000000..9b8bd7c6aa1 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/self_framing.rs @@ -0,0 +1,463 @@ +use super::PredicateKind; +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + lowerer::Lowerer, + places::PlacesInterface, + pointers::PointersInterface, + predicates::{PredicatesMemoryBlockInterface, PredicatesOwnedInterface}, + snapshots::IntoSnapshotLowerer, + }, +}; +use rustc_hash::FxHashMap; +use vir_crate::{ + common::position::Positioned, + low::{self as vir_low}, + middle::{self as vir_mid, operations::ty::Typed}, +}; + +// Based on +// prusti-viper/src/encoder/middle/core_proof/predicates/owned/builders/owned_non_aliased/predicate_decl.rs, +// whch should be deleted. +pub(in super::super::super::super::super) struct SelfFramingAssertionToSnapshot { + predicate_kind: PredicateKind, + created_predicate_types: Vec, + /// Mapping from place to snapshot. We use a vector because we need to know + /// the insertion order. + snap_calls: Vec<(vir_mid::Expression, vir_low::Expression)>, + // Fields for encoding predicate body. In a language with inheritance, we + // would have `place` and `root_address` in a subclass. However, in Rust we + // need to play with `if` statements. + place: Option, + root_address: Option, + /// A map for replacing `self.field` with a matching argument. Used in + /// assign postcondition. + field_replacement_map: FxHashMap, + heap: Option, +} + +impl SelfFramingAssertionToSnapshot { + pub(in super::super::super::super::super) fn for_predicate_body( + place: vir_low::VariableDecl, + root_address: vir_low::VariableDecl, + predicate_kind: PredicateKind, + ) -> Self { + Self { + predicate_kind, + created_predicate_types: Vec::new(), + snap_calls: Vec::new(), + place: Some(place), + root_address: Some(root_address), + field_replacement_map: FxHashMap::default(), + heap: None, + } + } + + pub(in super::super::super::super::super) fn for_assign_precondition( + regular_field_arguments: Vec, + fields: Vec, + heap: Option, + ) -> Self { + let field_replacement_map = fields + .into_iter() + .zip(regular_field_arguments.iter().cloned()) + .collect(); + Self { + predicate_kind: PredicateKind::Owned, + created_predicate_types: Vec::new(), + snap_calls: Vec::new(), + place: None, + root_address: None, + field_replacement_map, + heap, + } + } + + pub(in super::super::super::super::super) fn into_created_predicate_types( + self, + ) -> Vec { + self.created_predicate_types + } + + fn is_predicate_body(&self) -> bool { + self.place.is_some() + } + + fn predicate_place(&self) -> vir_low::Expression { + self.place.clone().unwrap().into() + } + + fn predicate_root_address(&self) -> vir_low::Expression { + self.root_address.clone().unwrap().into() + } + + // FIXME: Code duplication. + fn pointer_deref_into_address<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + if let Some(deref_place) = place.get_last_dereferenced_pointer() { + let base_snapshot = self.expression_to_snapshot(lowerer, deref_place, true)?; + let ty = deref_place.get_type(); + lowerer.pointer_address(ty, base_snapshot, place.position()) + } else { + unreachable!() + } + } + + // FIXME: Code duplication. + fn snap_call<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ty: &vir_mid::Type, + place: vir_low::Expression, + root_address: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + match &self.predicate_kind { + PredicateKind::Owned => lowerer.owned_non_aliased_snap( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address, + position, + ), + PredicateKind::FracRef { lifetime } => { + let TODO_target_slice_len = None; + lowerer.frac_ref_snap( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address, + lifetime.clone(), + TODO_target_slice_len, + ) + } + PredicateKind::UniqueRef { lifetime, is_final } => { + assert!(!is_final); + let TODO_target_slice_len = None; + lowerer.unique_ref_snap( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address, + lifetime.clone(), + TODO_target_slice_len, + false, + ) + } + } + } + + fn predicate<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ty: &vir_mid::Type, + place: vir_low::Expression, + root_address: vir_low::Expression, + ) -> SpannedEncodingResult { + self.created_predicate_types.push(ty.clone()); + let snapshot = true.into(); // Will not be used. + match &self.predicate_kind { + PredicateKind::Owned => lowerer.owned_non_aliased_predicate( + CallContext::BuiltinMethod, + ty, + ty, + place.clone(), + root_address.clone(), + snapshot, + None, + ), + PredicateKind::FracRef { lifetime } => { + let TODO_target_slice_len = None; + lowerer.frac_ref_predicate( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address, + snapshot, + lifetime.clone(), + TODO_target_slice_len, + ) + } + PredicateKind::UniqueRef { lifetime, is_final } => { + assert!(!is_final); + let TODO_target_slice_len = None; + let final_snapshot = lowerer.unique_ref_snap( + CallContext::BuiltinMethod, + ty, + ty, + place.clone(), + root_address.clone(), + lifetime.clone(), + TODO_target_slice_len.clone(), + true, + )?; + lowerer.unique_ref_predicate( + CallContext::BuiltinMethod, + ty, + ty, + place, + root_address, + snapshot.clone(), + final_snapshot, + lifetime.clone(), + TODO_target_slice_len, + ) + } + } + } +} + +impl<'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for SelfFramingAssertionToSnapshot { + fn expression_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + expression: &vir_mid::Expression, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + for (place, call) in &self.snap_calls { + if place == expression { + return Ok(call.clone()); + } + } + self.expression_to_snapshot_impl(lowerer, expression, expect_math_bool) + } + + fn binary_op_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + op: &vir_mid::BinaryOp, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + let mut introduced_snap = false; + if op.op_kind == vir_mid::BinaryOpKind::And { + if let box vir_mid::Expression::AccPredicate(expression) = &op.left { + if expression.predicate.is_owned_non_aliased() { + // The recursive call to `acc_predicate_to_snapshot` will + // add a snap call to `self.snap_calls`. + introduced_snap = true; + } + } + } + let expression = self.binary_op_to_snapshot_impl(lowerer, op, expect_math_bool)?; + if introduced_snap { + self.snap_calls.pop(); + } + Ok(expression) + } + + fn field_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + field: &vir_mid::Field, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + match &*field.base { + vir_mid::Expression::Local(local) if self.is_predicate_body() => { + assert!(local.variable.is_self_variable()); + let field_place = lowerer.encode_field_place( + &local.variable.ty, + &field.field, + self.predicate_place(), + field.position, + )?; + self.snap_call( + lowerer, + &field.field.ty, + field_place, + self.predicate_root_address(), + local.position, + ) + } + vir_mid::Expression::Local(local) + if local.variable.is_self_variable() + && self.field_replacement_map.contains_key(&field.field) => + { + Ok(self.field_replacement_map[&field.field].clone()) + } + _ => self.field_to_snapshot_impl(lowerer, field, expect_math_bool), + } + } + + fn variable_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + variable: &vir_mid::VariableDecl, + ) -> SpannedEncodingResult { + assert!(variable.is_self_variable(), "{} must be self", variable); + Ok(vir_low::VariableDecl { + name: variable.name.clone(), + ty: self.type_to_snapshot(lowerer, &variable.ty)?, + }) + } + + fn labelled_old_to_snapshot( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _old: &vir_mid::LabelledOld, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn func_app_to_snapshot( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _app: &vir_mid::FuncApp, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn acc_predicate_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + acc_predicate: &vir_mid::AccPredicate, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + assert!(expect_math_bool); + let expression = match &*acc_predicate.predicate { + vir_mid::Predicate::OwnedNonAliased(predicate) => { + let ty = predicate.place.get_type(); + let place = lowerer.encode_expression_as_place(&predicate.place)?; + let root_address = self.pointer_deref_into_address(lowerer, &predicate.place)?; + + let (acc, snap_call) = if self.heap.is_some() { + let snapshot = self.expression_to_snapshot(lowerer, &predicate.place, true)?; + let acc = lowerer.owned_non_aliased( + CallContext::BuiltinMethod, + ty, + ty, + place.clone(), + root_address.clone(), + snapshot.clone(), + None, + )?; + (acc, snapshot) + } else { + // let snapshot = true.into(); // Will not be used. + let acc = self.predicate(lowerer, ty, place.clone(), root_address.clone())?; + // let acc = lowerer.owned_non_aliased_predicate( + // CallContext::BuiltinMethod, + // ty, + // ty, + // place.clone(), + // root_address.clone(), + // snapshot, + // None, + // )?; + let snap_call = self.snap_call( + lowerer, + &ty, + place, + root_address, + predicate.place.position(), + )?; + // let snap_call = lowerer.owned_non_aliased_snap( + // CallContext::BuiltinMethod, + // ty, + // ty, + // place, + // root_address, + // predicate.place.position(), + // )?; + (acc, snap_call) + }; + self.snap_calls.push((predicate.place.clone(), snap_call)); + acc + } + vir_mid::Predicate::MemoryBlockHeap(predicate) => { + match self.predicate_kind { + PredicateKind::Owned => { + let place = lowerer.encode_expression_as_place(&predicate.address)?; + let root_address = + self.pointer_deref_into_address(lowerer, &predicate.address)?; + use vir_low::macros::*; + let compute_address = ty!(Address); + let address = expr! { + ComputeAddress::compute_address([place], [root_address]) + }; + let size = self.expression_to_snapshot( + lowerer, + &predicate.size, + expect_math_bool, + )?; + lowerer.encode_memory_block_stack_acc( + address, + size, + acc_predicate.position, + )? + } + PredicateKind::FracRef { .. } | PredicateKind::UniqueRef { .. } => { + // Memory blocks are not accessible in frac/unique ref predicates. + true.into() + } + } + } + vir_mid::Predicate::MemoryBlockHeapDrop(predicate) => { + match self.predicate_kind { + PredicateKind::Owned => { + let place = self.pointer_deref_into_address(lowerer, &predicate.address)?; + let size = self.expression_to_snapshot( + lowerer, + &predicate.size, + expect_math_bool, + )?; + lowerer.encode_memory_block_heap_drop_acc( + place, + size, + acc_predicate.position, + )? + } + PredicateKind::FracRef { .. } | PredicateKind::UniqueRef { .. } => { + // Memory blocks are not accessible in frac/unique ref predicates. + true.into() + } + } + } + _ => unimplemented!("{acc_predicate}"), + }; + Ok(expression) + } + + // FIXME: Code duplication. + fn pointer_deref_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + deref: &vir_mid::Deref, + base_snapshot: vir_low::Expression, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + let heap = self + .heap + .clone() + .expect("This function should be reachable only when heap is Some"); + lowerer.pointer_target_snapshot_in_heap( + deref.base.get_type(), + heap, + base_snapshot, + deref.position, + ) + } + + fn call_context(&self) -> CallContext { + todo!() + } + + fn owned_non_aliased_snap( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _ty: &vir_mid::Type, + _pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + todo!() + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/snap.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/snap.rs new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/snap.rs @@ -0,0 +1 @@ + diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/validity.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/validity.rs new file mode 100644 index 00000000000..3ec81230eac --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/assertions/validity.rs @@ -0,0 +1,139 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{ + builtin_methods::CallContext, + lowerer::Lowerer, + snapshots::{IntoSnapshotLowerer, SnapshotValidityInterface}, + }, +}; +use std::collections::BTreeMap; +use vir_crate::{ + low::{self as vir_low}, + middle::{self as vir_mid, operations::ty::Typed}, +}; + +pub(in super::super::super::super::super) struct ValidityAssertionToSnapshot { + framed_places: Vec, + deref_fields: BTreeMap, +} + +impl ValidityAssertionToSnapshot { + pub(in super::super::super::super::super) fn new( + deref_fields: Vec<(vir_mid::Expression, String, vir_low::Type)>, + ) -> Self { + Self { + framed_places: Vec::new(), + deref_fields: deref_fields + .into_iter() + .map(|(owned_place, name, ty)| (owned_place, vir_low::VariableDecl::new(name, ty))) + .collect(), + } + } +} + +impl<'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for ValidityAssertionToSnapshot { + fn expression_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + expression: &vir_mid::Expression, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + if let Some(field) = self.deref_fields.get(expression) { + assert!( + self.framed_places.contains(expression), + "The place {} must be framed", + expression + ); + Ok(field.clone().into()) + } else { + self.expression_to_snapshot_impl(lowerer, expression, expect_math_bool) + } + } + + fn variable_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + variable: &vir_mid::VariableDecl, + ) -> SpannedEncodingResult { + assert!(variable.is_self_variable(), "{} must be self", variable); + Ok(vir_low::VariableDecl { + name: variable.name.clone(), + ty: self.type_to_snapshot(lowerer, &variable.ty)?, + }) + } + + fn labelled_old_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + old: &vir_mid::LabelledOld, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn func_app_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + app: &vir_mid::FuncApp, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn binary_op_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + op: &vir_mid::BinaryOp, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + let mut introduced_snap = false; + if op.op_kind == vir_mid::BinaryOpKind::And { + if let box vir_mid::Expression::AccPredicate(expression) = &op.left { + if let vir_mid::Predicate::OwnedNonAliased(predicate) = &*expression.predicate { + self.framed_places.push(predicate.place.clone()); + introduced_snap = true; + } + } + } + let expression = self.binary_op_to_snapshot_impl(lowerer, op, expect_math_bool)?; + if introduced_snap { + self.framed_places.pop(); + } + Ok(expression) + } + + fn acc_predicate_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + acc_predicate: &vir_mid::AccPredicate, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + assert!(expect_math_bool); + let expression = match &*acc_predicate.predicate { + vir_mid::Predicate::OwnedNonAliased(predicate) => { + self.framed_places.push(predicate.place.clone()); + let place = self.expression_to_snapshot(lowerer, &predicate.place, false)?; + self.framed_places.pop(); + lowerer.encode_snapshot_valid_call_for_type(place, predicate.place.get_type())? + } + vir_mid::Predicate::MemoryBlockHeap(_) | vir_mid::Predicate::MemoryBlockHeapDrop(_) => { + true.into() + } + _ => unimplemented!("{acc_predicate}"), + }; + Ok(expression) + } + + fn call_context(&self) -> CallContext { + todo!() + } + + fn owned_non_aliased_snap( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ty: &vir_mid::Type, + pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + todo!() + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/builtin_methods/mod.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/builtin_methods/mod.rs index 21ff7d3ab21..bafe2d3c7e2 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/builtin_methods/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/builtin_methods/mod.rs @@ -1,7 +1,10 @@ use super::common::IntoSnapshotLowerer; use crate::encoder::{ errors::SpannedEncodingResult, - middle::core_proof::lowerer::{FunctionsLowererInterface, Lowerer}, + middle::core_proof::{ + builtin_methods::CallContext, + lowerer::{FunctionsLowererInterface, Lowerer}, + }, }; use vir_crate::{ common::identifier::WithIdentifier, @@ -56,4 +59,35 @@ impl<'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for BuiltinMethodSn // In pure contexts values cannot be mutated, so `old` has no effect. self.expression_to_snapshot(lowerer, &old.base, expect_math_bool) } + + fn acc_predicate_to_snapshot( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _predicate: &vir_mid::AccPredicate, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + unreachable!() + } + + fn owned_non_aliased_snap( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _ty: &vir_mid::Type, + _pointer_snapshot: &vir_mid::Expression, + ) -> SpannedEncodingResult { + unimplemented!() + } + + fn call_context(&self) -> CallContext { + CallContext::BuiltinMethod + } + + // fn unfolding_to_snapshot( + // &mut self, + // lowerer: &mut Lowerer<'p, 'v, 'tcx>, + // unfolding: &vir_mid::Unfolding, + // expect_math_bool: bool, + // ) -> SpannedEncodingResult { + // todo!() + // } } diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/common/mod.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/common/mod.rs index f88ff2b7ac6..015eb2164de 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/common/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/common/mod.rs @@ -3,20 +3,32 @@ use crate::encoder::{ errors::SpannedEncodingResult, high::types::HighTypeEncoderInterface, middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, lifetimes::*, lowerer::DomainsLowererInterface, + places::PlacesInterface, + pointers::PointersInterface, + predicates::PredicatesOwnedInterface, references::ReferencesInterface, - snapshots::{IntoSnapshot, SnapshotDomainsInterface, SnapshotValuesInterface}, + snapshots::{ + IntoSnapshot, SnapshotDomainsInterface, SnapshotValidityInterface, + SnapshotValuesInterface, + }, types::TypesInterface, }, }; use vir_crate::{ - common::{identifier::WithIdentifier, position::Positioned}, + common::{ + expression::BinaryOperationHelpers, identifier::WithIdentifier, position::Positioned, + }, low::{self as vir_low}, middle::{self as vir_mid, operations::ty::Typed}, }; -pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { +pub(in super::super::super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v>: + Sized +{ fn expression_vec_to_snapshot( &mut self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, @@ -38,6 +50,15 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { lowerer: &mut Lowerer<'p, 'v, 'tcx>, expression: &vir_mid::Expression, expect_math_bool: bool, + ) -> SpannedEncodingResult { + self.expression_to_snapshot_impl(lowerer, expression, expect_math_bool) + } + + fn expression_to_snapshot_impl( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + expression: &vir_mid::Expression, + expect_math_bool: bool, ) -> SpannedEncodingResult { match expression { vir_mid::Expression::Local(expression) => { @@ -84,6 +105,12 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { self.builtin_func_app_to_snapshot(lowerer, expression, expect_math_bool) } // vir_mid::Expression::Downcast(expression) => self.downcast_to_snapshot(lowerer, expression, expect_math_bool), + vir_mid::Expression::AccPredicate(expression) => { + self.acc_predicate_to_snapshot(lowerer, expression, expect_math_bool) + } + vir_mid::Expression::Unfolding(expression) => { + self.unfolding_to_snapshot(lowerer, expression, expect_math_bool) + } x => unimplemented!("{:?}", x), } } @@ -106,7 +133,7 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { fn variable_to_snapshot( &mut self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, - local: &vir_mid::VariableDecl, + variable: &vir_mid::VariableDecl, ) -> SpannedEncodingResult; fn local_to_snapshot( @@ -131,7 +158,21 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { for argument in &constructor.arguments { arguments.push(self.expression_to_snapshot(lowerer, argument, false)?); } - lowerer.construct_struct_snapshot(&constructor.ty, arguments, constructor.position) + let struct_snapshot = + lowerer.construct_struct_snapshot(&constructor.ty, arguments, constructor.position)?; + if let vir_mid::Type::Enum(vir_mid::ty::Enum { + variant: Some(_), .. + }) = &constructor.ty + { + let enum_snapshot = lowerer.construct_enum_snapshot( + &constructor.ty, + struct_snapshot, + constructor.position, + )?; + Ok(enum_snapshot) + } else { + Ok(struct_snapshot) + } } fn variant_to_snapshot( @@ -155,6 +196,15 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { lowerer: &mut Lowerer<'p, 'v, 'tcx>, field: &vir_mid::Field, expect_math_bool: bool, + ) -> SpannedEncodingResult { + self.field_to_snapshot_impl(lowerer, field, expect_math_bool) + } + + fn field_to_snapshot_impl( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + field: &vir_mid::Field, + expect_math_bool: bool, ) -> SpannedEncodingResult { let base_snapshot = self.expression_to_snapshot(lowerer, &field.base, expect_math_bool)?; let result = if field.field.is_discriminant() { @@ -187,14 +237,25 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { expect_math_bool: bool, ) -> SpannedEncodingResult { let base_snapshot = self.expression_to_snapshot(lowerer, &deref.base, expect_math_bool)?; - let result = lowerer.reference_target_current_snapshot( - deref.base.get_type(), - base_snapshot, - Default::default(), - )?; + let ty = deref.base.get_type(); + let result = if ty.is_reference() { + lowerer.reference_target_current_snapshot(ty, base_snapshot, deref.position)? + } else { + self.pointer_deref_to_snapshot(lowerer, deref, base_snapshot, expect_math_bool)? + }; self.ensure_bool_expression(lowerer, deref.get_type(), result, expect_math_bool) } + fn pointer_deref_to_snapshot( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _deref: &vir_mid::Deref, + _base_snapshot: vir_low::Expression, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + unreachable!("Should be overriden."); + } + fn addr_of_to_snapshot( &mut self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, @@ -316,6 +377,15 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { lowerer: &mut Lowerer<'p, 'v, 'tcx>, op: &vir_mid::BinaryOp, expect_math_bool: bool, + ) -> SpannedEncodingResult { + self.binary_op_to_snapshot_impl(lowerer, op, expect_math_bool) + } + + fn binary_op_to_snapshot_impl( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + op: &vir_mid::BinaryOp, + expect_math_bool: bool, ) -> SpannedEncodingResult { // FIXME: Binary Operations with MPerm should not be handled manually as special cases // They are difficult because binary operations with MPerm and Integer values are allowed. @@ -376,15 +446,26 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { self.expression_to_snapshot(lowerer, &op.right, expect_math_bool_args)?; let arg_type = op.left.get_type().clone().erase_lifetimes(); assert_eq!(arg_type, op.right.get_type().clone().erase_lifetimes()); - let result = lowerer.construct_binary_op_snapshot( - op.op_kind, - ty, - &arg_type, - left_snapshot, - right_snapshot, - op.position, - )?; - self.ensure_bool_expression(lowerer, ty, result, expect_math_bool) + if expect_math_bool && op.op_kind == vir_mid::BinaryOpKind::EqCmp { + // FIXME: Instead of this ad-hoc optimization, have a proper + // optimization pass. + Ok(vir_low::Expression::binary_op( + vir_low::BinaryOpKind::EqCmp, + left_snapshot, + right_snapshot, + op.position, + )) + } else { + let result = lowerer.construct_binary_op_snapshot( + op.op_kind, + ty, + &arg_type, + left_snapshot, + right_snapshot, + op.position, + )?; + self.ensure_bool_expression(lowerer, ty, result, expect_math_bool) + } } fn binary_op_kind_to_snapshot( @@ -422,8 +503,11 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { self.expression_to_snapshot(lowerer, &conditional.then_expr, expect_math_bool)?; let else_expr_snapshot = self.expression_to_snapshot(lowerer, &conditional.else_expr, expect_math_bool)?; - let arg_type = conditional.then_expr.get_type(); - assert_eq!(arg_type, conditional.else_expr.get_type()); + let arg_type = vir_low::operations::ty::Typed::get_type(&then_expr_snapshot); + assert_eq!( + arg_type, + vir_low::operations::ty::Typed::get_type(&else_expr_snapshot) + ); // We do not need to ensure expect_math_bool because we pushed this // responsibility to the arguments. Ok(vir_low::Expression::conditional( @@ -444,7 +528,7 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { fn builtin_func_app_to_snapshot( &mut self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, - app: &vir_crate::middle::expression::BuiltinFuncApp, + app: &vir_mid::BuiltinFuncApp, expect_math_bool: bool, ) -> SpannedEncodingResult { use vir_low::expression::ContainerOpKind; @@ -514,6 +598,17 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { app.position, ) } + BuiltinFunc::Align => { + assert_eq!(args.len(), 0); + let return_type = self.type_to_snapshot(lowerer, &app.return_type)?; + lowerer.create_domain_func_app( + "Align", + app.get_identifier(), + args, + return_type, + app.position, + ) + } BuiltinFunc::Discriminant => { assert_eq!(args.len(), 1); let discriminant_call = lowerer.obtain_enum_discriminant( @@ -667,9 +762,167 @@ pub(super) trait IntoSnapshotLowerer<'p, 'v: 'p, 'tcx: 'v> { lowerer.construct_constant_snapshot(&vir_mid::Type::Bool, value, app.position) } } + BuiltinFunc::IsNull => { + assert_eq!(args.len(), 1); + let ty = app.arguments[0].get_type(); + let address = lowerer.pointer_address(ty, args[0].clone(), app.position)?; + let null_address = lowerer.address_null(app.position)?; + let equals = vir_low::Expression::equals(address, null_address); + let equals = + lowerer.construct_constant_snapshot(app.get_type(), equals, app.position)?; + self.ensure_bool_expression(lowerer, app.get_type(), equals, expect_math_bool) + } + BuiltinFunc::IsValid => { + assert_eq!(app.arguments.len(), 1); + let argument = args.pop().unwrap(); + let ty = app.arguments[0].get_type(); + lowerer.encode_snapshot_valid_call_for_type(argument, ty) + } + BuiltinFunc::EnsureOwnedPredicate => { + assert_eq!(app.arguments.len(), 1); + fn peel_unfolding<'p, 'v: 'p, 'tcx: 'v>( + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + into_snap_lowerer: &mut impl IntoSnapshotLowerer<'p, 'v, 'tcx>, + place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + match place { + vir_mid::Expression::Unfolding(unfolding) => { + let body = peel_unfolding(lowerer, into_snap_lowerer, &unfolding.body)?; + into_snap_lowerer.unfolding_to_snapshot_with_body( + lowerer, + &unfolding.predicate, + body, + unfolding.position, + true, + ) + } + _ => { + let ty = place.get_type(); + let snap_call = + into_snap_lowerer.owned_non_aliased_snap(lowerer, ty, place)?; + let snapshot = + into_snap_lowerer.expression_to_snapshot(lowerer, place, true)?; + let position = place.position(); + Ok(vir_low::Expression::binary_op( + vir_low::BinaryOpKind::EqCmp, + snap_call, + snapshot, + position, + )) + } + } + } + peel_unfolding(lowerer, self, &app.arguments[0]) + // let argument = &app.arguments[0]; + // let ty = argument.get_type(); + // let snap_call = self.owned_non_aliased_snap(lowerer, ty, argument)?; + // lowerer.wrap_snap_into_bool(ty, snap_call.set_default_position(app.position)) + } + BuiltinFunc::TakeLifetime => { + unimplemented!("TODO: Delete"); + } } } + fn acc_predicate_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + acc_predicate: &vir_mid::AccPredicate, + expect_math_bool: bool, + ) -> SpannedEncodingResult; + + // fn unfolding_to_snapshot( + // &mut self, + // lowerer: &mut Lowerer<'p, 'v, 'tcx>, + // unfolding: &vir_mid::Unfolding, + // expect_math_bool: bool, + // ) -> SpannedEncodingResult; + + fn call_context(&self) -> CallContext; + + fn unfolding_to_snapshot_with_body( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + predicate: &vir_mid::Predicate, + body: vir_low::Expression, + position: vir_low::Position, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + let predicate = match predicate { + vir_mid::Predicate::OwnedNonAliased(predicate) => { + let ty = predicate.place.get_type(); + lowerer.mark_owned_non_aliased_as_unfolded(ty)?; + let place = lowerer.encode_expression_as_place(&predicate.place)?; + let root_address = lowerer.extract_root_address(&predicate.place)?; + let snapshot = + self.expression_to_snapshot(lowerer, &predicate.place, expect_math_bool)?; + // predicate.place.to_procedure_snapshot(lowerer)?; // FIXME: This is probably wrong. It should take into account the current old. + lowerer + .owned_non_aliased_predicate( + self.call_context(), + ty, + ty, + place, + root_address, + snapshot, + None, + )? + .unwrap_predicate_access_predicate() + } + _ => unimplemented!("{predicate}"), + }; + let expression = vir_low::Expression::unfolding(predicate, body, position); + Ok(expression) + } + + fn unfolding_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + unfolding: &vir_mid::Unfolding, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + let body = self.expression_to_snapshot(lowerer, &unfolding.body, expect_math_bool)?; + self.unfolding_to_snapshot_with_body( + lowerer, + &unfolding.predicate, + body, + unfolding.position, + expect_math_bool, + ) + // let predicate = match &*unfolding.predicate { + // vir_mid::Predicate::OwnedNonAliased(predicate) => { + // let ty = predicate.place.get_type(); + // lowerer.mark_owned_non_aliased_as_unfolded(ty)?; + // let place = lowerer.encode_expression_as_place(&predicate.place)?; + // let root_address = lowerer.extract_root_address(&predicate.place)?; + // let snapshot = self.expression_to_snapshot(lowerer, &predicate.place, expect_math_bool)?; + // // predicate.place.to_procedure_snapshot(lowerer)?; // FIXME: This is probably wrong. It should take into account the current old. + // lowerer + // .owned_non_aliased_predicate( + // self.call_context(), + // ty, + // ty, + // place, + // root_address, + // snapshot, + // None, + // )? + // .unwrap_predicate_access_predicate() + // } + // _ => unimplemented!("{unfolding}"), + // }; + // let body = self.expression_to_snapshot(lowerer, &unfolding.body, expect_math_bool)?; + // let expression = vir_low::Expression::unfolding(predicate, body, unfolding.position); + // Ok(expression) + } + + fn owned_non_aliased_snap( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ty: &vir_mid::Type, + pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult; + fn type_to_snapshot( &mut self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/context_independent/mod.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/context_independent/mod.rs index 4b17e40976f..e1e6b55dbb7 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/context_independent/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/context_independent/mod.rs @@ -2,7 +2,10 @@ //! the context. Currently, the only example is types. use super::common::IntoSnapshotLowerer; -use crate::encoder::{errors::SpannedEncodingResult, middle::core_proof::lowerer::Lowerer}; +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::{builtin_methods::CallContext, lowerer::Lowerer}, +}; use vir_crate::{ low::{self as vir_low}, middle::{self as vir_mid}, @@ -41,4 +44,35 @@ impl<'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for ContextIndepend ) -> SpannedEncodingResult { unreachable!("requested context dependent encoding"); } + + fn acc_predicate_to_snapshot( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _predicate: &vir_mid::AccPredicate, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + unreachable!("requested context dependent encoding"); + } + + fn owned_non_aliased_snap( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _ty: &vir_mid::Type, + _pointer_snapshot: &vir_mid::Expression, + ) -> SpannedEncodingResult { + unimplemented!() + } + + // fn unfolding_to_snapshot( + // &mut self, + // lowerer: &mut Lowerer<'p, 'v, 'tcx>, + // unfolding: &vir_mid::Unfolding, + // expect_math_bool: bool, + // ) -> SpannedEncodingResult { + // todo!() + // } + + fn call_context(&self) -> CallContext { + todo!() + } } diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/expressions/mod.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/expressions/mod.rs new file mode 100644 index 00000000000..8c7a0b6594e --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/expressions/mod.rs @@ -0,0 +1,10 @@ +/// Expressions that are framed and used in pure contexts. For example, pure +/// function bodies. +mod pure_framed; +/// Expressions to be used in procedure bodies. For example, arguments of +/// builtin methods. +mod procedure_bodies; + +pub(in super::super::super) use self::{ + procedure_bodies::ProcedureExpressionToSnapshot, pure_framed::FramedExpressionToSnapshot, +}; diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/expressions/procedure_bodies.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/expressions/procedure_bodies.rs new file mode 100644 index 00000000000..5fd15bd0980 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/expressions/procedure_bodies.rs @@ -0,0 +1,189 @@ +use super::super::PredicateKind; +use crate::encoder::{ + errors::SpannedEncodingResult, + high::types::HighTypeEncoderInterface, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + footprint::FootprintInterface, + lowerer::{FunctionsLowererInterface, Lowerer}, + places::PlacesInterface, + pointers::PointersInterface, + predicates::{PredicatesMemoryBlockInterface, PredicatesOwnedInterface}, + snapshots::{IntoSnapshotLowerer, SnapshotValuesInterface, SnapshotVariablesInterface}, + }, +}; +use rustc_hash::FxHashMap; +use vir_crate::{ + common::{identifier::WithIdentifier, position::Positioned}, + low::{self as vir_low}, + middle::{self as vir_mid, operations::ty::Typed}, +}; + +pub(in super::super::super::super::super) struct ProcedureExpressionToSnapshot { + old_label: Option, + predicate_kind: PredicateKind, + /// If `true`, uses SSA snapshots instead of `snap` calls for stack + /// variables. + use_pure_stack: bool, +} + +impl ProcedureExpressionToSnapshot { + pub(in super::super::super::super) fn for_address(predicate_kind: PredicateKind) -> Self { + Self { + old_label: None, + predicate_kind, + use_pure_stack: true, + } + } + + pub(in super::super::super::super) fn for_place(predicate_kind: PredicateKind) -> Self { + Self { + old_label: None, + predicate_kind, + use_pure_stack: false, + } + } + + fn snap_call<'p, 'v, 'tcx>( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ty: &vir_mid::Type, + pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + let place = lowerer.encode_expression_as_place(pointer_place)?; + let root_address = lowerer.extract_root_address(pointer_place)?; + match &self.predicate_kind { + PredicateKind::Owned => lowerer.owned_non_aliased_snap( + CallContext::Procedure, + ty, + ty, + place, + root_address, + pointer_place.position(), + ), + PredicateKind::FracRef { lifetime } => { + let TODO_target_slice_len = None; + lowerer.frac_ref_snap( + CallContext::Procedure, + ty, + ty, + place, + root_address, + lifetime.clone(), + TODO_target_slice_len, + ) + } + PredicateKind::UniqueRef { lifetime, is_final } => { + let TODO_target_slice_len = None; + lowerer.unique_ref_snap( + CallContext::Procedure, + ty, + ty, + place, + root_address, + lifetime.clone(), + TODO_target_slice_len, + *is_final, + ) + } + } + } +} + +impl<'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for ProcedureExpressionToSnapshot { + fn expression_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + expression: &vir_mid::Expression, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + assert!( + expression.is_place(), + "FIXME: Should this be renamed to place encoder and accept only places?" + ); + if lowerer.check_mode.unwrap().check_core_proof() { + // In the core proof checking mode, all places are obtained by + // applying the snap functions. To make the results predictable + // we always call the snap function on the place that the user + // wrote. + if self.use_pure_stack { + // This is a hack for addresses. See `Prusti places` HackMD for + // details. + if expression.get_last_dereferenced_pointer().is_some() { + let ty = expression.get_type(); + self.snap_call(lowerer, ty, expression) + } else { + self.expression_to_snapshot_impl(lowerer, expression, expect_math_bool) + } + } else { + let ty = expression.get_type(); + self.snap_call(lowerer, ty, expression) + } + } else { + self.expression_to_snapshot_impl(lowerer, expression, expect_math_bool) + } + } + + fn variable_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + variable: &vir_mid::VariableDecl, + ) -> SpannedEncodingResult { + if let Some(label) = &self.old_label { + lowerer.snapshot_variable_version_at_label(variable, label) + } else { + lowerer.current_snapshot_variable_version(variable) + } + } + + fn labelled_old_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + old: &vir_mid::LabelledOld, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn func_app_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + app: &vir_mid::FuncApp, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn acc_predicate_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + acc_predicate: &vir_mid::AccPredicate, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn call_context(&self) -> CallContext { + todo!() + } + + fn owned_non_aliased_snap( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ty: &vir_mid::Type, + pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + todo!() + } + + fn pointer_deref_to_snapshot( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _deref: &vir_mid::Deref, + _base_snapshot: vir_low::Expression, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + unreachable!("Should be overriden."); + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/expressions/pure_framed.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/expressions/pure_framed.rs new file mode 100644 index 00000000000..9587842c9fc --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/expressions/pure_framed.rs @@ -0,0 +1,162 @@ +use super::super::PredicateKind; +use crate::encoder::{ + errors::SpannedEncodingResult, + high::types::HighTypeEncoderInterface, + middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + footprint::FootprintInterface, + lowerer::{FunctionsLowererInterface, Lowerer}, + places::PlacesInterface, + pointers::PointersInterface, + predicates::{PredicatesMemoryBlockInterface, PredicatesOwnedInterface}, + snapshots::{IntoSnapshotLowerer, SnapshotValuesInterface}, + }, +}; +use rustc_hash::FxHashMap; +use vir_crate::{ + common::{identifier::WithIdentifier, position::Positioned}, + low::{self as vir_low}, + middle::{self as vir_mid, operations::ty::Typed}, +}; + +pub(in super::super::super::super::super) struct FramedExpressionToSnapshot<'a> { + framing_variables: &'a [vir_mid::VariableDecl], +} + +impl<'a> FramedExpressionToSnapshot<'a> { + pub(in super::super::super::super::super) fn for_function_body( + framing_variables: &'a [vir_mid::VariableDecl], + ) -> Self { + Self { framing_variables } + } + + /// Find a base of type struct that has an invariant. + fn obtain_invariant<'e>( + &mut self, + lowerer: &mut Lowerer, + expression: &'e vir_mid::Expression, + ) -> SpannedEncodingResult<( + &'e vir_mid::Expression, + vir_mid::Expression, + Vec, + )> { + let ty = expression.get_type(); + if ty.is_struct() { + let type_decl = lowerer.encoder.get_type_decl_mid(ty)?; + if let vir_mid::TypeDecl::Struct(vir_mid::type_decl::Struct { + structural_invariant: Some(invariant), + .. + }) = type_decl + { + let self_place = vir_mid::VariableDecl::self_variable(ty.clone()); + return Ok((expression, self_place.into(), invariant)); + } else { + unimplemented!("TODO: A proper error message that only permissions from non-nested structs are supported."); + } + } else { + let (base_place, parent, invariant) = self.obtain_invariant( + lowerer, + expression + .get_parent_ref() + .expect("TODO: A proper error message that the permission has to be framed."), + )?; + Ok((base_place, expression.with_new_parent(parent), invariant)) + } + } +} + +impl<'a, 'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> + for FramedExpressionToSnapshot<'a> +{ + fn variable_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + variable: &vir_mid::VariableDecl, + ) -> SpannedEncodingResult { + Ok(vir_low::VariableDecl { + name: variable.name.clone(), + ty: self.type_to_snapshot(lowerer, &variable.ty)?, + }) + } + + // FIXME: Code duplication with + // prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/pure/mod.rs + fn labelled_old_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + old: &vir_mid::LabelledOld, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + // In pure contexts values cannot be mutated, so `old` has no effect. + self.expression_to_snapshot(lowerer, &old.base, expect_math_bool) + } + + // FIXME: Code duplication with + // prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/pure/mod.rs + fn func_app_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + app: &vir_mid::FuncApp, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + let arguments = + self.expression_vec_to_snapshot(lowerer, &app.arguments, expect_math_bool)?; + let return_type = self.type_to_snapshot(lowerer, &app.return_type)?; + let func_app = lowerer.call_pure_function_in_pure_context( + app.get_identifier(), + arguments, + return_type, + app.position, + )?; + let result = vir_low::Expression::DomainFuncApp(func_app); + self.ensure_bool_expression(lowerer, &app.return_type, result, expect_math_bool) + } + + fn acc_predicate_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + acc_predicate: &vir_mid::AccPredicate, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + todo!() + } + + fn call_context(&self) -> CallContext { + todo!() + } + + fn owned_non_aliased_snap( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ty: &vir_mid::Type, + pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + todo!() + } + + fn pointer_deref_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + deref: &vir_mid::Deref, + _base_snapshot: vir_low::Expression, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + let (base_place, framed_place, invariant) = self.obtain_invariant(lowerer, &deref.base)?; + let framed_place = vir_mid::Expression::deref_no_pos(framed_place, deref.ty.clone()); + let deref_fields = lowerer.structural_invariant_to_deref_fields(&invariant)?; + let base_snapshot = self.expression_to_snapshot(lowerer, base_place, expect_math_bool)?; + for (deref_place, name, ty) in deref_fields { + if deref_place == framed_place { + return lowerer.pointer_target_as_snapshot_field( + base_place.get_type(), + &name, + ty, + base_snapshot, + deref.position, + ); + } + } + unimplemented!("TODO: A proper error message that failed to find a framing place.") + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/mod.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/mod.rs index 282b2ffede8..b0d0112e990 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/mod.rs @@ -1,22 +1,38 @@ -//! The traits for converting expressions into snapshots: -//! -//! + `procedure` contains the traits for converting in procedure contexts where -//! we need to use SSA form and `caller_for` for calling pure functions. -//! + `pure` contains the traits for converting in pure contexts such as axioms -//! and pure function definitions where we do not use neither SSA nor -//! `caller_for`. -//! + `builtin_methods` contains the traits for converting in builtin-method -//! contexts where we do not use SSA, but use `caller_for`. +//! The traits for converting expressions into snapshots. +/// Contains the traits for converting in builtin-method contexts where we do +/// not use SSA, but use `caller_for`. mod builtin_methods; +/// The trait that provides the general sceleton for converting expressions into +/// snapshots. mod common; +/// Contains the traits for converting elements into the snapshots where the +/// context does not matter. Currently, the only example is types. mod context_independent; +/// Contains the traits for converting in procedure contexts where we need to +/// use SSA form and `caller_for` for calling pure functions. mod procedure; +/// Contains the traits for converting in pure contexts such as axioms and pure +/// function definitions where we do not use neither SSA nor `caller_for`. mod pure; +/// Contains structs for converting assertions (potentially containing +/// accessibility predicates) to snapshots. +mod assertions; +/// Contains structs for converting expressions to snapshots. +mod expressions; pub(in super::super) use self::{ + assertions::{ + AssertionToSnapshotConstructor, PredicateKind, SelfFramingAssertionToSnapshot, + ValidityAssertionToSnapshot, + }, builtin_methods::IntoBuiltinMethodSnapshot, + common::IntoSnapshotLowerer, context_independent::IntoSnapshot, - procedure::{IntoProcedureBoolExpression, IntoProcedureFinalSnapshot, IntoProcedureSnapshot}, - pure::{IntoPureBoolExpression, IntoPureSnapshot}, + expressions::{FramedExpressionToSnapshot, ProcedureExpressionToSnapshot}, + procedure::{ + IntoProcedureAssertion, IntoProcedureBoolExpression, IntoProcedureFinalSnapshot, + IntoProcedureSnapshot, ProcedureSnapshot, + }, + pure::{IntoFramedPureSnapshot, IntoPureBoolExpression, IntoPureSnapshot}, }; diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/procedure/mod.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/procedure/mod.rs index 776b1e8e7c5..780313f7edb 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/procedure/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/procedure/mod.rs @@ -2,17 +2,23 @@ //! procedure bodies. Most important difference from `pure` is that this //! encoding uses SSA. -use super::common::IntoSnapshotLowerer; +use super::{common::IntoSnapshotLowerer, PredicateKind}; use crate::encoder::{ errors::SpannedEncodingResult, middle::core_proof::{ + addresses::AddressesInterface, + builtin_methods::CallContext, + lifetimes::LifetimesInterface, lowerer::{FunctionsLowererInterface, Lowerer}, + places::PlacesInterface, + pointers::PointersInterface, + predicates::{PredicatesMemoryBlockInterface, PredicatesOwnedInterface}, references::ReferencesInterface, snapshots::SnapshotVariablesInterface, }, }; use vir_crate::{ - common::identifier::WithIdentifier, + common::{identifier::WithIdentifier, position::Positioned}, low::{self as vir_low}, middle::{self as vir_mid, operations::ty::Typed}, }; @@ -20,16 +26,51 @@ use vir_crate::{ mod traits; pub(in super::super::super) use self::traits::{ - IntoProcedureBoolExpression, IntoProcedureFinalSnapshot, IntoProcedureSnapshot, + IntoProcedureAssertion, IntoProcedureBoolExpression, IntoProcedureFinalSnapshot, + IntoProcedureSnapshot, }; -#[derive(Default)] -struct ProcedureSnapshot { +pub(in super::super::super::super) struct ProcedureSnapshot { old_label: Option, deref_to_final: bool, + is_assertion: bool, + in_heap_assertions: Vec, + predicate_kind: PredicateKind, +} + +impl ProcedureSnapshot { + pub(in super::super) fn new_for_owned() -> Self { + Self { + old_label: None, + deref_to_final: false, + is_assertion: false, + in_heap_assertions: Vec::new(), + predicate_kind: PredicateKind::Owned, + } + } } impl<'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for ProcedureSnapshot { + fn expression_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + expression: &vir_mid::Expression, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + if !lowerer.use_heap_variable()? + && expression.is_place() + && expression.get_last_dereferenced_pointer().is_some() + { + // let address = lowerer.encode_expression_as_place_address(expression)?; + // let place = lowerer.encode_expression_as_place(expression)?; + // let root_address = lowerer.extract_root_address(expression)?; + let ty = expression.get_type(); + // return lowerer.owned_non_aliased_snap(CallContext::Procedure, ty, ty, place, root_address); + return self.owned_non_aliased_snap(lowerer, ty, expression); + } + self.expression_to_snapshot_impl(lowerer, expression, expect_math_bool) + } + fn variable_to_snapshot( &mut self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, @@ -87,17 +128,251 @@ impl<'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for ProcedureSnapsh lowerer.reference_target_final_snapshot( deref.base.get_type(), base_snapshot, - Default::default(), + deref.position, )? } else { let base_snapshot = self.expression_to_snapshot(lowerer, &deref.base, expect_math_bool)?; - lowerer.reference_target_current_snapshot( - deref.base.get_type(), - base_snapshot, - Default::default(), - )? + if deref.base.get_type().is_reference() { + lowerer.reference_target_current_snapshot( + deref.base.get_type(), + base_snapshot, + deref.position, + )? + } else { + lowerer.pointer_target_snapshot( + deref.base.get_type(), + &self.old_label, + base_snapshot, + deref.position, + )? + } }; self.ensure_bool_expression(lowerer, deref.get_type(), result, expect_math_bool) } + + fn acc_predicate_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + acc_predicate: &vir_mid::AccPredicate, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + assert!(self.is_assertion); + // fn in_heap<'p, 'v, 'tcx>( + // old_label: &Option, + // place: &vir_mid::Expression, + // lowerer: &mut Lowerer<'p, 'v, 'tcx>, + // ) -> SpannedEncodingResult { + // let in_heap = if let Some(pointer_place) = place.get_last_dereferenced_pointer() { + // let pointer = pointer_place.to_procedure_snapshot(lowerer)?; + // let address = + // lowerer.pointer_address(pointer_place.get_type(), pointer, place.position())?; + // let heap = lowerer.heap_variable_version_at_label(old_label)?; + // vir_low::Expression::container_op_no_pos( + // vir_low::ContainerOpKind::MapContains, + // heap.ty.clone(), + // vec![heap.into(), address], + // ) + // } else { + // unimplemented!("TODO: Proper error message: {:?}", place); + // }; + // Ok(in_heap) + // } + let expression = match &*acc_predicate.predicate { + vir_mid::Predicate::OwnedNonAliased(predicate) => { + let ty = predicate.place.get_type(); + let place = lowerer.encode_expression_as_place(&predicate.place)?; + let root_address = lowerer.extract_root_address(&predicate.place)?; + let snapshot = predicate.place.to_procedure_snapshot(lowerer)?; // FIXME: This is probably wrong. It should take into account the current old. + // if lowerer.use_heap_variable()? { + // let in_heap = in_heap(&self.old_label, &predicate.place, lowerer)?; + // self.in_heap_assertions.push(in_heap); + // } + // let acc = + unimplemented!(); + // lowerer.owned_aliased( + // CallContext::Procedure, + // ty, + // ty, + // place, + // root_address, + // snapshot, + // None, + // )? + // ; + // vir_low::Expression::and(in_heap, acc) + } + vir_mid::Predicate::OwnedRange(predicate) => { + let ty = predicate.address.get_type(); + let address = predicate.address.to_procedure_snapshot(lowerer)?; + let start_index = predicate.start_index.to_procedure_snapshot(lowerer)?; + let end_index = predicate.end_index.to_procedure_snapshot(lowerer)?; + lowerer.owned_aliased_range( + CallContext::Procedure, + ty, + ty, + address, + start_index, + end_index, + None, + )? + } + vir_mid::Predicate::MemoryBlockHeap(predicate) => { + let place = lowerer.encode_expression_as_place_address(&predicate.address)?; + let size = predicate.size.to_procedure_snapshot(lowerer)?; + // if lowerer.use_heap_variable()? { + // let in_heap = in_heap(&self.old_label, &predicate.address, lowerer)?; + // self.in_heap_assertions.push(in_heap); + // } + // let acc = + lowerer.encode_memory_block_stack_acc(place, size, acc_predicate.position)? + //; + // vir_low::Expression::and(in_heap, acc) + } + vir_mid::Predicate::MemoryBlockHeapRange(predicate) => { + let pointer_value = predicate.address.to_procedure_snapshot(lowerer)?; + let address = lowerer.pointer_address( + predicate.address.get_type(), + pointer_value, + predicate.position, + )?; + let size = predicate.size.to_procedure_snapshot(lowerer)?; + let start_index = predicate.start_index.to_procedure_snapshot(lowerer)?; + let end_index = predicate.end_index.to_procedure_snapshot(lowerer)?; + lowerer.encode_memory_block_range_acc( + address, + size, + start_index, + end_index, + acc_predicate.position, + )? + } + vir_mid::Predicate::MemoryBlockHeapDrop(predicate) => { + let place = lowerer.encode_expression_as_place_address(&predicate.address)?; + let size = predicate.size.to_procedure_snapshot(lowerer)?; + // if lowerer.use_heap_variable()? { + // let in_heap = in_heap(&self.old_label, &predicate.address, lowerer)?; + // self.in_heap_assertions.push(in_heap); + // } + // let acc = + lowerer.encode_memory_block_heap_drop_acc(place, size, acc_predicate.position)? + // ; + // vir_low::Expression::and(in_heap, acc) + } + _ => unimplemented!("{acc_predicate}"), + }; + Ok(expression) + } + + fn owned_non_aliased_snap( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ty: &vir_mid::Type, + pointer_place: &vir_mid::Expression, + ) -> SpannedEncodingResult { + unimplemented!(); + // let place = lowerer.encode_expression_as_place(pointer_place)?; + // let root_address = lowerer.extract_root_address(pointer_place)?; + // match &self.predicate_kind { + // PredicateKind::Owned => lowerer.owned_non_aliased_snap( + // CallContext::Procedure, + // ty, + // ty, + // place, + // root_address, + // pointer_place.position(), + // ), + // PredicateKind::FracRef { lifetime } => todo!(), + // PredicateKind::UniqueRef { lifetime, is_final } => { + // let TODO_target_slice_len = None; + // lowerer.unique_ref_snap( + // CallContext::Procedure, + // ty, + // ty, + // place, + // root_address, + // lifetime.clone(), + // TODO_target_slice_len, + // *is_final, + // ) + // } + // } + // if let Some(reference_place) = pointer_place.get_first_dereferenced_reference() { + // let vir_mid::Type::Reference(reference_type) = reference_place.get_type() else { + // unreachable!() + // }; + // let TODO_target_slice_len = None; + // let lifetime = lowerer + // .encode_lifetime_const_into_procedure_variable(reference_type.lifetime.clone())?; + // match reference_type.uniqueness { + // vir_mid::ty::Uniqueness::Unique => lowerer.unique_ref_snap( + // CallContext::Procedure, + // ty, + // ty, + // place, + // root_address, + // lifetime.into(), + // TODO_target_slice_len, + // self.deref_to_final, + // ), + // vir_mid::ty::Uniqueness::Shared => lowerer.frac_ref_snap( + // CallContext::Procedure, + // ty, + // ty, + // place, + // root_address, + // lifetime.into(), + // TODO_target_slice_len, + // ), + // } + // } else { + // lowerer.owned_non_aliased_snap( + // CallContext::Procedure, + // ty, + // ty, + // place, + // root_address, + // pointer_place.position(), + // ) + // } + // // TODO: Check whether the place is behind a shared/mutable reference and use the appropriate function + // eprintln!("pointer_place: {}", pointer_place); + // eprintln!("pointer_place: {:?}", pointer_place); + } + + fn call_context(&self) -> CallContext { + CallContext::Procedure + } + + // fn unfolding_to_snapshot( + // &mut self, + // lowerer: &mut Lowerer<'p, 'v, 'tcx>, + // unfolding: &vir_mid::Unfolding, + // expect_math_bool: bool, + // ) -> SpannedEncodingResult { + // let predicate = match &*unfolding.predicate { + // vir_mid::Predicate::OwnedNonAliased(predicate) => { + // let ty = predicate.place.get_type(); + // lowerer.mark_owned_non_aliased_as_unfolded(ty)?; + // let place = lowerer.encode_expression_as_place(&predicate.place)?; + // let root_address = lowerer.extract_root_address(&predicate.place)?; + // let snapshot = predicate.place.to_procedure_snapshot(lowerer)?; // FIXME: This is probably wrong. It should take into account the current old. + // lowerer + // .owned_non_aliased_predicate( + // CallContext::Procedure, + // ty, + // ty, + // place, + // root_address, + // snapshot, + // None, + // )? + // .unwrap_predicate_access_predicate() + // } + // _ => unimplemented!("{unfolding}"), + // }; + // let body = self.expression_to_snapshot(lowerer, &unfolding.body, expect_math_bool)?; + // let expression = vir_low::Expression::unfolding(predicate, body, unfolding.position); + // Ok(expression) + // } } diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/procedure/traits.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/procedure/traits.rs index 30a50fbb9a1..76cd89c4145 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/procedure/traits.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/procedure/traits.rs @@ -6,6 +6,7 @@ use crate::encoder::{ middle::core_proof::{lowerer::Lowerer, snapshots::into_snapshot::common::IntoSnapshotLowerer}, }; use vir_crate::{ + common::expression::ExpressionIterator, low::{self as vir_low}, middle::{self as vir_mid}, }; @@ -25,7 +26,35 @@ impl IntoProcedureBoolExpression for vir_mid::Expression { &self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, ) -> SpannedEncodingResult { - ProcedureSnapshot::default().expression_to_snapshot(lowerer, self, true) + ProcedureSnapshot::new_for_owned().expression_to_snapshot(lowerer, self, true) + } +} + +/// Converts `self` into assertion that evaluates to a Viper Bool. +pub(in super::super::super::super) trait IntoProcedureAssertion { + type Target; + fn to_procedure_assertion<'p, 'v: 'p, 'tcx: 'v>( + &self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ) -> SpannedEncodingResult; +} + +impl IntoProcedureAssertion for vir_mid::Expression { + type Target = vir_low::Expression; + fn to_procedure_assertion<'p, 'v: 'p, 'tcx: 'v>( + &self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ) -> SpannedEncodingResult { + let mut snapshot_encoder = ProcedureSnapshot { + is_assertion: true, + ..ProcedureSnapshot::new_for_owned() + }; + let expression = snapshot_encoder.expression_to_snapshot(lowerer, self, true)?; + Ok(snapshot_encoder + .in_heap_assertions + .into_iter() + .chain(std::iter::once(expression)) + .conjoin()) } } @@ -43,7 +72,7 @@ impl IntoProcedureSnapshot for vir_mid::VariableDecl { &self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, ) -> SpannedEncodingResult { - ProcedureSnapshot::default().variable_to_snapshot(lowerer, self) + ProcedureSnapshot::new_for_owned().variable_to_snapshot(lowerer, self) } } @@ -53,7 +82,7 @@ impl IntoProcedureSnapshot for vir_mid::Expression { &self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, ) -> SpannedEncodingResult { - ProcedureSnapshot::default().expression_to_snapshot(lowerer, self, false) + ProcedureSnapshot::new_for_owned().expression_to_snapshot(lowerer, self, false) } } @@ -87,7 +116,7 @@ impl IntoProcedureFinalSnapshot for vir_mid::Expression { ) -> SpannedEncodingResult { let mut snapshot_encoder = ProcedureSnapshot { deref_to_final: true, - ..ProcedureSnapshot::default() + ..ProcedureSnapshot::new_for_owned() }; snapshot_encoder.expression_to_snapshot(lowerer, self, false) } diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/pure/mod.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/pure/mod.rs index aa9561a9468..31e724e6388 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/pure/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/pure/mod.rs @@ -4,7 +4,10 @@ use super::common::IntoSnapshotLowerer; use crate::encoder::{ errors::SpannedEncodingResult, - middle::core_proof::lowerer::{FunctionsLowererInterface, Lowerer}, + middle::core_proof::{ + builtin_methods::CallContext, + lowerer::{FunctionsLowererInterface, Lowerer}, + }, }; use vir_crate::{ common::identifier::WithIdentifier, @@ -14,9 +17,15 @@ use vir_crate::{ mod traits; -pub(in super::super::super) use self::traits::{IntoPureBoolExpression, IntoPureSnapshot}; +pub(in super::super::super) use self::traits::{ + IntoFramedPureSnapshot, IntoPureBoolExpression, IntoPureSnapshot, +}; -struct PureSnapshot; +#[derive(Default)] +struct PureSnapshot { + /// Assume that all pointer accesses are safe. + assume_pointers_to_be_framed: bool, +} impl<'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for PureSnapshot { fn variable_to_snapshot( @@ -58,4 +67,120 @@ impl<'p, 'v: 'p, 'tcx: 'v> IntoSnapshotLowerer<'p, 'v, 'tcx> for PureSnapshot { // In pure contexts values cannot be mutated, so `old` has no effect. self.expression_to_snapshot(lowerer, &old.base, expect_math_bool) } + + fn pointer_deref_to_snapshot( + &mut self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + deref: &vir_mid::Deref, + base_snapshot: vir_low::Expression, + expect_math_bool: bool, + ) -> SpannedEncodingResult { + // FIXME: Delete. + assert!(self.assume_pointers_to_be_framed); + eprintln!("deref: {}", deref); + eprintln!("base_snapshot: {}", base_snapshot); + unimplemented!(); + } + + // fn deref_to_snapshot( + // &mut self, + // lowerer: &mut Lowerer<'p, 'v, 'tcx>, + // deref: &vir_mid::Deref, + // expect_math_bool: bool, + // ) -> SpannedEncodingResult { + // let base_snapshot = self.expression_to_snapshot(lowerer, &deref.base, expect_math_bool)?; + // let ty = deref.base.get_type(); + // let result = if ty.is_reference() { + // lowerer.reference_target_current_snapshot(ty, base_snapshot, deref.position)? + // } else { + // unreachable!(); + // // unimplemented!("TODO: to double-check that this is actually used (and in a correct way)"); + // // This most likely should be unreachable. In axioms we should use snapshot variables + // // instead. + // // let heap = vir_low::VariableDecl::new("pure_heap$", lowerer.heap_type()?); + // // lowerer.pointer_target_snapshot_in_heap( + // // deref.base.get_type(), + // // heap, + // // base_snapshot, + // // deref.position, + // // )? + // // lowerer.pointer_target_snapshot( + // // deref.base.get_type(), + // // &None, + // // base_snapshot, + // // deref.position, + // // )? + // }; + // self.ensure_bool_expression(lowerer, deref.get_type(), result, expect_math_bool) + // } + + // FIXME: Mark as unreachable. + fn acc_predicate_to_snapshot( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _acc_predicate: &vir_mid::AccPredicate, + _expect_math_bool: bool, + ) -> SpannedEncodingResult { + unimplemented!("FIXME: Delete"); + // assert!(self.is_assertion); + // let expression = match &*acc_predicate.predicate { + // vir_mid::Predicate::OwnedNonAliased(predicate) => { + // eprintln!("pure predicate: {}", predicate); + // let ty = predicate.place.get_type(); + // let place = lowerer.encode_expression_as_place(&predicate.place)?; + // // let root_address = lowerer.extract_root_address(&predicate.place)?; + // let root_address = true.into(); + // // let snapshot = predicate.place.to_pure_snapshot(lowerer)?; + // let snapshot = true.into(); + // let acc = lowerer.owned_aliased( + // CallContext::Procedure, + // ty, + // ty, + // place, + // root_address, + // snapshot, + // None, + // )?; + // eprintln!(" → {}", acc); + // acc + // } + // vir_mid::Predicate::MemoryBlockHeap(predicate) => { + // // let place = lowerer.encode_expression_as_place_address(&predicate.address)?; + // let place = true.into(); + // let size = predicate.size.to_pure_snapshot(lowerer)?; + // lowerer.encode_memory_block_stack_acc(place, size, acc_predicate.position)? + // } + // vir_mid::Predicate::MemoryBlockHeapDrop(predicate) => { + // // let place = lowerer.encode_expression_as_place_address(&predicate.address)?; + // let place = true.into(); + // // let size = predicate.size.to_pure_snapshot(lowerer)?; + // let size = true.into(); + // lowerer.encode_memory_block_heap_drop_acc(place, size, acc_predicate.position)? + // } + // _ => unimplemented!("{acc_predicate}"), + // }; + // Ok(expression) + } + + fn owned_non_aliased_snap( + &mut self, + _lowerer: &mut Lowerer<'p, 'v, 'tcx>, + _ty: &vir_mid::Type, + _pointer_snapshot: &vir_mid::Expression, + ) -> SpannedEncodingResult { + unimplemented!() + } + + // fn unfolding_to_snapshot( + // &mut self, + // lowerer: &mut Lowerer<'p, 'v, 'tcx>, + // unfolding: &vir_mid::Unfolding, + // expect_math_bool: bool, + // ) -> SpannedEncodingResult { + // todo!() + // } + + fn call_context(&self) -> CallContext { + todo!() + } } diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/pure/traits.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/pure/traits.rs index 3cda221b325..c06e5b9c6e1 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/pure/traits.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/into_snapshot/pure/traits.rs @@ -28,7 +28,7 @@ impl IntoPureBoolExpression for vir_mid::Expression { &self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, ) -> SpannedEncodingResult { - PureSnapshot.expression_to_snapshot(lowerer, self, true) + PureSnapshot::default().expression_to_snapshot(lowerer, self, true) } } @@ -38,7 +38,31 @@ impl IntoPureBoolExpression for Vec { &self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, ) -> SpannedEncodingResult { - PureSnapshot.expression_vec_to_snapshot(lowerer, self, true) + PureSnapshot::default().expression_vec_to_snapshot(lowerer, self, true) + } +} + +/// Converts `self` into expression that evaluates to a snapshot. It assumes +/// that all pointers can be safely dereferenced. +pub(in super::super::super::super) trait IntoFramedPureSnapshot { + type Target; + fn to_framed_pure_snapshot<'p, 'v: 'p, 'tcx: 'v>( + &self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ) -> SpannedEncodingResult; +} + +impl IntoFramedPureSnapshot for vir_mid::Expression { + type Target = vir_low::Expression; + fn to_framed_pure_snapshot<'p, 'v: 'p, 'tcx: 'v>( + &self, + lowerer: &mut Lowerer<'p, 'v, 'tcx>, + ) -> SpannedEncodingResult { + let mut snapshot_encoder = PureSnapshot { + assume_pointers_to_be_framed: true, + ..PureSnapshot::default() + }; + snapshot_encoder.expression_to_snapshot(lowerer, self, true) } } @@ -57,7 +81,7 @@ impl IntoPureSnapshot for vir_mid::Expression { &self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, ) -> SpannedEncodingResult { - PureSnapshot.expression_to_snapshot(lowerer, self, false) + PureSnapshot::default().expression_to_snapshot(lowerer, self, false) } } @@ -69,7 +93,7 @@ impl IntoPureSnapshot for Vec { ) -> SpannedEncodingResult { let mut variables = Vec::new(); for variable in self { - variables.push(PureSnapshot.variable_to_snapshot(lowerer, variable)?); + variables.push(PureSnapshot::default().variable_to_snapshot(lowerer, variable)?); } Ok(variables) } @@ -81,7 +105,7 @@ impl IntoPureSnapshot for vir_mid::VariableDecl { &self, lowerer: &mut Lowerer<'p, 'v, 'tcx>, ) -> SpannedEncodingResult { - PureSnapshot.variable_to_snapshot(lowerer, self) + PureSnapshot::default().variable_to_snapshot(lowerer, self) } } diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/mod.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/mod.rs index 51806b98828..32304bffbd5 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/mod.rs @@ -14,11 +14,15 @@ pub(super) use self::{ bytes::SnapshotBytesInterface, domains::SnapshotDomainsInterface, into_snapshot::{ - IntoBuiltinMethodSnapshot, IntoProcedureBoolExpression, IntoProcedureFinalSnapshot, - IntoProcedureSnapshot, IntoPureBoolExpression, IntoPureSnapshot, IntoSnapshot, + AssertionToSnapshotConstructor, FramedExpressionToSnapshot, IntoBuiltinMethodSnapshot, + IntoFramedPureSnapshot, IntoProcedureAssertion, IntoProcedureBoolExpression, + IntoProcedureFinalSnapshot, IntoProcedureSnapshot, IntoPureBoolExpression, + IntoPureSnapshot, IntoSnapshot, IntoSnapshotLowerer, PredicateKind, + ProcedureExpressionToSnapshot, ProcedureSnapshot, SelfFramingAssertionToSnapshot, + ValidityAssertionToSnapshot, }, state::SnapshotsState, validity::{valid_call, valid_call2, SnapshotValidityInterface}, values::SnapshotValuesInterface, - variables::SnapshotVariablesInterface, + variables::{AllVariablesMap, SnapshotVariablesInterface, VariableVersionMap}, }; diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/state.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/state.rs index 3d7591b02b9..e9e0c4451b3 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/state.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/state.rs @@ -15,10 +15,11 @@ pub(in super::super) struct SnapshotsState { pub(super) encoded_to_bytes: FxHashSet, /// The list of types for which sequence_repeat_constructor was encoded. pub(super) encoded_sequence_repeat_constructor: FxHashSet, - pub(super) all_variables: AllVariablesMap, - pub(super) variables: BTreeMap, - pub(super) variables_at_label: BTreeMap, - pub(super) current_variables: Option, + pub(super) ssa_state: vir_low::ssa::SSAState, + // pub(super) all_variables: AllVariablesMap, + // pub(super) variables: BTreeMap, + // pub(super) variables_at_label: BTreeMap, + // pub(super) current_variables: Option, /// Mapping from low types to their domain names. pub(super) type_domains: FxHashMap, } diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/validity/interface.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/validity/interface.rs index 70cc436c562..6ab29932790 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/validity/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/validity/interface.rs @@ -54,6 +54,12 @@ pub(in super::super::super) trait SnapshotValidityInterface { &mut self, domain_name: &str, parameters: Vec, + ) -> SpannedEncodingResult<()>; + fn encode_validity_axioms_struct_with_invariant( + &mut self, + domain_name: &str, + parameters: Vec, + parameters_with_validity: usize, invariant: vir_low::Expression, ) -> SpannedEncodingResult<()>; fn encode_validity_axioms_struct_alternative_constructor( @@ -61,6 +67,7 @@ pub(in super::super::super) trait SnapshotValidityInterface { domain_name: &str, variant_name: &str, parameters: Vec, + parameters_with_validity: usize, invariant: vir_low::Expression, ) -> SpannedEncodingResult<()>; /// `variants` is `(variant_name, variant_domain, discriminant)`. @@ -120,31 +127,56 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotValidityInterface for Lowerer<'p, 'v, 'tcx> { ) -> SpannedEncodingResult<()> { use vir_low::macros::*; let parameters = vars! { value: {parameter_type}}; - self.encode_validity_axioms_struct(domain_name, parameters, invariant) + let parameters_with_validity = parameters.len(); + self.encode_validity_axioms_struct_with_invariant( + domain_name, + parameters, + parameters_with_validity, + invariant, + ) } fn encode_validity_axioms_struct( &mut self, domain_name: &str, parameters: Vec, + ) -> SpannedEncodingResult<()> { + let parameters_with_validity = parameters.len(); + self.encode_validity_axioms_struct_with_invariant( + domain_name, + parameters, + parameters_with_validity, + true.into(), + ) + } + fn encode_validity_axioms_struct_with_invariant( + &mut self, + domain_name: &str, + parameters: Vec, + parameters_with_validity: usize, invariant: vir_low::Expression, ) -> SpannedEncodingResult<()> { self.encode_validity_axioms_struct_alternative_constructor( domain_name, "", parameters, + parameters_with_validity, invariant, ) } + /// `parameters_with_validity` – how many of `parameters` should have a + /// conjoined validity call. For all Rust types without permissions in their + /// structural invariants, `parameters_with_validity == parameters.len()`. fn encode_validity_axioms_struct_alternative_constructor( &mut self, domain_name: &str, variant_name: &str, parameters: Vec, + parameters_with_validity: usize, invariant: vir_low::Expression, ) -> SpannedEncodingResult<()> { use vir_low::macros::*; let mut valid_parameters = Vec::new(); - for parameter in ¶meters { + for parameter in parameters.iter().take(parameters_with_validity) { if let Some(domain_name) = self.get_non_primitive_domain(¶meter.ty) { let domain_name = domain_name.to_string(); valid_parameters @@ -159,7 +191,8 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotValidityInterface for Lowerer<'p, 'v, 'tcx> { .map(|parameter| parameter.clone().into()) .collect(), )?; - let valid_constructor = self.encode_snapshot_valid_call(domain_name, constructor_call)?; + let valid_constructor = + self.encode_snapshot_valid_call(domain_name, constructor_call.clone())?; if parameters.is_empty() { let axiom = vir_low::DomainAxiomDecl { comment: None, @@ -181,8 +214,11 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotValidityInterface for Lowerer<'p, 'v, 'tcx> { // parameters, the bottom-up and top-down axioms are equivalent. let mut top_down_validity_expression = validity_expression.clone(); var_decls! { snapshot: {vir_low::Type::domain(domain_name.to_string())}}; + let snapshot_expression = snapshot.clone().into(); + top_down_validity_expression = + top_down_validity_expression.replace_self(&snapshot_expression); let valid_constructor = - self.encode_snapshot_valid_call(domain_name, snapshot.clone().into())?; + self.encode_snapshot_valid_call(domain_name, snapshot_expression)?; let mut triggers = Vec::new(); for parameter in ¶meters { if self.get_non_primitive_domain(¶meter.ty).is_some() { @@ -214,6 +250,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotValidityInterface for Lowerer<'p, 'v, 'tcx> { }; self.declare_axiom(domain_name, axiom_top_down)?; } + let bottom_up_validity_expression = validity_expression.replace_self(&constructor_call); let axiom_bottom_up_body = { let mut trigger = vec![valid_constructor.clone()]; trigger.extend(valid_parameters.clone()); @@ -221,7 +258,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotValidityInterface for Lowerer<'p, 'v, 'tcx> { parameters, vec![vir_low::Trigger::new(trigger)], expr! { - [ valid_constructor ] == [ validity_expression ] + [ valid_constructor ] == [ bottom_up_validity_expression ] }, ) }; diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/values/interface.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/values/interface.rs index 98af319413c..c751376d32f 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/values/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/values/interface.rs @@ -23,6 +23,14 @@ pub(in super::super::super) trait SnapshotValuesInterface { argument: vir_low::Expression, position: vir_mid::Position, ) -> SpannedEncodingResult; + fn obtain_parameter_snapshot( + &mut self, + base_type: &vir_mid::Type, + parameter_name: &str, + parameter_type: vir_low::Type, + base_snapshot: vir_low::Expression, + position: vir_mid::Position, + ) -> SpannedEncodingResult; fn obtain_struct_field_snapshot( &mut self, base_type: &vir_mid::Type, @@ -118,19 +126,46 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotValuesInterface for Lowerer<'p, 'v, 'tcx> { position, ) } - fn obtain_struct_field_snapshot( + fn obtain_parameter_snapshot( &mut self, base_type: &vir_mid::Type, - field: &vir_mid::FieldDecl, + parameter_name: &str, + parameter_type: vir_low::Type, base_snapshot: vir_low::Expression, position: vir_mid::Position, ) -> SpannedEncodingResult { let domain_name = self.encode_snapshot_domain_name(base_type)?; - let return_type = field.ty.to_snapshot(self)?; + let return_type = parameter_type; Ok(self - .snapshot_destructor_struct_call(&domain_name, &field.name, return_type, base_snapshot)? + .snapshot_destructor_struct_call( + &domain_name, + parameter_name, + return_type, + base_snapshot, + )? .set_default_position(position)) } + fn obtain_struct_field_snapshot( + &mut self, + base_type: &vir_mid::Type, + field: &vir_mid::FieldDecl, + base_snapshot: vir_low::Expression, + position: vir_mid::Position, + ) -> SpannedEncodingResult { + let parameter_type = field.ty.to_snapshot(self)?; + self.obtain_parameter_snapshot( + base_type, + &field.name, + parameter_type, + base_snapshot, + position, + ) + // let domain_name = self.encode_snapshot_domain_name(base_type)?; + // let return_type = field.ty.to_snapshot(self)?; + // Ok(self + // .snapshot_destructor_struct_call(&domain_name, &field.name, return_type, base_snapshot)? + // .set_default_position(position)) + } fn obtain_enum_variant_snapshot( &mut self, base_type: &vir_mid::Type, @@ -209,7 +244,9 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotValuesInterface for Lowerer<'p, 'v, 'tcx> { vir_mid::Type::Reference(_) => self.address_type()?, x => unimplemented!("{:?}", x), }; - vir_low::operations::ty::Typed::set_type(&mut argument, low_type); + if !ty.is_bool() { + vir_low::operations::ty::Typed::set_type(&mut argument, low_type); + } Ok(self .snapshot_constructor_constant_call(&domain_name, vec![argument])? .set_default_position(position)) diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/interface.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/interface.rs index 8f63e6eec50..49b660dbadc 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/interface.rs @@ -2,7 +2,10 @@ use crate::encoder::{ errors::{ErrorCtxt, SpannedEncodingResult}, high::types::HighTypeEncoderInterface, middle::core_proof::{ + addresses::AddressesInterface, + heap::HeapInterface, lowerer::{Lowerer, VariablesLowererInterface}, + pointers::PointersInterface, references::ReferencesInterface, snapshots::{ IntoProcedureSnapshot, IntoSnapshot, SnapshotValidityInterface, SnapshotValuesInterface, @@ -15,6 +18,7 @@ use crate::encoder::{ use std::collections::BTreeMap; use vir_crate::{ + common::check_mode::CheckMode, low::{self as vir_low}, middle::{self as vir_mid, operations::ty::Typed}, }; @@ -29,11 +33,14 @@ trait Private { version: u64, ) -> SpannedEncodingResult; #[allow(clippy::ptr_arg)] // Clippy false positive. + /// Note: if `new_snapshot_root` is `Some`, the current encoding assumes + /// that the `place` is not behind a raw pointer. fn snapshot_copy_except( &mut self, statements: &mut Vec, - old_snapshot_root: vir_low::VariableDecl, - new_snapshot_root: vir_low::VariableDecl, + base: vir_mid::VariableDecl, + // old_snapshot_root: vir_low::Expression, + // new_snapshot_root: vir_low::Expression, place: &vir_mid::Expression, position: vir_low::Position, ) -> SpannedEncodingResult<(vir_low::Expression, vir_low::Expression)>; @@ -46,9 +53,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { ty: &vir_mid::Type, version: u64, ) -> SpannedEncodingResult { - let name = format!("{name}$snapshot${version}"); + // let name = format!("{}$snapshot${}", name, version); let ty = ty.to_snapshot(self)?; - self.create_variable(name, ty) + // self.create_variable(name, ty) + self.create_snapshot_variable_low(name, ty, version) } /// Copy all values of the old snapshot into the new snapshot, except the /// ones that belong to `place`. @@ -57,27 +65,72 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { fn snapshot_copy_except( &mut self, statements: &mut Vec, - old_snapshot_root: vir_low::VariableDecl, - new_snapshot_root: vir_low::VariableDecl, + base: vir_mid::VariableDecl, + // old_snapshot_root: vir_low::Expression, + // new_snapshot_root: vir_low::Expression, place: &vir_mid::Expression, position: vir_low::Position, ) -> SpannedEncodingResult<(vir_low::Expression, vir_low::Expression)> { use vir_low::macros::*; if let Some(parent) = place.get_parent_ref() { - let (old_snapshot, new_snapshot) = self.snapshot_copy_except( - statements, - old_snapshot_root, - new_snapshot_root, - parent, - position, - )?; let parent_type = parent.get_type(); + let (old_snapshot, new_snapshot) = + if let vir_mid::Type::Pointer(pointer_type) = parent_type { + let fresh_heap_chunk = self.fresh_heap_chunk(position)?; + let heap_chunk = self.heap_chunk_to_snapshot( + &pointer_type.target_type, + fresh_heap_chunk.clone().into(), + position, + )?; + if self.use_heap_variable()? { + let old_snapshot = parent.to_procedure_snapshot(self)?; // FIXME: This is most likely wrong. + let old_target_snapshot = self.pointer_target_snapshot( + parent.get_type(), + &None, + old_snapshot.clone(), + position, + )?; + let old_heap = self.heap_variable_version_at_label(&None)?; + + // Note: All `old_*` need to be computed before the heap version + // is incremented. + let new_heap = self.new_heap_variable_version(position)?; + let address = + self.pointer_address(parent.get_type(), old_snapshot, position)?; + statements.push(vir_low::Statement::assign( + new_heap, + self.heap_update( + old_heap.into(), + address, + fresh_heap_chunk.clone().into(), + position, + )?, + // vir_low::Expression::container_op( + // vir_low::ContainerOpKind::MapUpdate, + // self.heap_type()?, + // vec![old_heap.into(), address, fresh_heap_chunk.into()], + // position, + // ), + position, + )); + return Ok((old_target_snapshot, heap_chunk)); + } else { + return Ok((heap_chunk.clone(), heap_chunk)); + } + } else { + self.snapshot_copy_except( + statements, base, + // old_snapshot_root, + // new_snapshot_root, + parent, position, + )? + }; + let type_decl = self.encoder.get_type_decl_mid(parent_type)?; match &type_decl { vir_mid::TypeDecl::Bool | vir_mid::TypeDecl::Int(_) - | vir_mid::TypeDecl::Float(_) - | vir_mid::TypeDecl::Pointer(_) => { + | vir_mid::TypeDecl::Float(_) => { unreachable!("place: {}", place); } vir_mid::TypeDecl::Trusted(_) | vir_mid::TypeDecl::TypeVar(_) => { @@ -231,6 +284,38 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { unimplemented!("Place: {}", place); } } + vir_mid::TypeDecl::Pointer(_decl) => { + unreachable!("Should be handled by the caller."); + // let fresh_heap_chunk = self.fresh_heap_chunk()?; + // let heap_chunk = self.heap_chunk_to_snapshot( + // &decl.target_type, + // fresh_heap_chunk.clone().into(), + // position, + // )?; + // let old_heap = self.heap_variable_version_at_label(&None)?; + // let new_heap = self.new_heap_variable_version(position)?; + // let address = + // self.pointer_address(parent_type, old_snapshot.clone(), position)?; + // statements.push(vir_low::Statement::assign( + // new_heap, + // vir_low::Expression::container_op( + // vir_low::ContainerOpKind::MapUpdate, + // self.heap_type()?, + // vec![old_heap.into(), address, fresh_heap_chunk.into()], + // position, + // ), + // position, + // )); + // // statements.push(vir_low::Statement::assume( + // // vir_low::Expression::equals( + // // heap_chunk.clone(), + + // // ) + // // )); + // let old_target_snapshot = + // self.pointer_target_snapshot(parent_type, &None, old_snapshot, position)?; + // Ok((old_target_snapshot, heap_chunk)) + } vir_mid::TypeDecl::Sequence(_) => unimplemented!("ty: {}", type_decl), vir_mid::TypeDecl::Map(_) => unimplemented!("ty: {}", type_decl), vir_mid::TypeDecl::Never => unimplemented!("ty: {}", type_decl), @@ -238,6 +323,8 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { vir_mid::TypeDecl::Unsupported(_) => unimplemented!("ty: {}", type_decl), } } else { + let old_snapshot_root = base.to_procedure_snapshot(self)?; + let new_snapshot_root = self.new_snapshot_variable_version(&base, position)?; // We reached the root. Nothing to do here. Ok((old_snapshot_root.into(), new_snapshot_root.into())) } @@ -245,6 +332,12 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { } pub(in super::super::super) trait SnapshotVariablesInterface { + fn create_snapshot_variable_low( + &mut self, + name: &str, + ty: vir_low::Type, + version: u64, + ) -> SpannedEncodingResult; fn new_snapshot_variable_version( &mut self, variable: &vir_mid::VariableDecl, @@ -263,21 +356,40 @@ pub(in super::super::super) trait SnapshotVariablesInterface { variable: &vir_mid::VariableDecl, label: &str, ) -> SpannedEncodingResult; + fn use_heap_variable(&self) -> SpannedEncodingResult; + fn heap_variable_name(&self) -> SpannedEncodingResult<&'static str>; + fn new_heap_variable_version( + &mut self, + position: vir_low::Position, + ) -> SpannedEncodingResult; + fn heap_variable_version_at_label( + &mut self, + old_label: &Option, + ) -> SpannedEncodingResult; + fn address_variable_version_at_label( + &mut self, + variable_name: &str, + old_label: &Option, + ) -> SpannedEncodingResult; + fn fresh_heap_chunk( + &mut self, + position: vir_low::Position, + ) -> SpannedEncodingResult; fn encode_snapshot_havoc( &mut self, statements: &mut Vec, target: &vir_mid::Expression, position: vir_low::Position, - new_snapshot: Option, - ) -> SpannedEncodingResult<()>; + // new_snapshot: Option, + ) -> SpannedEncodingResult; fn encode_snapshot_update_with_new_snapshot( &mut self, statements: &mut Vec, target: &vir_mid::Expression, value: vir_low::Expression, position: vir_low::Position, - new_snapshot: Option, - ) -> SpannedEncodingResult<()>; + // new_snapshot: Option, + ) -> SpannedEncodingResult; #[allow(clippy::ptr_arg)] // Clippy false positive. fn encode_snapshot_update( &mut self, @@ -292,8 +404,15 @@ pub(in super::super::super) trait SnapshotVariablesInterface { predecessors: &BTreeMap>, basic_block_edges: &mut BTreeMap< vir_mid::BasicBlockId, - BTreeMap>, + BTreeMap< + vir_mid::BasicBlockId, + Vec<(String, vir_low::Type, vir_low::Position, u64, u64)>, + >, >, + // basic_block_edges: &mut BTreeMap< + // vir_mid::BasicBlockId, + // BTreeMap>, + // >, ) -> SpannedEncodingResult<()>; fn unset_current_block_for_snapshots( &mut self, @@ -303,32 +422,51 @@ pub(in super::super::super) trait SnapshotVariablesInterface { } impl<'p, 'v: 'p, 'tcx: 'v> SnapshotVariablesInterface for Lowerer<'p, 'v, 'tcx> { + fn create_snapshot_variable_low( + &mut self, + name: &str, + ty: vir_low::Type, + version: u64, + ) -> SpannedEncodingResult { + let name = format!("{}$snapshot${}", name, version); + self.create_variable(name, ty) + } fn new_snapshot_variable_version( &mut self, variable: &vir_mid::VariableDecl, position: vir_low::Position, ) -> SpannedEncodingResult { - let new_version = self - .snapshots_state - .all_variables - .new_version_or_default(variable, position); - self.snapshots_state - .current_variables - .as_mut() - .unwrap() - .set(variable.name.clone(), new_version); - self.create_snapshot_variable(&variable.name, &variable.ty, new_version) + let ty = variable.ty.to_snapshot(self)?; + // let new_version = self.snapshots_state.all_variables.new_version_or_default( + // &variable.name, + // &ty, + // position, + // ); + // self.snapshots_state + // .current_variables + // .as_mut() + // .unwrap() + // .set(variable.name.clone(), new_version); + let new_version = + self.snapshots_state + .ssa_state + .new_variable_version(&variable.name, &ty, position); + self.create_snapshot_variable_low(&variable.name, ty, new_version) } fn current_snapshot_variable_version( &mut self, variable: &vir_mid::VariableDecl, ) -> SpannedEncodingResult { + // let version = self + // .snapshots_state + // .current_variables + // .as_ref() + // .unwrap() + // .get_or_default(&variable.name); let version = self .snapshots_state - .current_variables - .as_ref() - .unwrap() - .get_or_default(&variable.name); + .ssa_state + .current_variable_version(&variable.name); self.create_snapshot_variable(&variable.name, &variable.ty, version) } fn initial_snapshot_variable_version( @@ -342,45 +480,201 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotVariablesInterface for Lowerer<'p, 'v, 'tcx> variable: &vir_mid::VariableDecl, label: &str, ) -> SpannedEncodingResult { + // let version = self + // .snapshots_state + // .variables_at_label + // .get(label) + // .unwrap_or_else(|| panic!("not found label {}", label)) + // .get_or_default(&variable.name); let version = self .snapshots_state - .variables_at_label - .get(label) - .unwrap_or_else(|| panic!("not found label {label}")) - .get_or_default(&variable.name); + .ssa_state + .variable_version_at_label(&variable.name, label); self.create_snapshot_variable(&variable.name, &variable.ty, version) } + fn use_heap_variable(&self) -> SpannedEncodingResult { + Ok(self.check_mode.unwrap().is_purification_group()) + } + fn heap_variable_name(&self) -> SpannedEncodingResult<&'static str> { + assert!( + self.use_heap_variable()?, + "The heap variable is not used when the check mode is Both" + ); + Ok("heap$") + } + fn new_heap_variable_version( + &mut self, + position: vir_low::Position, + ) -> SpannedEncodingResult { + // let name = "heap$"; + let name = self.heap_variable_name()?; + let ty = self.heap_type()?; + let new_version = self + .snapshots_state + .ssa_state + .new_variable_version(name, &ty, position); + // let new_version = self + // .snapshots_state + // .all_variables + // .new_version_or_default(name, &ty, position); + // self.snapshots_state + // .current_variables + // .as_mut() + // .unwrap() + // .set(name.to_string(), new_version); + self.create_snapshot_variable_low(name, ty, new_version) + } + fn heap_variable_version_at_label( + &mut self, + old_label: &Option, + ) -> SpannedEncodingResult { + // let name = "heap$"; + let name = self.heap_variable_name()?; + let version = self + .snapshots_state + .ssa_state + .variable_version_at_maybe_label(name, old_label); + // let version = if let Some(label) = old_label { + // self.snapshots_state + // .variables_at_label + // .get(label) + // .unwrap_or_else(|| panic!("not found label {}", label)) + // .get_or_default(name) + // } else { + // self.snapshots_state + // .current_variables + // .as_ref() + // .unwrap() + // .get_or_default(name) + // }; + let ty = self.heap_type()?; + // let name = format!("{}${}", name, version); + // self.create_variable(name, ty) + self.create_snapshot_variable_low(name, ty, version) + } + fn address_variable_version_at_label( + &mut self, + variable_name: &str, + old_label: &Option, + ) -> SpannedEncodingResult { + let name = format!("{}$address", variable_name); + let version = self + .snapshots_state + .ssa_state + .variable_version_at_maybe_label(&name, old_label); + // let version = if let Some(label) = old_label { + // self.snapshots_state + // .variables_at_label + // .get(label) + // .unwrap_or_else(|| panic!("not found label {}", label)) + // .get_or_default(&name) + // } else { + // self.snapshots_state + // .current_variables + // .as_ref() + // .unwrap() + // .get_or_default(&name) + // }; + let ty = self.address_type()?; + self.create_snapshot_variable_low(&name, ty, version) + } + fn fresh_heap_chunk( + &mut self, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let name = "heap_chunk$"; + let ty = self.heap_chunk_type()?; + let new_version = self + .snapshots_state + .ssa_state + .new_variable_version(name, &ty, position); + // let new_version = self + // .snapshots_state + // .all_variables + // .new_version_or_default(name, &ty, position); + // self.snapshots_state + // .current_variables + // .as_mut() + // .unwrap() + // .set(name.to_string(), new_version); + // let name = format!("{}${}", name, new_version); + // self.create_variable(name, ty) + self.create_snapshot_variable_low(name, ty, new_version) + } fn encode_snapshot_havoc( &mut self, statements: &mut Vec, target: &vir_mid::Expression, position: vir_low::Position, - new_snapshot: Option, - ) -> SpannedEncodingResult<()> { + // new_snapshot_root: Option, + ) -> SpannedEncodingResult { + // let base = target.get_base(); + // self.ensure_type_definition(&base.ty)?; + // let old_snapshot = base.to_procedure_snapshot(self)?; + // let new_snapshot = if let Some(new_snapshot) = new_snapshot { + // new_snapshot + // } else { + // self.new_snapshot_variable_version(&base, position)? + // }; + // self.snapshot_copy_except(statements, old_snapshot, new_snapshot, target, position)?; + // Ok(()) let base = target.get_base(); self.ensure_type_definition(&base.ty)?; - let old_snapshot = base.to_procedure_snapshot(self)?; - let new_snapshot = if let Some(new_snapshot) = new_snapshot { - new_snapshot - } else { - self.new_snapshot_variable_version(&base, position)? - }; - self.snapshot_copy_except(statements, old_snapshot, new_snapshot, target, position)?; - Ok(()) + + // if let Some(pointer_place) = target.get_last_dereferenced_pointer() { + // let pointer_type = pointer_place.get_type().clone().unwrap_pointer(); + // let fresh_heap_chunk = self.fresh_heap_chunk()?; + // let heap_chunk = self.heap_chunk_to_snapshot( + // &pointer_type.target_type, + // fresh_heap_chunk.clone().into(), + // position, + // )?; + // let old_heap = self.heap_variable_version_at_label(&None)?; + // let new_heap = self.new_heap_variable_version(position)?; + // let address = + // self.pointer_address(pointer_place.get_type(), old_snapshot.clone().into(), position)?; + // statements.push(vir_low::Statement::assign( + // new_heap, + // vir_low::Expression::container_op( + // vir_low::ContainerOpKind::MapUpdate, + // self.heap_type()?, + // vec![old_heap.into(), address, fresh_heap_chunk.into()], + // position, + // ), + // position, + // )); + // let old_target_snapshot = + // self.pointer_target_snapshot(pointer_place.get_type(), &None, old_snapshot.into(), position)?; + // // Ok((old_target_snapshot, heap_chunk) + // let (_old_snapshot, new_snapshot) = + // self.snapshot_copy_except(statements, old_target_snapshot, heap_chunk, pointer_place, position)?; + // Ok(new_snapshot) + // } else { + + // let (_old_snapshot, new_snapshot) = + // self.snapshot_copy_except(statements, old_snapshot, new_snapshot, target, position)?; + let (_old_snapshot, new_snapshot) = + self.snapshot_copy_except(statements, base, target, position)?; + Ok(new_snapshot) + // } } + /// `new_snapshot_root` is used when we want to use a specific variable + /// version as the root of the new snapshot. fn encode_snapshot_update_with_new_snapshot( &mut self, statements: &mut Vec, target: &vir_mid::Expression, value: vir_low::Expression, position: vir_low::Position, - new_snapshot: Option, - ) -> SpannedEncodingResult<()> { + // new_snapshot_root: Option, + ) -> SpannedEncodingResult { use vir_low::macros::*; - self.encode_snapshot_havoc(statements, target, position, new_snapshot)?; - statements - .push(stmtp! { position => assume ([target.to_procedure_snapshot(self)?] == [value]) }); - Ok(()) + // self.encode_snapshot_havoc(statements, target, position, new_snapshot)?; + // statements + // .push(stmtp! { position => assume ([target.to_procedure_snapshot(self)?] == [value]) }); + let new_snapshot = self.encode_snapshot_havoc(statements, target, position)?; + statements.push(stmtp! { position => assume ([new_snapshot.clone()] == [value]) }); + Ok(new_snapshot) } fn encode_snapshot_update( &mut self, @@ -389,7 +683,8 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotVariablesInterface for Lowerer<'p, 'v, 'tcx> value: vir_low::Expression, position: vir_low::Position, ) -> SpannedEncodingResult<()> { - self.encode_snapshot_update_with_new_snapshot(statements, target, value, position, None) + self.encode_snapshot_update_with_new_snapshot(statements, target, value, position)?; + Ok(()) } /// `basic_block_edges` are statements to be executed then going from one /// block to another. @@ -399,73 +694,88 @@ impl<'p, 'v: 'p, 'tcx: 'v> SnapshotVariablesInterface for Lowerer<'p, 'v, 'tcx> predecessors: &BTreeMap>, basic_block_edges: &mut BTreeMap< vir_mid::BasicBlockId, - BTreeMap>, + BTreeMap< + vir_mid::BasicBlockId, + Vec<(String, vir_low::Type, vir_low::Position, u64, u64)>, + >, >, + // basic_block_edges: &mut BTreeMap< + // vir_mid::BasicBlockId, + // BTreeMap>, + // >, ) -> SpannedEncodingResult<()> { - let predecessor_labels = &predecessors[label]; - let mut new_map = VariableVersionMap::default(); - for variable in self.snapshots_state.all_variables.names_clone() { - let predecessor_maps = predecessor_labels - .iter() - .map(|label| &self.snapshots_state.variables[label]) - .collect::>(); - let first_version = predecessor_maps[0].get_or_default(&variable); - let different = predecessor_maps - .iter() - .any(|map| map.get_or_default(&variable) != first_version); - if different { - let new_version = self.snapshots_state.all_variables.new_version(&variable); - let ty = self - .snapshots_state - .all_variables - .get_type(&variable) - .clone(); - let new_variable = self.create_snapshot_variable(&variable, &ty, new_version)?; - for predecessor_label in predecessor_labels { - let old_version = - self.snapshots_state.variables[predecessor_label].get_or_default(&variable); - let statements = basic_block_edges - .entry(predecessor_label.clone()) - .or_default() - .entry(label.clone()) - .or_default(); - let old_variable = - self.create_snapshot_variable(&variable, &ty, old_version)?; - let position = self.encoder.change_error_context( - // FIXME: Get a more precise span. - self.snapshots_state.all_variables.get_position(&variable), - ErrorCtxt::Unexpected, - ); - let statement = vir_low::macros::stmtp! { position => assume (new_variable == old_variable) }; - statements.push(statement); - } - new_map.set(variable, new_version); - } else { - new_map.set(variable, first_version); - } - } - self.snapshots_state.current_variables = Some(new_map); + self.snapshots_state.ssa_state.prepare_new_current_block( + label, + predecessors, + basic_block_edges, + ); + // let predecessor_labels = &predecessors[label]; + // let mut new_map = VariableVersionMap::default(); + // for variable in self.snapshots_state.all_variables.names_clone() { + // let predecessor_maps = predecessor_labels + // .iter() + // .map(|label| &self.snapshots_state.variables[label]) + // .collect::>(); + // let first_version = predecessor_maps[0].get_or_default(&variable); + // let different = predecessor_maps + // .iter() + // .any(|map| map.get_or_default(&variable) != first_version); + // if different { + // let new_version = self.snapshots_state.all_variables.new_version(&variable); + // let ty = self + // .snapshots_state + // .all_variables + // .get_type(&variable) + // .clone(); + // let new_variable = + // self.create_snapshot_variable_low(&variable, ty.clone(), new_version)?; + // for predecessor_label in predecessor_labels { + // let old_version = + // self.snapshots_state.variables[predecessor_label].get_or_default(&variable); + // let statements = basic_block_edges + // .entry(predecessor_label.clone()) + // .or_default() + // .entry(label.clone()) + // .or_default(); + // let old_variable = + // self.create_snapshot_variable_low(&variable, ty.clone(), old_version)?; + // let position = self.encoder.change_error_context( + // // FIXME: Get a more precise span. + // self.snapshots_state.all_variables.get_position(&variable), + // ErrorCtxt::Unexpected, + // ); + // let statement = vir_low::macros::stmtp! { position => assume (new_variable == old_variable) }; + // statements.push(statement); + // } + // new_map.set(variable, new_version); + // } else { + // new_map.set(variable, first_version); + // } + // } + // self.snapshots_state.current_variables = Some(new_map); Ok(()) } fn unset_current_block_for_snapshots( &mut self, label: vir_mid::BasicBlockId, ) -> SpannedEncodingResult<()> { - let current_variables = self.snapshots_state.current_variables.take().unwrap(); - assert!(self - .snapshots_state - .variables - .insert(label, current_variables) - .is_none()); + self.snapshots_state.ssa_state.finish_current_block(label); + // let current_variables = self.snapshots_state.current_variables.take().unwrap(); + // assert!(self + // .snapshots_state + // .variables + // .insert(label, current_variables) + // .is_none()); Ok(()) } fn save_old_label(&mut self, label: String) -> SpannedEncodingResult<()> { - let current_variables = self.snapshots_state.current_variables.clone().unwrap(); - assert!(self - .snapshots_state - .variables_at_label - .insert(label, current_variables) - .is_none()); + self.snapshots_state.ssa_state.save_state_at_label(label); + // let current_variables = self.snapshots_state.current_variables.clone().unwrap(); + // assert!(self + // .snapshots_state + // .variables_at_label + // .insert(label, current_variables) + // .is_none()); Ok(()) } } diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/mod.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/mod.rs index eb61a1f4135..ad92ebf902c 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/mod.rs @@ -3,5 +3,7 @@ mod interface; mod state; -pub(in super::super) use self::interface::SnapshotVariablesInterface; -pub(super) use self::state::{AllVariablesMap, VariableVersionMap}; +pub(in super::super) use self::{ + interface::SnapshotVariablesInterface, + state::{AllVariablesMap, VariableVersionMap}, +}; diff --git a/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/state.rs b/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/state.rs index f60795708a3..7b757ae4c3e 100644 --- a/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/state.rs +++ b/prusti-viper/src/encoder/middle/core_proof/snapshots/variables/state.rs @@ -1,11 +1,8 @@ use std::collections::BTreeMap; -use vir_crate::{ - low::{self as vir_low}, - middle::{self as vir_mid}, -}; +use vir_crate::low::{self as vir_low}; #[derive(Default, Clone)] -pub(in super::super) struct VariableVersionMap { +pub(in super::super::super) struct VariableVersionMap { /// Mapping from variable names to their versions. variable_versions: BTreeMap, } @@ -31,9 +28,9 @@ impl VariableVersionMap { } #[derive(Default)] -pub(in super::super) struct AllVariablesMap { +pub(in super::super::super) struct AllVariablesMap { versions: BTreeMap, - types: BTreeMap, + types: BTreeMap, positions: BTreeMap, } @@ -41,7 +38,7 @@ impl AllVariablesMap { pub(super) fn names_clone(&self) -> Vec { self.versions.keys().cloned().collect() } - pub(super) fn get_type(&self, variable: &str) -> &vir_mid::Type { + pub(super) fn get_type(&self, variable: &str) -> &vir_low::Type { &self.types[variable] } pub(super) fn get_position(&self, variable: &str) -> vir_low::Position { @@ -54,18 +51,18 @@ impl AllVariablesMap { } pub(super) fn new_version_or_default( &mut self, - variable: &vir_mid::VariableDecl, + variable: &str, + ty: &vir_low::Type, position: vir_low::Position, ) -> u64 { - if self.versions.contains_key(&variable.name) { - let version = self.versions.get_mut(&variable.name).unwrap(); + if self.versions.contains_key(variable) { + let version = self.versions.get_mut(variable).unwrap(); *version += 1; *version } else { - self.versions.insert(variable.name.clone(), 1); - self.types - .insert(variable.name.clone(), variable.ty.clone()); - self.positions.insert(variable.name.clone(), position); + self.versions.insert(variable.to_string(), 1); + self.types.insert(variable.to_string(), ty.clone()); + self.positions.insert(variable.to_string(), position); 1 } } diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/effects/mod.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/effects/mod.rs new file mode 100644 index 00000000000..21cc51e6ba0 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/effects/mod.rs @@ -0,0 +1,541 @@ +use super::{ + permission_mask::{ + PermissionMaskKind, PermissionMaskKindAliasedBool, + PermissionMaskKindAliasedFractionalBoundedPerm, PermissionMaskOperations, + PredicatePermissionMaskKind, TPermissionMaskOperations, + }, + HeapEncoder, +}; +use crate::encoder::errors::SpannedEncodingResult; +use vir_crate::{ + common::expression::{BinaryOperationHelpers, SyntacticEvaluation}, + low::{self as vir_low}, +}; + +impl<'p, 'v: 'p, 'tcx: 'v> HeapEncoder<'p, 'v, 'tcx> { + /// Note: this method assumes that it is called only as a top level assert + /// because it performs creating a new permission mask and rolling it back. + /// + /// Note: this method also evaluates accessibility predicates in + /// `expression_evaluation_state_label`. + pub(super) fn encode_expression_assert( + &mut self, + statements: &mut Vec, + expression: vir_low::Expression, + position: vir_low::Position, + expression_evaluation_state_label: Option, + ) -> SpannedEncodingResult<()> { + assert!(!position.is_default(), "expression: {}", expression); + if expression.is_pure() { + let expression = self.encode_pure_expression( + statements, + expression, + expression_evaluation_state_label, + position, + )?; + statements.push(vir_low::Statement::assert(expression, position)); + } else { + let check_point = self.fresh_label(); + self.ssa_state.save_state_at_label(check_point.clone()); + let evaluation_state = if let Some(label) = &expression_evaluation_state_label { + // This call is needed because we want to evaluate the + // accessibility predicates in the specified state. + self.ssa_state.change_state_to_label(label); + label + } else { + &check_point + }; + self.encode_expression_exhale(statements, expression, position, evaluation_state)?; + self.ssa_state.change_state_to_label(&check_point); + } + Ok(()) + } + + /// This method is similar to `encode_expression_assert` but it is intended + /// for asserting function preconditions. The key difference between + /// asserting function preconditions and regular assert statements is that + /// in function preconditions we ignore the permission amounts used in the + /// accessibility predicates: we only check that the permission amounts are + /// positive. + pub(super) fn encode_function_precondition_assert( + &mut self, + statements: &mut Vec, + expression: vir_low::Expression, + position: vir_low::Position, + expression_evaluation_state_label: Option, + ) -> SpannedEncodingResult<()> { + assert!(!position.is_default(), "expression: {}", expression); + if expression.is_pure() { + let expression = self.encode_pure_expression( + statements, + expression, + expression_evaluation_state_label, + position, + )?; + statements.push(vir_low::Statement::assert(expression, position)); + } else { + let check_point = self.fresh_label(); + self.ssa_state.save_state_at_label(check_point.clone()); + let evaluation_state = if let Some(label) = &expression_evaluation_state_label { + // This call is needed because we want to evaluate the + // accessibility predicates in the specified state. + self.ssa_state.change_state_to_label(label); + label + } else { + &check_point + }; + self.encode_function_precondition_assert_rec( + statements, + expression, + position, + evaluation_state, + )?; + self.ssa_state.change_state_to_label(&check_point); + } + Ok(()) + } + + pub(super) fn encode_function_precondition_assert_rec( + &mut self, + statements: &mut Vec, + expression: vir_low::Expression, + position: vir_low::Position, + expression_evaluation_state_label: &str, + ) -> SpannedEncodingResult<()> { + assert!(!position.is_default(), "expression: {}", expression); + if expression.is_pure() { + let expression = self.encode_pure_expression( + statements, + expression, + Some(expression_evaluation_state_label.to_string()), + position, + )?; + statements.push(vir_low::Statement::assert(expression, position)); + } else { + match expression { + vir_low::Expression::PredicateAccessPredicate(expression) => { + // FIXME: evaluate predicate arguments in expression_evaluation_state_label + match self.get_predicate_permission_mask_kind(&expression.name)? { + PredicatePermissionMaskKind::AliasedWholeBool + | PredicatePermissionMaskKind::AliasedFractionalBool => { + let operations = + PermissionMaskOperations::::new( + self, + statements, + &expression, + Some(expression_evaluation_state_label.to_string()), + position, + )?; + self.encode_function_precondition_assert_rec_predicate( + statements, + &expression, + position, + operations, + )? + } + PredicatePermissionMaskKind::AliasedFractionalBoundedPerm => { + let operations = PermissionMaskOperations::< + PermissionMaskKindAliasedFractionalBoundedPerm, + >::new( + self, + statements, + &expression, + Some(expression_evaluation_state_label.to_string()), + position, + )?; + self.encode_function_precondition_assert_rec_predicate( + statements, + &expression, + position, + operations, + )? + } + } + } + vir_low::Expression::Unfolding(_) => todo!(), + vir_low::Expression::LabelledOld(_) => todo!(), + vir_low::Expression::BinaryOp(expression) => match expression.op_kind { + vir_low::BinaryOpKind::And => { + self.encode_function_precondition_assert_rec( + statements, + *expression.left, + position, + expression_evaluation_state_label, + )?; + self.encode_function_precondition_assert_rec( + statements, + *expression.right, + position, + expression_evaluation_state_label, + )?; + } + vir_low::BinaryOpKind::Implies if expression.left.is_true() => { + self.encode_function_precondition_assert_rec( + statements, + *expression.right, + position, + expression_evaluation_state_label, + )?; + } + vir_low::BinaryOpKind::Implies => { + unimplemented!(); + } + _ => unreachable!("expression: {}", expression), + }, + vir_low::Expression::Conditional(_) => todo!(), + vir_low::Expression::FuncApp(_) => todo!(), + vir_low::Expression::DomainFuncApp(_) => todo!(), + _ => { + unimplemented!("expression: {:?}", expression); + } + } + } + Ok(()) + } + + fn encode_function_precondition_assert_rec_predicate( + &mut self, + statements: &mut Vec, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + position: vir_low::Position, + operations: PermissionMaskOperations, + ) -> SpannedEncodingResult<()> + where + PermissionMaskOperations: TPermissionMaskOperations, + { + statements.push(vir_low::Statement::comment(format!( + "assert-function-precondition-predicate {}", + predicate + ))); + // assert perm

(r1, r2, v_old) >= p + statements.push(vir_low::Statement::assert( + operations.perm_old_positive(), + position, // FIXME: use position of expression.permission with proper ErrorCtxt. + )); + Ok(()) + } + + pub(super) fn encode_expression_exhale( + &mut self, + statements: &mut Vec, + expression: vir_low::Expression, + position: vir_low::Position, + expression_evaluation_state_label: &str, + ) -> SpannedEncodingResult<()> { + assert!(!position.is_default(), "expression: {}", expression); + if expression.is_pure() { + let expression = self.encode_pure_expression( + statements, + expression, + Some(expression_evaluation_state_label.to_string()), + position, + )?; + statements.push(vir_low::Statement::assert(expression, position)); + } else { + match expression { + vir_low::Expression::PredicateAccessPredicate(expression) => { + // FIXME: evaluate predicate arguments in expression_evaluation_state_label + match self.get_predicate_permission_mask_kind(&expression.name)? { + PredicatePermissionMaskKind::AliasedWholeBool + | PredicatePermissionMaskKind::AliasedFractionalBool => { + let operations = + PermissionMaskOperations::::new( + self, + statements, + &expression, + Some(expression_evaluation_state_label.to_string()), + position, + )?; + self.encode_expression_exhale_predicate( + statements, + &expression, + position, + Some(expression_evaluation_state_label.to_string()), + operations, + )? + } + PredicatePermissionMaskKind::AliasedFractionalBoundedPerm => { + let operations = PermissionMaskOperations::< + PermissionMaskKindAliasedFractionalBoundedPerm, + >::new( + self, + statements, + &expression, + Some(expression_evaluation_state_label.to_string()), + position, + )?; + self.encode_expression_exhale_predicate( + statements, + &expression, + position, + Some(expression_evaluation_state_label.to_string()), + operations, + )? + } + } + } + vir_low::Expression::Unfolding(_) => todo!(), + vir_low::Expression::LabelledOld(_) => todo!(), + vir_low::Expression::BinaryOp(expression) => match expression.op_kind { + vir_low::BinaryOpKind::And => { + self.encode_expression_exhale( + statements, + *expression.left, + position, + expression_evaluation_state_label, + )?; + self.encode_expression_exhale( + statements, + *expression.right, + position, + expression_evaluation_state_label, + )?; + } + vir_low::BinaryOpKind::Implies if expression.left.is_true() => { + self.encode_expression_exhale( + statements, + *expression.right, + position, + expression_evaluation_state_label, + )?; + } + vir_low::BinaryOpKind::Implies => { + unimplemented!("Merge the heap versions in the commented out code below."); + // let guard = self.encode_pure_expression( + // statements, + // *expression.left, + // Some(expression_evaluation_state_label.to_string()), + // position, + // )?; + // let mut body = Vec::new(); + // self.encode_expression_exhale( + // &mut body, + // *expression.right, + // position, + // expression_evaluation_state_label, + // )?; + // // FIXME: Permission mask and heap versions need to be + // // unified after the branch merge. + // statements.push(vir_low::Statement::conditional( + // guard, + // body, + // Vec::new(), + // position, + // )) + } + _ => unreachable!("expression: {}", expression), + }, + vir_low::Expression::Conditional(_) => todo!(), + vir_low::Expression::FuncApp(_) => todo!(), + vir_low::Expression::DomainFuncApp(_) => todo!(), + _ => { + unimplemented!("expression: {:?}", expression); + } + } + } + Ok(()) + } + + fn encode_expression_exhale_predicate( + &mut self, + statements: &mut Vec, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + position: vir_low::Position, + expression_evaluation_state_label: Option, + operations: PermissionMaskOperations, + ) -> SpannedEncodingResult<()> + where + PermissionMaskOperations: TPermissionMaskOperations, + { + statements.push(vir_low::Statement::comment(format!( + "exhale-predicate {}", + predicate + ))); + // assert perm

(r1, r2, v_old) >= p + statements.push(vir_low::Statement::assert( + operations.perm_old_greater_equals(&predicate.permission), + position, // FIXME: use position of expression.permission with proper ErrorCtxt. + )); + let perm_new_value = operations.perm_old_sub(&predicate.permission); + // assume perm

(r1, r2, v_new) == perm

(r1, r2, v_old) - p + statements.push(vir_low::Statement::assume( + vir_low::Expression::equals(operations.perm_new(), perm_new_value.clone()), + position, // FIXME: use position of expression.permission with proper ErrorCtxt. + )); + // assume forall arg1: Ref, arg2: Ref :: + // {perm

(arg1, arg2, v_new)} + // !(r1 == arg1 && r2 == arg2) ==> + // perm

(arg1, arg2, v_new) == perm

(arg1, arg2, v_old) + self.encode_perm_unchanged_quantifier( + statements, + &predicate, + operations.old_permission_mask_version(), + operations.new_permission_mask_version(), + position, + expression_evaluation_state_label, + perm_new_value, + )?; + // assume forall arg1: Ref, arg2: Ref :: + // {heap

(arg1, arg2, vh_new)} + // perm

(arg1, arg2, v_new) > 0 ==> + // heap

(arg1, arg2, vh_new) == heap

(arg1, arg2, vh_old) + self.encode_heap_unchanged_quantifier( + statements, + &predicate, + operations.new_permission_mask_version(), + position, + )?; + Ok(()) + } + + pub(super) fn encode_expression_inhale( + &mut self, + statements: &mut Vec, + expression: vir_low::Expression, + position: vir_low::Position, + expression_evaluation_state_label: Option, + ) -> SpannedEncodingResult<()> { + if expression.is_pure() { + let expression = self.encode_pure_expression( + statements, + expression, + expression_evaluation_state_label, + position, + )?; + statements.push(vir_low::Statement::assume(expression, position)); + } else { + match expression { + vir_low::Expression::PredicateAccessPredicate(expression) => { + match self.get_predicate_permission_mask_kind(&expression.name)? { + PredicatePermissionMaskKind::AliasedWholeBool + | PredicatePermissionMaskKind::AliasedFractionalBool => { + let operations = + PermissionMaskOperations::::new( + self, + statements, + &expression, + expression_evaluation_state_label.clone(), + position, + )?; + self.encode_expression_inhale_predicate( + statements, + &expression, + position, + expression_evaluation_state_label, + operations, + )? + } + PredicatePermissionMaskKind::AliasedFractionalBoundedPerm => { + let operations = PermissionMaskOperations::< + PermissionMaskKindAliasedFractionalBoundedPerm, + >::new( + self, + statements, + &expression, + expression_evaluation_state_label.clone(), + position, + )?; + self.encode_expression_inhale_predicate( + statements, + &expression, + position, + expression_evaluation_state_label, + operations, + )? + } + } + } + vir_low::Expression::Unfolding(_) => todo!(), + vir_low::Expression::LabelledOld(_) => todo!(), + vir_low::Expression::BinaryOp(expression) => match expression.op_kind { + vir_low::BinaryOpKind::And => { + self.encode_expression_inhale( + statements, + *expression.left, + position, + expression_evaluation_state_label.clone(), + )?; + self.encode_expression_inhale( + statements, + *expression.right, + position, + expression_evaluation_state_label, + )?; + } + vir_low::BinaryOpKind::Implies => { + let guard = self.encode_pure_expression( + statements, + *expression.left, + expression_evaluation_state_label.clone(), + position, + )?; + let mut body = Vec::new(); + self.encode_expression_inhale( + &mut body, + *expression.right, + position, + expression_evaluation_state_label, + )?; + statements.push(vir_low::Statement::conditional( + guard, + body, + Vec::new(), + position, + )) + } + _ => unreachable!("expression: {}", expression), + }, + vir_low::Expression::Conditional(_) => todo!(), + vir_low::Expression::FuncApp(_) => todo!(), + vir_low::Expression::DomainFuncApp(_) => todo!(), + _ => { + unimplemented!("expression: {:?}", expression); + } + } + } + Ok(()) + } + + fn encode_expression_inhale_predicate( + &mut self, + statements: &mut Vec, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + position: vir_low::Position, + expression_evaluation_state_label: Option, + operations: PermissionMaskOperations, + ) -> SpannedEncodingResult<()> + where + PermissionMaskOperations: TPermissionMaskOperations, + { + statements.push(vir_low::Statement::comment(format!( + "inhale-predicate {}", + predicate + ))); + if operations.can_assume_old_permission_is_none(&predicate.permission) { + statements.push(vir_low::Statement::assume( + operations.perm_old_equal_none(), + position, // FIXME: use position of expression.permission with proper ErrorCtxt. + )); + } + let perm_new_value = operations.perm_old_add(&predicate.permission); + // assume perm

(r1, r2, v_new) == perm

(r1, r2, v_old) + p + statements.push(vir_low::Statement::assume( + vir_low::Expression::equals(operations.perm_new(), perm_new_value.clone()), + position, // FIXME: use position of expression.permission with proper ErrorCtxt. + )); + // assume forall arg1: Ref, arg2: Ref :: + // {perm

(arg1, arg2, v_new)} + // !(r1 == arg1 && r2 == arg2) ==> + // perm

(arg1, arg2, v_new) == perm

(arg1, arg2, v_old) + self.encode_perm_unchanged_quantifier( + statements, + &predicate, + operations.old_permission_mask_version(), + operations.new_permission_mask_version(), + position, + expression_evaluation_state_label, + perm_new_value, + )?; + Ok(()) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/heap/mod.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/heap/mod.rs new file mode 100644 index 00000000000..3fa3573c907 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/heap/mod.rs @@ -0,0 +1,152 @@ +use super::HeapEncoder; +use crate::encoder::errors::SpannedEncodingResult; +use vir_crate::{ + common::expression::{BinaryOperationHelpers, QuantifierHelpers}, + low::{self as vir_low}, +}; + +impl<'p, 'v: 'p, 'tcx: 'v> HeapEncoder<'p, 'v, 'tcx> { + fn heap_version_type(&self) -> vir_low::Type { + vir_low::Type::domain("HeapVersion".to_string()) + } + + pub(super) fn heap_function_name(&self, predicate_name: &str) -> String { + format!("heap${}", predicate_name) + } + + pub(super) fn heap_call( + &mut self, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + mut arguments: Vec, + heap_version: vir_low::Expression, + ) -> SpannedEncodingResult> { + let call = + if let Some(snapshot_type) = self.get_snapshot_type_for_predicate(&predicate.name) { + let heap_function_name = self.heap_function_name(&predicate.name); + arguments.push(heap_version); + Some(vir_low::Expression::domain_function_call( + "HeapFunctions", + heap_function_name, + arguments, + snapshot_type, + )) + } else { + None + }; + Ok(call) + } + + pub(super) fn heap_call_for_predicate_def( + &mut self, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + heap_version: vir_low::Expression, + ) -> SpannedEncodingResult> { + let arguments = self.get_predicate_parameters_as_arguments(&predicate.name)?; + self.heap_call(predicate, arguments, heap_version) + } + + pub(super) fn encode_heap_unchanged_quantifier( + &mut self, + statements: &mut Vec, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + new_permission_mask: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult<()> { + let heap_version_old = self.get_current_heap_version_for(&predicate.name)?; + if let Some(heap_old) = + self.heap_call_for_predicate_def(&predicate, heap_version_old.clone())? + { + let heap_version_new = self.get_new_heap_version_for(&predicate.name, position)?; + let heap_new = self + .heap_call_for_predicate_def(&predicate, heap_version_new.clone())? + .unwrap(); + let predicate_parameters = self.get_predicate_parameters(&predicate.name).to_owned(); + let triggers = vec![vir_low::Trigger::new(vec![heap_new.clone()])]; + let guard = self.positive_permission_mask_call_for_predicate_def( + predicate, + new_permission_mask.clone(), + )?; + let body = vir_low::Expression::implies( + guard, + vir_low::Expression::equals(heap_old, heap_new), + ); + statements.push(vir_low::Statement::assume( + vir_low::Expression::forall(predicate_parameters, triggers, body), + position, + )); + } + Ok(()) + } + + pub(super) fn get_current_heap_version_for( + &mut self, + predicate_name: &str, + ) -> SpannedEncodingResult { + let variable_name = self.heap_names.get(predicate_name).unwrap(); + let version = self.ssa_state.current_variable_version(variable_name); + let ty = self.heap_version_type(); + Ok(self + .new_variables + .create_variable(variable_name, ty, version)? + .into()) + } + + fn get_new_heap_version_for( + &mut self, + predicate_name: &str, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let variable_name = self.heap_names.get(predicate_name).unwrap(); + let ty = self.heap_version_type(); + let version = self + .ssa_state + .new_variable_version(variable_name, &ty, position); + Ok(self + .new_variables + .create_variable(variable_name, ty, version)? + .into()) + } + + pub(super) fn get_heap_version_at_label( + &mut self, + predicate_name: &str, + label: &str, + ) -> SpannedEncodingResult { + let variable_name = self.heap_names.get(predicate_name).unwrap(); + let version = self + .ssa_state + .variable_version_at_label(variable_name, label); + let ty = self.heap_version_type(); + Ok(self + .new_variables + .create_variable(variable_name, ty, version)? + .into()) + } + + pub(super) fn generate_heap_domains( + &self, + domains: &mut Vec, + ) -> SpannedEncodingResult<()> { + let heap_version_domain = vir_low::DomainDecl::new("HeapVersion", Vec::new(), Vec::new()); + domains.push(heap_version_domain); + let mut functions = Vec::new(); + for predicate in self.predicates.iter_decls() { + if let Some(snapshot_type) = self.get_snapshot_type_for_predicate(&predicate.name) { + let mut parameters = predicate.parameters.clone(); + parameters.push(vir_low::VariableDecl::new( + "version", + self.heap_version_type(), + )); + functions.push(vir_low::DomainFunctionDecl::new( + self.heap_function_name(&predicate.name), + false, + parameters, + snapshot_type, + )); + } + } + let heap_domain = vir_low::DomainDecl::new("HeapFunctions", functions, Vec::new()); + domains.push(heap_domain); + Ok(()) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/mod.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/mod.rs new file mode 100644 index 00000000000..a5d65f6816c --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/mod.rs @@ -0,0 +1,153 @@ +mod statements; +mod pure_expressions; +mod heap; +mod effects; +mod predicates; +mod permission_mask; + +use self::predicates::Predicates; +use super::variable_declarations::VariableDeclarations; +use crate::encoder::{ + errors::{ErrorCtxt, SpannedEncodingError, SpannedEncodingResult}, + middle::core_proof::{ + lowerer::LoweringResult, + predicates::PredicateInfo, + snapshots::{AllVariablesMap, VariableVersionMap}, + }, + mir::errors::ErrorInterface, + Encoder, +}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::collections::{BTreeMap, BTreeSet}; +use vir_crate::{ + common::{ + cfg::Cfg, + check_mode::CheckMode, + expression::{ + BinaryOperationHelpers, ExpressionIterator, QuantifierHelpers, UnaryOperationHelpers, + }, + graphviz::ToGraphviz, + }, + low::{ + self as vir_low, + expression::visitors::{ExpressionFallibleFolder, ExpressionFolder}, + operations::ty::Typed, + }, + middle as vir_mid, +}; + +pub(super) struct HeapEncoder<'p, 'v: 'p, 'tcx: 'v> { + encoder: &'p mut Encoder<'v, 'tcx>, + new_variables: VariableDeclarations, + predicates: Predicates<'p>, + functions: FxHashMap, + ssa_state: vir_low::ssa::SSAState, + permission_mask_names: FxHashMap, + heap_names: FxHashMap, + /// A counter used for generating fresh labels. + fresh_label_counter: u64, +} + +impl<'p, 'v: 'p, 'tcx: 'v> HeapEncoder<'p, 'v, 'tcx> { + pub(super) fn new( + encoder: &'p mut Encoder<'v, 'tcx>, + predicates: &'p [vir_low::PredicateDecl], + predicate_info: BTreeMap, + functions: &'p [vir_low::FunctionDecl], + ) -> Self { + Self { + encoder, + new_variables: Default::default(), + permission_mask_names: predicates + .iter() + .map(|predicate| { + let mask_name = format!("perm${}", predicate.name); + (predicate.name.clone(), mask_name) + }) + .collect(), + heap_names: predicates + .iter() + .map(|predicate| { + let heap_name = format!("heap${}", predicate.name); + (predicate.name.clone(), heap_name) + }) + .collect(), + predicates: Predicates::new(predicates, predicate_info), + functions: functions + .iter() + .map(|function| (function.name.clone(), function)) + .collect(), + ssa_state: Default::default(), + fresh_label_counter: 0, + } + } + + pub(super) fn encode_statement( + &mut self, + statements: &mut Vec, + statement: vir_low::Statement, + ) -> SpannedEncodingResult<()> { + self.encode_statement_internal(statements, statement) + } + + pub(super) fn prepare_new_current_block( + &mut self, + label: &vir_low::Label, + predecessors: &BTreeMap>, + basic_block_edges: &mut BTreeMap< + vir_low::Label, + BTreeMap>, + >, + ) -> SpannedEncodingResult<()> { + self.ssa_state + .prepare_new_current_block(label, predecessors, basic_block_edges); + Ok(()) + } + + pub(super) fn finish_current_block( + &mut self, + label: vir_low::Label, + ) -> SpannedEncodingResult<()> { + self.ssa_state.finish_current_block(label); + Ok(()) + } + + pub(super) fn generate_init_permissions_to_zero( + &mut self, + position: vir_low::Position, + ) -> SpannedEncodingResult> { + self.generate_init_permissions_to_zero_internal(position) + } + + pub(super) fn generate_necessary_domains( + &self, + ) -> SpannedEncodingResult> { + let mut domains = Vec::new(); + self.generate_permission_mask_domains(&mut domains)?; + self.generate_heap_domains(&mut domains)?; + Ok(domains) + } + + pub(super) fn create_variable( + &mut self, + variable_name: &str, + ty: vir_low::Type, + version: u64, + ) -> SpannedEncodingResult { + self.new_variables + .create_variable(variable_name, ty, version) + } + + pub(super) fn take_variables(&mut self) -> FxHashSet { + self.new_variables.take_variables() + } + + pub(super) fn encoder(&mut self) -> &mut Encoder<'v, 'tcx> { + &mut self.encoder + } + + fn fresh_label(&mut self) -> String { + self.fresh_label_counter += 1; + format!("heap_label${}", self.fresh_label_counter) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/permission_mask/mod.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/permission_mask/mod.rs new file mode 100644 index 00000000000..f3be86d3316 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/permission_mask/mod.rs @@ -0,0 +1,323 @@ +mod operations; + +use super::HeapEncoder; +use crate::encoder::errors::SpannedEncodingResult; +use vir_crate::{ + common::expression::{BinaryOperationHelpers, ExpressionIterator, QuantifierHelpers}, + low::{self as vir_low}, +}; + +pub(super) use self::operations::{ + PermissionMaskKind, PermissionMaskKindAliasedBool, + PermissionMaskKindAliasedFractionalBoundedPerm, PermissionMaskOperations, + TPermissionMaskOperations, +}; + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub(super) enum PredicatePermissionMaskKind { + /// The permission amounts can be either full or none. + AliasedWholeBool, + /// The permission amounts can be fractional, but we are always guaranteed + /// to operate on the same amount. Therefore, we do not need to perform + /// arithmetic operations on permissions and can use a boolean permission + /// mask with a third parameter that specifies the permission amount that we + /// are currently tracking. + AliasedFractionalBool, + /// The permission amounts can be fractional and we need to perform + /// arithmetic operations on them. However, the permission amount is bounded + /// by `write` and, therefore, when inhaling `write` we can assume that the + /// current amount is `none`. + AliasedFractionalBoundedPerm, +} + +impl<'p, 'v: 'p, 'tcx: 'v> HeapEncoder<'p, 'v, 'tcx> { + fn perm_version_type(&self) -> vir_low::Type { + vir_low::Type::domain("PermVersion".to_string()) + } + + fn mask_function_return_type(&self, kind: PredicatePermissionMaskKind) -> vir_low::Type { + match kind { + PredicatePermissionMaskKind::AliasedWholeBool + | PredicatePermissionMaskKind::AliasedFractionalBool => vir_low::Type::Bool, + PredicatePermissionMaskKind::AliasedFractionalBoundedPerm => vir_low::Type::Perm, + } + } + + fn no_permission(&self, kind: PredicatePermissionMaskKind) -> vir_low::Expression { + match kind { + PredicatePermissionMaskKind::AliasedWholeBool + | PredicatePermissionMaskKind::AliasedFractionalBool => false.into(), + PredicatePermissionMaskKind::AliasedFractionalBoundedPerm => { + vir_low::Expression::no_permission() + } + } + } + + fn permission_amount_parameter( + &self, + kind: PredicatePermissionMaskKind, + ) -> Option { + match kind { + PredicatePermissionMaskKind::AliasedFractionalBool => Some(vir_low::VariableDecl::new( + "permission_amount", + vir_low::Type::Perm, + )), + PredicatePermissionMaskKind::AliasedWholeBool + | PredicatePermissionMaskKind::AliasedFractionalBoundedPerm => None, + } + } + + fn permission_mask_function_name(&self, predicate_name: &str) -> String { + format!("perm${}", predicate_name) + } + + pub(super) fn get_current_permission_mask_for( + &mut self, + predicate_name: &str, + ) -> SpannedEncodingResult { + let variable_name = self.permission_mask_names.get(predicate_name).unwrap(); + let version = self.ssa_state.current_variable_version(variable_name); + let ty = self.perm_version_type(); + Ok(self + .new_variables + .create_variable(variable_name, ty, version)? + .into()) + } + + pub(super) fn get_new_permission_mask_for( + &mut self, + predicate_name: &str, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let variable_name = self.permission_mask_names.get(predicate_name).unwrap(); + let ty = self.perm_version_type(); + let version = self + .ssa_state + .new_variable_version(variable_name, &ty, position); + Ok(self + .new_variables + .create_variable(variable_name, ty, version)? + .into()) + } + + pub(super) fn permission_mask_call( + &mut self, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + mut arguments: Vec, + permission_mask_version: vir_low::Expression, + ) -> SpannedEncodingResult { + let perm_function_name = self.permission_mask_function_name(&predicate.name); + arguments.push(permission_mask_version); + let kind = self.get_predicate_permission_mask_kind(&predicate.name)?; + if kind == PredicatePermissionMaskKind::AliasedFractionalBool { + arguments.push((*predicate.permission).clone()); + } + let return_type = self.mask_function_return_type(kind); + Ok(vir_low::Expression::domain_function_call( + "PermFunctions", + perm_function_name.clone(), + arguments, + return_type, + )) + } + + pub(super) fn permission_mask_call_for_predicate_use( + &mut self, + statements: &mut Vec, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + permission_mask: vir_low::Expression, + expression_evaluation_state_label: Option, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let arguments = self.purify_predicate_arguments( + statements, + predicate, + expression_evaluation_state_label, + position, + )?; + self.permission_mask_call(predicate, arguments, permission_mask) + } + + pub(super) fn permission_mask_call_for_predicate_def( + &mut self, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + permission_mask: vir_low::Expression, + ) -> SpannedEncodingResult { + let arguments = self.get_predicate_parameters_as_arguments(&predicate.name)?; + self.permission_mask_call(predicate, arguments, permission_mask) + } + + pub(super) fn positive_permission_mask_call_for_predicate_def( + &mut self, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + permission_mask: vir_low::Expression, + ) -> SpannedEncodingResult { + let perm_call = self.permission_mask_call_for_predicate_def(predicate, permission_mask)?; + let kind = self.get_predicate_permission_mask_kind(&predicate.name)?; + let positivity_check = match kind { + PredicatePermissionMaskKind::AliasedWholeBool + | PredicatePermissionMaskKind::AliasedFractionalBool => perm_call, + PredicatePermissionMaskKind::AliasedFractionalBoundedPerm => { + vir_low::Expression::greater_than(perm_call, vir_low::Expression::no_permission()) + } + }; + Ok(positivity_check) + } + + pub(super) fn encode_perm_unchanged_quantifier( + &mut self, + statements: &mut Vec, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + old_permission_mask_version: vir_low::Expression, + new_permission_mask_version: vir_low::Expression, + position: vir_low::Position, + expression_evaluation_state_label: Option, + perm_new_value: vir_low::Expression, + ) -> SpannedEncodingResult<()> { + let perm_new = self.permission_mask_call_for_predicate_def( + &predicate, + new_permission_mask_version.clone(), + )?; + let perm_old = self.permission_mask_call_for_predicate_def( + &predicate, + old_permission_mask_version.clone(), + )?; + let predicate_parameters = self.get_predicate_parameters(&predicate.name).to_owned(); + let predicate_arguments = self.get_predicate_parameters_as_arguments(&predicate.name)?; + let arguments = self.purify_predicate_arguments( + statements, + predicate, + expression_evaluation_state_label, + position, + )?; + let triggers = vec![vir_low::Trigger::new(vec![perm_new.clone()])]; + let guard = predicate_arguments + .into_iter() + .zip(arguments) + .map(|(parameter, argument)| vir_low::Expression::equals(parameter, argument)) + .conjoin(); + let body = vir_low::Expression::equals( + perm_new, + vir_low::Expression::conditional_no_pos(guard, perm_new_value, perm_old), + ); + statements.push(vir_low::Statement::assume( + vir_low::Expression::forall(predicate_parameters, triggers, body), + position, + )); + Ok(()) + } + + pub(super) fn generate_permission_mask_domains( + &self, + domains: &mut Vec, + ) -> SpannedEncodingResult<()> { + let perm_version_domain = vir_low::DomainDecl::new("PermVersion", Vec::new(), Vec::new()); + domains.push(perm_version_domain); + let mut functions = Vec::new(); + let mut axioms = Vec::new(); + for predicate in self.predicates.iter_decls() { + let mut parameters = predicate.parameters.clone(); + parameters.push(vir_low::VariableDecl::new( + "version", + self.perm_version_type(), + )); + let function_name = self.permission_mask_function_name(&predicate.name); + let kind = self.get_predicate_permission_mask_kind(&predicate.name)?; + parameters.extend(self.permission_amount_parameter(kind)); + let return_type = self.mask_function_return_type(kind); + let function = vir_low::DomainFunctionDecl::new( + function_name.clone(), + false, + parameters.clone(), + return_type, + ); + functions.push(function); + if matches!( + kind, + PredicatePermissionMaskKind::AliasedFractionalBoundedPerm + ) { + let function_call = vir_low::Expression::domain_func_app_no_pos( + "PermFunctions".to_string(), + function_name.clone(), + parameters + .clone() + .into_iter() + .map(|parameter| parameter.into()) + .collect(), + parameters.clone(), + vir_low::Type::Perm, + ); + use vir_low::macros::*; + let body = vir_low::Expression::forall( + parameters, + vec![vir_low::Trigger::new(vec![function_call.clone()])], + expr! { + ([vir_low::Expression::no_permission()] <= [function_call.clone()]) && + ([function_call] <= [vir_low::Expression::full_permission()]) + }, + ); + let axiom = + vir_low::DomainAxiomDecl::new(None, format!("{}$bounds", function_name), body); + axioms.push(axiom); + } + } + let perm_domain = vir_low::DomainDecl::new("PermFunctions", functions, axioms); + domains.push(perm_domain); + Ok(()) + } + + pub(super) fn generate_init_permissions_to_zero_internal( + &mut self, + position: vir_low::Position, + ) -> SpannedEncodingResult> { + let mut statements = Vec::new(); + for predicate in self.predicates.iter_decls() { + let initial_permission_mask_name = + self.permission_mask_names.get(&predicate.name).unwrap(); + let initial_permission_mask_version = self + .ssa_state + .initial_variable_version(initial_permission_mask_name); + let initial_permission_mask_ty = self.perm_version_type(); + let initial_permission_mask: vir_low::Expression = self + .new_variables + .create_variable( + initial_permission_mask_name, + initial_permission_mask_ty, + initial_permission_mask_version, + )? + .into(); + let kind = self.get_predicate_permission_mask_kind(&predicate.name)?; + let mut arguments: Vec<_> = predicate + .parameters + .iter() + .map(|parameter| parameter.clone().into()) + .collect(); + arguments.push(initial_permission_mask.clone()); + arguments.extend( + self.permission_amount_parameter(kind) + .map(|parameter| parameter.into()), + ); + // if matches!(kind, PredicatePermissionMaskKind::AliasedFractionalBool) { + // arguments.push(.into()); + // } + let perm_function_name = self.permission_mask_function_name(&predicate.name); + let return_type = self.mask_function_return_type(kind); + let perm = vir_low::Expression::domain_function_call( + "PermFunctions", + perm_function_name.clone(), + arguments, + return_type, + ); + let no_permission = self.no_permission(kind); + let triggers = vec![vir_low::Trigger::new(vec![perm.clone()])]; + let body = vir_low::Expression::equals(perm, no_permission); + let mut parameters = predicate.parameters.clone(); + parameters.extend(self.permission_amount_parameter(kind)); + statements.push(vir_low::Statement::assume( + vir_low::Expression::forall(parameters, triggers, body), + position, + )); + } + Ok(statements) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/permission_mask/operations.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/permission_mask/operations.rs new file mode 100644 index 00000000000..179c0c7c564 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/permission_mask/operations.rs @@ -0,0 +1,169 @@ +use crate::encoder::{ + errors::SpannedEncodingResult, + middle::core_proof::transformations::custom_heap_encoding::heap_encoder::HeapEncoder, +}; +use vir_crate::{ + common::expression::BinaryOperationHelpers, + low::{self as vir_low}, +}; + +pub(in super::super) trait PermissionMaskKind {} +pub(in super::super) struct PermissionMaskKindAliasedFractionalBoundedPerm {} +impl PermissionMaskKind for PermissionMaskKindAliasedFractionalBoundedPerm {} +pub(in super::super) struct PermissionMaskKindAliasedBool {} +impl PermissionMaskKind for PermissionMaskKindAliasedBool {} + +pub(in super::super) struct PermissionMaskOperations { + _kind: std::marker::PhantomData, + old_permission_mask_version: vir_low::Expression, + new_permission_mask_version: vir_low::Expression, + perm_old: vir_low::Expression, + perm_new: vir_low::Expression, +} + +impl PermissionMaskOperations { + pub(in super::super) fn new<'p, 'v: 'p, 'tcx: 'v>( + heap_encoder: &mut HeapEncoder<'p, 'v, 'tcx>, + statements: &mut Vec, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + expression_evaluation_state_label: Option, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let old_permission_mask_version = + heap_encoder.get_current_permission_mask_for(&predicate.name)?; + let new_permission_mask_version = + heap_encoder.get_new_permission_mask_for(&predicate.name, position)?; + let perm_old = heap_encoder.permission_mask_call_for_predicate_use( + statements, + predicate, + old_permission_mask_version.clone(), + expression_evaluation_state_label.clone(), + position, + )?; + let perm_new = heap_encoder.permission_mask_call_for_predicate_use( + statements, + predicate, + new_permission_mask_version.clone(), + expression_evaluation_state_label, + position, + )?; + Ok(Self { + _kind: std::marker::PhantomData, + old_permission_mask_version, + new_permission_mask_version, + perm_old, + perm_new, + }) + } + + pub(in super::super) fn perm_new(&self) -> vir_low::Expression { + self.perm_new.clone() + } + + pub(in super::super) fn old_permission_mask_version(&self) -> vir_low::Expression { + self.old_permission_mask_version.clone() + } + + pub(in super::super) fn new_permission_mask_version(&self) -> vir_low::Expression { + self.new_permission_mask_version.clone() + } +} + +pub(in super::super) trait TPermissionMaskOperations { + fn perm_old_greater_equals( + &self, + permission_amount: &vir_low::Expression, + ) -> vir_low::Expression; + + fn perm_old_positive(&self) -> vir_low::Expression; + + fn perm_old_sub(&self, permission_amount: &vir_low::Expression) -> vir_low::Expression; + + fn perm_old_add(&self, permission_amount: &vir_low::Expression) -> vir_low::Expression; + + fn perm_old_equal_none(&self) -> vir_low::Expression; + + fn can_assume_old_permission_is_none(&self, permission_amount: &vir_low::Expression) -> bool; +} + +impl TPermissionMaskOperations + for PermissionMaskOperations +{ + fn perm_old_greater_equals( + &self, + permission_amount: &vir_low::Expression, + ) -> vir_low::Expression { + vir_low::Expression::greater_equals(self.perm_old.clone(), permission_amount.clone()) + } + + fn perm_old_positive(&self) -> vir_low::Expression { + vir_low::Expression::greater_equals( + self.perm_old.clone(), + vir_low::Expression::no_permission(), + ) + } + + fn perm_old_sub(&self, permission_amount: &vir_low::Expression) -> vir_low::Expression { + if permission_amount.is_full_permission() { + vir_low::Expression::no_permission() + } else { + vir_low::Expression::perm_binary_op_no_pos( + vir_low::ast::expression::PermBinaryOpKind::Sub, + self.perm_old.clone(), + permission_amount.clone(), + ) + } + } + + fn perm_old_add(&self, permission_amount: &vir_low::Expression) -> vir_low::Expression { + if permission_amount.is_full_permission() { + vir_low::Expression::full_permission() + } else { + vir_low::Expression::perm_binary_op_no_pos( + vir_low::ast::expression::PermBinaryOpKind::Add, + self.perm_old.clone(), + permission_amount.clone(), + ) + } + } + + fn perm_old_equal_none(&self) -> vir_low::Expression { + vir_low::Expression::equals(self.perm_old.clone(), vir_low::Expression::no_permission()) + } + + fn can_assume_old_permission_is_none(&self, permission_amount: &vir_low::Expression) -> bool { + permission_amount.is_full_permission() + } +} + +impl TPermissionMaskOperations for PermissionMaskOperations { + fn perm_old_greater_equals( + &self, + permission_amount: &vir_low::Expression, + ) -> vir_low::Expression { + assert!(permission_amount.is_full_permission()); + self.perm_old.clone() + } + + fn perm_old_positive(&self) -> vir_low::Expression { + self.perm_old.clone() + } + + fn perm_old_sub(&self, permission_amount: &vir_low::Expression) -> vir_low::Expression { + assert!(permission_amount.is_full_permission()); + false.into() + } + + fn perm_old_add(&self, permission_amount: &vir_low::Expression) -> vir_low::Expression { + assert!(permission_amount.is_full_permission()); + true.into() + } + + fn perm_old_equal_none(&self) -> vir_low::Expression { + vir_low::Expression::equals(self.perm_old.clone(), false.into()) + } + + fn can_assume_old_permission_is_none(&self, _: &vir_low::Expression) -> bool { + true + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/predicates.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/predicates.rs new file mode 100644 index 00000000000..6ebd5510a09 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/predicates.rs @@ -0,0 +1,159 @@ +use super::{permission_mask::PredicatePermissionMaskKind, HeapEncoder}; +use crate::encoder::{ + errors::SpannedEncodingResult, middle::core_proof::predicates::PredicateInfo, +}; +use rustc_hash::FxHashMap; +use std::collections::BTreeMap; +use vir_crate::low::{self as vir_low}; + +pub(super) struct Predicates<'p> { + predicate_decls: FxHashMap, + snapshot_functions_to_predicates: BTreeMap, + predicates_to_snapshot_types: BTreeMap, +} + +impl<'p> Predicates<'p> { + pub(super) fn new( + predicate_decls: &'p [vir_low::PredicateDecl], + predicate_info: BTreeMap, + ) -> Self { + let mut snapshot_functions_to_predicates = BTreeMap::new(); + let mut predicates_to_snapshot_types = BTreeMap::new(); + for ( + predicate_name, + PredicateInfo { + snapshot_function_name, + snapshot_type, + }, + ) in predicate_info + { + snapshot_functions_to_predicates.insert(snapshot_function_name, predicate_name.clone()); + predicates_to_snapshot_types.insert(predicate_name, snapshot_type); + } + Self { + predicate_decls: predicate_decls + .iter() + .map(|predicate| (predicate.name.clone(), predicate)) + .collect(), + snapshot_functions_to_predicates, + predicates_to_snapshot_types, + } + } + + pub(super) fn iter_decls(&self) -> impl Iterator { + self.predicate_decls.values().cloned() + } +} + +impl<'p, 'v: 'p, 'tcx: 'v> HeapEncoder<'p, 'v, 'tcx> { + pub(super) fn get_predicate_decl( + &self, + predicate_name: &str, + ) -> SpannedEncodingResult<&'p vir_low::PredicateDecl> { + let decl = self + .predicates + .predicate_decls + .get(predicate_name) + .cloned() + .unwrap(); + Ok(decl) + } + + pub(super) fn get_predicate_parameters( + &self, + predicate_name: &str, + ) -> &[vir_low::VariableDecl] { + self.predicates + .predicate_decls + .get(predicate_name) + .unwrap() + .parameters + .as_slice() + } + + pub(super) fn get_predicate_parameters_as_arguments( + &mut self, + predicate_name: &str, + ) -> SpannedEncodingResult> { + let predicate_parameters = self.get_predicate_parameters(predicate_name).to_owned(); + Ok(predicate_parameters + .iter() + .map(|parameter| parameter.clone().into()) + .collect()) + } + + pub(super) fn get_predicate_name_for_function<'a>( + &'a self, + function_name: &str, + ) -> SpannedEncodingResult { + let function = self.functions[function_name]; + let predicate_name = match function.kind { + vir_low::FunctionKind::MemoryBlockBytes => todo!(), + vir_low::FunctionKind::CallerFor => todo!(), + vir_low::FunctionKind::Snap => { + self.predicates.snapshot_functions_to_predicates[function_name].clone() + } + }; + Ok(predicate_name) + } + + pub(super) fn get_snapshot_type_for_predicate( + &self, + predicate_name: &str, + ) -> Option { + let predicate = self.predicates.predicate_decls[predicate_name]; + match predicate.kind { + vir_low::PredicateKind::MemoryBlock => { + use vir_low::macros::*; + Some(ty!(Bytes)) + } + vir_low::PredicateKind::Owned => Some( + self.predicates + .predicates_to_snapshot_types + .get(predicate_name) + .unwrap_or_else(|| unreachable!("predicate not found: {}", predicate_name)) + .clone(), + ), + vir_low::PredicateKind::WithoutSnapshotWhole + | vir_low::PredicateKind::WithoutSnapshotFrac => None, + } + } + + pub(super) fn purify_predicate_arguments( + &mut self, + statements: &mut Vec, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + expression_evaluation_state_label: Option, + position: vir_low::Position, + ) -> SpannedEncodingResult> { + let mut arguments = Vec::new(); + for argument in &predicate.arguments { + arguments.push(self.encode_pure_expression( + statements, + argument.clone(), + expression_evaluation_state_label.clone(), + position, + )?); + } + Ok(arguments) + } + + pub(super) fn get_predicate_permission_mask_kind( + &self, + predicate_name: &str, + ) -> SpannedEncodingResult { + let predicate_decl = self.get_predicate_decl(predicate_name)?; + let mask_kind = match predicate_decl.kind { + vir_low::PredicateKind::MemoryBlock | vir_low::PredicateKind::Owned => { + PredicatePermissionMaskKind::AliasedFractionalBoundedPerm + } + vir_low::PredicateKind::WithoutSnapshotFrac => { + PredicatePermissionMaskKind::AliasedFractionalBoundedPerm + } + vir_low::PredicateKind::WithoutSnapshotWhole => { + PredicatePermissionMaskKind::AliasedWholeBool + } + }; + Ok(mask_kind) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/pure_expressions.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/pure_expressions.rs new file mode 100644 index 00000000000..99be1520d14 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/pure_expressions.rs @@ -0,0 +1,146 @@ +use super::HeapEncoder; +use crate::encoder::errors::{SpannedEncodingError, SpannedEncodingResult}; +use vir_crate::{ + common::expression::{BinaryOperationHelpers, ExpressionIterator, UnaryOperationHelpers}, + low::{self as vir_low, expression::visitors::ExpressionFallibleFolder}, +}; + +impl<'p, 'v: 'p, 'tcx: 'v> HeapEncoder<'p, 'v, 'tcx> { + pub(super) fn encode_pure_expression( + &mut self, + statements: &mut Vec, + expression: vir_low::Expression, + expression_evaluation_state_label: Option, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let mut purifier = Purifier { + expression_evaluation_state_label, + heap_encoder: self, + statements, + path_condition: Vec::new(), + position, + }; + purifier.fallible_fold_expression(expression) + } +} + +struct Purifier<'e, 'p, 'v: 'p, 'tcx: 'v> { + /// The state in which we should evaluate the heap expressions. If `None`, + /// takes the latest heap. + expression_evaluation_state_label: Option, + heap_encoder: &'e mut HeapEncoder<'p, 'v, 'tcx>, + statements: &'e mut Vec, + path_condition: Vec, + position: vir_low::Position, +} + +impl<'e, 'p, 'v: 'p, 'tcx: 'v> ExpressionFallibleFolder for Purifier<'e, 'p, 'v, 'tcx> { + type Error = SpannedEncodingError; + + fn fallible_fold_func_app_enum( + &mut self, + func_app: vir_low::expression::FuncApp, + ) -> Result { + let function = self.heap_encoder.functions[&func_app.function_name]; + assert_eq!(function.parameters.len(), func_app.arguments.len()); + let mut arguments = func_app + .arguments + .into_iter() + .map(|argument| self.fallible_fold_expression(argument)) + .collect::, _>>()?; + let path_condition = self.path_condition.iter().cloned().conjoin(); + let replacements = function.parameters.iter().zip(arguments.iter()).collect(); + let pres = function + .pres + .iter() + .cloned() + .conjoin() + .substitute_variables(&replacements); + let pres = self.fallible_fold_expression(pres)?; + let assert_precondition = vir_low::Expression::implies(path_condition, pres); + self.heap_encoder.encode_function_precondition_assert( + self.statements, + assert_precondition, + self.position, + self.expression_evaluation_state_label.clone(), + )?; + match function.kind { + vir_low::FunctionKind::MemoryBlockBytes => todo!(), + vir_low::FunctionKind::CallerFor => todo!(), + vir_low::FunctionKind::Snap => { + let predicate_name = self + .heap_encoder + .get_predicate_name_for_function(&func_app.function_name)?; + let heap_version = if let Some(expression_evaluation_state_label) = + &self.expression_evaluation_state_label + { + self.heap_encoder.get_heap_version_at_label( + &predicate_name, + expression_evaluation_state_label, + )? + } else { + self.heap_encoder + .get_current_heap_version_for(&predicate_name)? + }; + arguments.push(heap_version); + let heap_function_name = self.heap_encoder.heap_function_name(&predicate_name); + let return_type = self + .heap_encoder + .get_snapshot_type_for_predicate(&predicate_name) + .unwrap(); + Ok(vir_low::Expression::domain_function_call( + "HeapFunctions", + heap_function_name, + arguments, + return_type, + )) + } + } + } + + fn fallible_fold_binary_op( + &mut self, + mut binary_op: vir_low::expression::BinaryOp, + ) -> Result { + binary_op.left = self.fallible_fold_expression_boxed(binary_op.left)?; + if binary_op.op_kind == vir_low::BinaryOpKind::Implies { + self.path_condition.push((*binary_op.left).clone()); + } + binary_op.right = self.fallible_fold_expression_boxed(binary_op.right)?; + if binary_op.op_kind == vir_low::BinaryOpKind::Implies { + self.path_condition.pop(); + } + Ok(binary_op) + } + + fn fallible_fold_conditional( + &mut self, + mut conditional: vir_low::expression::Conditional, + ) -> Result { + conditional.guard = self.fallible_fold_expression_boxed(conditional.guard)?; + self.path_condition.push((*conditional.guard).clone()); + conditional.then_expr = self.fallible_fold_expression_boxed(conditional.then_expr)?; + self.path_condition.pop(); + self.path_condition + .push(vir_low::Expression::not((*conditional.guard).clone())); + conditional.else_expr = self.fallible_fold_expression_boxed(conditional.else_expr)?; + self.path_condition.pop(); + Ok(conditional) + } + + fn fallible_fold_labelled_old_enum( + &mut self, + mut labelled_old: vir_low::expression::LabelledOld, + ) -> Result { + std::mem::swap( + &mut labelled_old.label, + &mut self.expression_evaluation_state_label, + ); + let mut labelled_old = self.fallible_fold_labelled_old(labelled_old)?; + std::mem::swap( + &mut labelled_old.label, + &mut self.expression_evaluation_state_label, + ); + Ok(vir_low::Expression::LabelledOld(labelled_old)) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/statements.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/statements.rs new file mode 100644 index 00000000000..51f3f49e8e9 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/statements.rs @@ -0,0 +1,73 @@ +use super::HeapEncoder; +use crate::encoder::errors::SpannedEncodingResult; +use vir_crate::low::{self as vir_low}; + +impl<'p, 'v: 'p, 'tcx: 'v> HeapEncoder<'p, 'v, 'tcx> { + pub(super) fn encode_statement_internal( + &mut self, + statements: &mut Vec, + statement: vir_low::Statement, + ) -> SpannedEncodingResult<()> { + match statement { + vir_low::Statement::Comment(_) + | vir_low::Statement::LogEvent(_) + | vir_low::Statement::Assign(_) => { + statements.push(statement); + } + vir_low::Statement::Label(statement) => { + self.ssa_state.save_state_at_label(statement.label.clone()); + statements.push(vir_low::Statement::Label(statement)); + } + vir_low::Statement::Assume(statement) => { + assert!(statement.expression.is_pure()); + let expression = self.encode_pure_expression( + statements, + statement.expression, + None, + statement.position, + )?; + statements.push(vir_low::Statement::assume(expression, statement.position)); + } + vir_low::Statement::Assert(statement) => { + self.encode_expression_assert( + statements, + statement.expression, + statement.position, + None, + )?; + } + vir_low::Statement::Inhale(statement) => { + statements.push(vir_low::Statement::comment(format!("{}", statement))); + self.encode_expression_inhale( + statements, + statement.expression, + statement.position, + None, + )?; + } + vir_low::Statement::Exhale(statement) => { + statements.push(vir_low::Statement::comment(format!("{}", statement))); + let evaluation_state = self.fresh_label(); + self.ssa_state.save_state_at_label(evaluation_state.clone()); + self.encode_expression_exhale( + statements, + statement.expression, + statement.position, + &evaluation_state, + )?; + } + vir_low::Statement::Fold(_) => todo!(), + vir_low::Statement::Unfold(_) => todo!(), + vir_low::Statement::ApplyMagicWand(_) => { + unimplemented!("magic wands are not supported yet"); + } + vir_low::Statement::MethodCall(statement) => { + unreachable!("method call: {}", statement); + } + vir_low::Statement::Conditional(conditional) => { + unreachable!("conditional: {}", conditional); + } + } + Ok(()) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/mod.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/mod.rs new file mode 100644 index 00000000000..8505370527e --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/mod.rs @@ -0,0 +1,109 @@ +//! This module contains a custom heap encoding that can be used instead of the +//! Viper builtin heap encoding. This module depends on `ErrorManager` and, +//! therefore, has to be in the `prusti-viper` crate. + +mod heap_encoder; +mod variable_declarations; + +use self::heap_encoder::HeapEncoder; +use crate::encoder::{ + errors::{ErrorCtxt, SpannedEncodingResult}, + middle::core_proof::predicates::PredicateInfo, + mir::errors::ErrorInterface, + Encoder, +}; +use std::collections::BTreeMap; +use vir_crate::{ + common::cfg::Cfg, + low::{self as vir_low}, +}; + +pub(in super::super) fn custom_heap_encoding<'p, 'v: 'p, 'tcx: 'v>( + encoder: &'p mut Encoder<'v, 'tcx>, + program: &mut vir_low::Program, + predicate_info: BTreeMap, +) -> SpannedEncodingResult<()> { + let mut procedures = Vec::new(); + let mut heap_encoder = HeapEncoder::new( + encoder, + &program.predicates, + predicate_info, + &program.functions, + ); + for procedure in std::mem::take(&mut program.procedures) { + let procedure = custom_heap_encoding_for_procedure(&mut heap_encoder, procedure)?; + procedures.push(procedure); + } + program.procedures = procedures; + program + .domains + .extend(heap_encoder.generate_necessary_domains()?); + Ok(()) +} + +fn custom_heap_encoding_for_procedure<'p, 'v: 'p, 'tcx: 'v>( + heap_encoder: &mut HeapEncoder<'p, 'v, 'tcx>, + mut procedure: vir_low::ProcedureDecl, +) -> SpannedEncodingResult { + let predecessors = procedure.predecessors_owned(); + let traversal_order = procedure.get_topological_sort(); + let mut basic_block_edges = BTreeMap::new(); + for label in &traversal_order { + heap_encoder.prepare_new_current_block(&label, &predecessors, &mut basic_block_edges)?; + let mut statements = Vec::new(); + let block = procedure.basic_blocks.get_mut(label).unwrap(); + for statement in std::mem::take(&mut block.statements) { + heap_encoder.encode_statement(&mut statements, statement)?; + } + block.statements = statements; + heap_encoder.finish_current_block(label.clone())?; + } + for label in traversal_order { + if let Some(intermediate_blocks) = basic_block_edges.remove(&label) { + let mut block = procedure.basic_blocks.remove(&label).unwrap(); + for (successor_label, equalities) in intermediate_blocks { + let intermediate_block_label = vir_low::Label::new(format!( + "label__from__{}__to__{}", + label.name, successor_label.name + )); + block + .successor + .replace_label(&successor_label, intermediate_block_label.clone()); + let mut successor_statements = Vec::new(); + for (variable_name, ty, position, old_version, new_version) in equalities { + let new_variable = + heap_encoder.create_variable(&variable_name, ty.clone(), new_version)?; + let old_variable = + heap_encoder.create_variable(&variable_name, ty.clone(), old_version)?; + let position = heap_encoder.encoder().change_error_context( + // FIXME: Get a more precise span. + position, + ErrorCtxt::Unexpected, + ); + let statement = vir_low::macros::stmtp! { + position => assume (new_variable == old_variable) + }; + successor_statements.push(statement); + } + procedure.basic_blocks.insert( + intermediate_block_label, + vir_low::BasicBlock { + statements: successor_statements, + successor: vir_low::Successor::Goto(successor_label), + }, + ); + } + procedure.basic_blocks.insert(label, block); + } + } + let init_permissions_to_zero = + heap_encoder.generate_init_permissions_to_zero(procedure.position)?; + procedure.locals.extend(heap_encoder.take_variables()); + procedure + .basic_blocks + .get_mut(&procedure.entry) + .unwrap() + .statements + .splice(0..0, init_permissions_to_zero); + Ok(procedure) +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/variable_declarations.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/variable_declarations.rs new file mode 100644 index 00000000000..d60df85754b --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/variable_declarations.rs @@ -0,0 +1,50 @@ +use crate::encoder::{ + errors::{ErrorCtxt, SpannedEncodingError, SpannedEncodingResult}, + middle::core_proof::{ + lowerer::LoweringResult, + predicates::PredicateInfo, + snapshots::{AllVariablesMap, VariableVersionMap}, + }, + mir::errors::ErrorInterface, + Encoder, +}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::collections::{BTreeMap, BTreeSet}; +use vir_crate::{ + common::{ + cfg::Cfg, + check_mode::CheckMode, + expression::{ + BinaryOperationHelpers, ExpressionIterator, QuantifierHelpers, UnaryOperationHelpers, + }, + graphviz::ToGraphviz, + }, + low::{ + self as vir_low, + expression::visitors::{ExpressionFallibleFolder, ExpressionFolder}, + operations::ty::Typed, + }, + middle as vir_mid, +}; + +#[derive(Default)] +pub(super) struct VariableDeclarations { + variables: FxHashSet, +} + +impl VariableDeclarations { + pub(super) fn create_variable( + &mut self, + variable_name: &str, + ty: vir_low::Type, + version: u64, + ) -> SpannedEncodingResult { + let variable = vir_low::VariableDecl::new(format!("{}_{}", variable_name, version), ty); + self.variables.insert(variable.clone()); + Ok(variable) + } + + pub(super) fn take_variables(&mut self) -> FxHashSet { + std::mem::take(&mut self.variables) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/desugar_conditionals.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/desugar_conditionals.rs new file mode 100644 index 00000000000..222fff676d3 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/desugar_conditionals.rs @@ -0,0 +1,95 @@ +use vir_crate::{ + common::{ + expression::{ExpressionIterator, UnaryOperationHelpers}, + graphviz::ToGraphviz, + }, + low::{self as vir_low}, +}; + +pub(crate) fn desugar_conditionals( + source_filename: &str, + mut program: vir_low::Program, +) -> vir_low::Program { + let mut procedures = Vec::new(); + for procedure in std::mem::take(&mut program.procedures) { + let procedure = desugar_conditionals_in_procedure(procedure); + if prusti_common::config::dump_debug_info() { + prusti_common::report::log::report_with_writer( + "graphviz_method_vir_low_after_desugar_conditionals", + format!("{}.{}.dot", source_filename, procedure.name), + |writer| procedure.to_graphviz(writer).unwrap(), + ); + } + procedures.push(procedure); + } + program.procedures = procedures; + program +} + +fn new_label(prefix: &str, label_counter: &mut usize) -> vir_low::Label { + let label = format!("{}${}", prefix, label_counter); + *label_counter += 1; + vir_low::Label::new(label) +} + +fn desugar_conditionals_in_procedure( + mut procedure: vir_low::ProcedureDecl, +) -> vir_low::ProcedureDecl { + let mut work_queue: Vec<_> = procedure.basic_blocks.keys().cloned().collect(); + let mut label_counter = 0; + while let Some(current_label) = work_queue.pop() { + let block = procedure.basic_blocks.get_mut(¤t_label).unwrap(); + if let Some(conditional_position) = block + .statements + .iter() + .position(|statement| matches!(statement, vir_low::Statement::Conditional(_))) + { + let remaining_statements = block.statements.split_off(conditional_position + 1); + let vir_low::Statement::Conditional(conditional) = block.statements.pop().unwrap() else { + unreachable!(); + }; + let remaining_block_label = new_label("remaining_block_label", &mut label_counter); + let then_block_label = new_label("then_block_label", &mut label_counter); + let else_block_label = new_label("else_block_label", &mut label_counter); + let then_block = vir_low::BasicBlock { + statements: conditional.then_branch, + successor: vir_low::Successor::Goto(remaining_block_label.clone()), + }; + + let mut targets = vec![(conditional.guard.clone(), then_block_label.clone())]; + let negated_guard = vir_low::Expression::not(conditional.guard.clone()); + let else_block = if conditional.else_branch.is_empty() { + targets.push((negated_guard, remaining_block_label.clone())); + None + } else { + let else_block = vir_low::BasicBlock { + statements: conditional.else_branch, + successor: vir_low::Successor::Goto(remaining_block_label.clone()), + }; + targets.push((negated_guard, else_block_label.clone())); + Some(else_block) + }; + let new_successor = vir_low::Successor::GotoSwitch(targets); + let original_successor = std::mem::replace(&mut block.successor, new_successor); + let remaining_block = vir_low::BasicBlock { + statements: remaining_statements, + successor: original_successor, + }; + work_queue.push(remaining_block_label.clone()); + work_queue.push(then_block_label.clone()); + procedure + .basic_blocks + .insert(then_block_label.clone(), then_block); + if let Some(else_block) = else_block { + work_queue.push(else_block_label.clone()); + procedure + .basic_blocks + .insert(else_block_label.clone(), else_block); + } + procedure + .basic_blocks + .insert(remaining_block_label.clone(), remaining_block); + } + } + procedure +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/desugar_method_calls.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/desugar_method_calls.rs new file mode 100644 index 00000000000..5f47a1bb355 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/desugar_method_calls.rs @@ -0,0 +1,80 @@ +use rustc_hash::{FxHashMap, FxHashSet}; +use vir_crate::{ + common::{expression::ExpressionIterator, graphviz::ToGraphviz}, + low::{self as vir_low}, +}; + +pub(crate) fn desugar_method_calls( + source_filename: &str, + mut program: vir_low::Program, +) -> vir_low::Program { + let mut procedures = Vec::new(); + let methods: FxHashMap<_, _> = program + .methods + .iter() + .map(|procedure| (&procedure.name, procedure)) + .collect(); + for mut procedure in std::mem::take(&mut program.procedures) { + let mut label_counter = 0; + for block in procedure.basic_blocks.values_mut() { + let mut statements = Vec::new(); + for statement in std::mem::take(&mut block.statements) { + if let vir_low::Statement::MethodCall(statement) = statement { + statements.push(vir_low::Statement::comment(format!("{}", statement))); + let old_label = format!("method_call_label${}", label_counter); + procedure + .custom_labels + .push(vir_low::Label::new(old_label.clone())); + label_counter += 1; + statements.push(vir_low::Statement::label( + old_label.clone(), + statement.position, + )); + let method = methods[&statement.method_name]; + let arguments: Vec<_> = statement + .arguments + .iter() + .map(|argument| { + vir_low::Expression::labelled_old( + Some(old_label.clone()), + argument.clone(), + statement.position, + ) + }) + .collect(); + let mut replacements = method.parameters.iter().zip(arguments.iter()).collect(); + let assertion = method + .pres + .clone() + .into_iter() + .conjoin() + .substitute_variables(&replacements) + .remove_unnecessary_old(); + statements.push(vir_low::Statement::exhale(assertion, statement.position)); + replacements.extend(method.targets.iter().zip(statement.targets.iter())); + let assertion = method + .posts + .clone() + .into_iter() + .conjoin() + .substitute_variables(&replacements) + .remove_unnecessary_old(); + statements.push(vir_low::Statement::inhale(assertion, statement.position)); + } else { + statements.push(statement); + } + } + block.statements = statements; + } + if prusti_common::config::dump_debug_info() { + prusti_common::report::log::report_with_writer( + "graphviz_method_vir_low_after_desugar_method_calls", + format!("{}.{}.dot", source_filename, procedure.name), + |writer| procedure.to_graphviz(writer).unwrap(), + ); + } + procedures.push(procedure); + } + program.procedures = procedures; + program +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/inline_functions.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/inline_functions.rs index f7878f44769..d12464bd4e4 100644 --- a/prusti-viper/src/encoder/middle/core_proof/transformations/inline_functions.rs +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/inline_functions.rs @@ -1,20 +1,30 @@ use rustc_hash::FxHashMap; use vir_crate::{ - common::expression::{ExpressionIterator, UnaryOperationHelpers}, + common::{ + expression::{ExpressionIterator, UnaryOperationHelpers}, + graphviz::ToGraphviz, + }, low::{self as vir_low}, }; use vir_low::expression::visitors::ExpressionFolder; -pub(crate) fn inline_caller_for(program: &mut vir_low::Program) { +pub(crate) fn inline_caller_for(source_filename: &str, program: &mut vir_low::Program) { let caller_for_functions = program .functions .drain_filter(|function| function.kind == vir_low::FunctionKind::CallerFor) .map(|function| (function.name.clone(), function)) .collect(); for procedure in &mut program.procedures { - for block in &mut procedure.basic_blocks { + for block in procedure.basic_blocks.values_mut() { inline_in_statements(&mut block.statements, &caller_for_functions); } + if prusti_common::config::dump_debug_info() { + prusti_common::report::log::report_with_writer( + "graphviz_method_vir_low_after_inline_caller_for", + format!("{}.{}.dot", source_filename, procedure.name), + |writer| procedure.to_graphviz(writer).unwrap(), + ); + } } } @@ -49,6 +59,7 @@ fn inline_in_statements( .push(vir_low::Statement::Assert(statement)); } vir_low::Statement::Comment(_) + | vir_low::Statement::Label(_) | vir_low::Statement::LogEvent(_) | vir_low::Statement::Inhale(_) | vir_low::Statement::Exhale(_) diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/mod.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/mod.rs index b678daa9e5f..ef684f072b2 100644 --- a/prusti-viper/src/encoder/middle/core_proof/transformations/mod.rs +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/mod.rs @@ -1,3 +1,7 @@ pub(super) mod inline_functions; pub(super) mod remove_predicates; pub(super) mod remove_unvisited_blocks; +pub(super) mod custom_heap_encoding; +pub(super) mod desugar_method_calls; +pub(super) mod desugar_conditionals; +pub(super) mod symbolic_execution; diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/remove_predicates.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/remove_predicates.rs index b53badeb770..bf9741aa204 100644 --- a/prusti-viper/src/encoder/middle/core_proof/transformations/remove_predicates.rs +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/remove_predicates.rs @@ -1,5 +1,5 @@ use rustc_hash::{FxHashMap, FxHashSet}; -use vir_crate::low as vir_low; +use vir_crate::low::{self as vir_low, expression::visitors::default_fold_func_app}; use vir_low::expression::visitors::ExpressionFolder; pub(in super::super) fn remove_predicates( @@ -22,7 +22,7 @@ fn from_procedure( removed_functions: &FxHashSet, predicates: &FxHashMap, ) { - for block in &mut procedure.basic_blocks { + for block in procedure.basic_blocks.values_mut() { from_statements( &mut block.statements, removed_methods, @@ -63,6 +63,7 @@ fn from_statements( for statement in std::mem::take(statements) { match statement { vir_low::Statement::Comment(_) + | vir_low::Statement::Label(_) | vir_low::Statement::LogEvent(_) | vir_low::Statement::Assume(_) | vir_low::Statement::Assert(_) @@ -149,7 +150,7 @@ impl<'a> ExpressionFolder for PredicateRemover<'a> { if self.removed_functions.contains(&func_app.function_name) { self.drop_parent_binary_op = true; } - func_app + default_fold_func_app(self, func_app) } fn fold_binary_op_enum( &mut self, @@ -163,6 +164,12 @@ impl<'a> ExpressionFolder for PredicateRemover<'a> { vir_low::Expression::BinaryOp(binary_op) } } + fn fold_unfolding_enum( + &mut self, + unfolding: vir_low::expression::Unfolding, + ) -> vir_low::Expression { + self.fold_expression(*unfolding.base) + } } struct PredicateInliner<'a> { diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/remove_unvisited_blocks.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/remove_unvisited_blocks.rs index 90964cf6507..1ce12e1dbcc 100644 --- a/prusti-viper/src/encoder/middle/core_proof/transformations/remove_unvisited_blocks.rs +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/remove_unvisited_blocks.rs @@ -10,8 +10,8 @@ pub(in super::super) fn remove_unvisited_blocks( label_markers: &FxHashMap, ) -> SpannedEncodingResult<()> { for procedure in procedures { - for block in &mut procedure.basic_blocks { - if !label_markers.get(&block.label.name).unwrap_or(&true) { + for (label, block) in &mut procedure.basic_blocks { + if !label_markers.get(&label.name).unwrap_or(&true) { // The block was not visited. Replace with assume false. let mut position = None; for statement in &block.statements { diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/egg/language.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/egg/language.rs new file mode 100644 index 00000000000..dcba467b014 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/egg/language.rs @@ -0,0 +1,27 @@ +use egg::{define_language, Id, Symbol}; + +define_language! { + pub(super) enum ExpressionLanguage { + "true" = True, + "false" = False, + "==" = EqCmp([Id; 2]), + "!=" = NeCmp([Id; 2]), + ">" = GtCmp([Id; 2]), + ">=" = GeCmp([Id; 2]), + "<=" = LtCmp([Id; 2]), + "<" = LeCmp([Id; 2]), + "+" = Add([Id; 2]), + "-" = Sub([Id; 2]), + "*" = Mul([Id; 2]), + "/" = Div([Id; 2]), + "%" = Mod([Id; 2]), + "&&" = And([Id; 2]), + "||" = Or([Id; 2]), + "==>" = Implies([Id; 2]), + "!" = Not(Id), + "neg" = Minus(Id), + Int(i64), + Variable(Symbol), + FuncApp(Symbol, Vec), + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/egg/mod.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/egg/mod.rs new file mode 100644 index 00000000000..73cdb0bfb1d --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/egg/mod.rs @@ -0,0 +1,298 @@ +use self::language::ExpressionLanguage; +use super::VerificationResult; +use egg::{EGraph, Id, Symbol}; +use rustc_hash::FxHashMap; +use vir_crate::low::{self as vir_low}; + +mod language; + +#[derive(Clone)] +pub(super) struct EGraphState { + egraph: EGraph, + simplification_rules: Vec>, + false_id: Id, + true_id: Id, + interned_terms: FxHashMap, +} + +impl EGraphState { + pub(super) fn new() -> Self { + let mut egraph = EGraph::default(); + let true_id = egraph.add(ExpressionLanguage::True); + let false_id = egraph.add(ExpressionLanguage::False); + let rule = { + let place_var: egg::Var = "?place".parse().unwrap(); + let address_var: egg::Var = "?address".parse().unwrap(); + let mut pattern: egg::RecExpr> = + egg::RecExpr::default(); + let place = pattern.add(egg::ENodeOrVar::Var(place_var)); + let address = pattern.add(egg::ENodeOrVar::Var(address_var)); + pattern.add(egg::ENodeOrVar::ENode(ExpressionLanguage::FuncApp( + Symbol::from("compute_address"), + vec![place, address], + ))); + let match_pattern = egg::Pattern::new(pattern); + let mut pattern: egg::RecExpr> = + egg::RecExpr::default(); + pattern.add(egg::ENodeOrVar::Var(address_var)); + let target_pattern = egg::Pattern::new(pattern); + egg::rewrite!("evaluate_compute_address"; match_pattern => target_pattern) + }; + let simplification_rules = vec![rule]; + Self { + egraph, + simplification_rules, + true_id, + false_id, + interned_terms: Default::default(), + } + } + + pub(super) fn assume_heap_independent_conjuncts( + &mut self, + expression: &vir_low::Expression, + ) -> VerificationResult<()> { + if let vir_low::Expression::BinaryOp(binary_expression) = expression { + match binary_expression.op_kind { + vir_low::BinaryOpKind::EqCmp => { + if expression.is_heap_independent() { + self.assume_equal(&binary_expression.left, &binary_expression.right)?; + return Ok(()); + } + } + vir_low::BinaryOpKind::And => { + self.assume_heap_independent_conjuncts(&binary_expression.left)?; + self.assume_heap_independent_conjuncts(&binary_expression.right)?; + return Ok(()); + } + _ => {} + } + } + if expression.is_heap_independent() { + self.assume(expression)?; + } + Ok(()) + } + + pub(super) fn assume(&mut self, term: &vir_low::Expression) -> VerificationResult<()> { + let term_id = self.intern_term(term)?; + self.egraph.union(term_id, self.true_id); + Ok(()) + } + + pub(super) fn assume_equal( + &mut self, + left: &vir_low::Expression, + right: &vir_low::Expression, + ) -> VerificationResult<()> { + let left_id = self.intern_term(left)?; + let right_id = self.intern_term(right)?; + self.egraph.union(left_id, right_id); + Ok(()) + } + + /// If the graph was modified, `sature` must be called before `is_equal` can + /// be used. + pub(super) fn is_equal( + &self, + left: &vir_low::Expression, + right: &vir_low::Expression, + ) -> VerificationResult { + let left_id = self.lookup_term(left)?; + let right_id = self.lookup_term(right)?; + Ok(self.egraph.find(left_id) == self.egraph.find(right_id)) + } + + pub(super) fn saturate(&mut self) -> VerificationResult<()> { + self.egraph.rebuild(); + let runner: egg::Runner<_, _, ()> = egg::Runner::new(Default::default()) + .with_egraph(std::mem::take(&mut self.egraph)) + .run(&self.simplification_rules); + assert!(matches!( + runner.stop_reason.unwrap(), + egg::StopReason::Saturated + )); + self.egraph = runner.egraph; + Ok(()) + } + + // pub(super) fn is_equal( + // &mut self, + // left: &vir_low::Expression, + // right: &vir_low::Expression, + // ) -> VerificationResult { + // // return Ok(left == right); + // // if left == right { + // // return Ok(true); + // // } + // // eprintln!("is_equal({}, {}): {}", left, right, self.counter); + // let left_id = self.intern_term(left)?; + // let right_id = self.intern_term(right)?; + // // return Ok(left_id == right_id); + // self.egraph.rebuild(); + // let counter = self.counter; + // self.egraph + // .dot() + // .to_dot(format!("/tmp/egraph-before-{}.dot", counter)) + // .unwrap(); + // self.counter += 1; + // // for rule in &self.simplification_rules { + // // eprintln!(" rule: {}", rule.name); + // // let matches = rule.search(&self.egraph); + // // for m in &matches { + // // eprintln!(" match: {:?}", m); + // // } + // // for id in rule.apply(&mut self.egraph, &matches) { + // // eprintln!(" applied: {:?}", id); + // // } + // // } + // // self.egraph.dot().to_dot(format!("/tmp/egraph-after-{}.dot", counter)).unwrap(); + // // self.egraph.rebuild(); + // // self.egraph.dot().to_dot(format!("/tmp/egraph-rebuild-{}.dot", counter)).unwrap(); + // let runner: egg::Runner<_, _, ()> = egg::Runner::new(Default::default()) + // .with_egraph(std::mem::take(&mut self.egraph)) + // // .with_hook(move |runner| { + // // runner.egraph.dot().to_dot(format!("/tmp/egraph{}.dot", counter)).unwrap(); + // // Ok(())}) + // .run(&self.simplification_rules); + // assert!(matches!( + // runner.stop_reason.unwrap(), + // egg::StopReason::Saturated + // )); + // self.egraph = runner.egraph; + // self.egraph + // .dot() + // .to_dot(format!("/tmp/egraph-after2-{}.dot", counter)) + // .unwrap(); + // Ok(self.egraph.find(left_id) == self.egraph.find(right_id)) + // } + + pub(super) fn is_inconsistent(&mut self) -> VerificationResult { + self.egraph.rebuild(); + Ok(self.egraph.find(self.true_id) == self.egraph.find(self.false_id)) + } + + /// Lookup the id of a previously interned term. + fn lookup_term(&self, term: &vir_low::Expression) -> VerificationResult { + Ok(*self.interned_terms.get(term).unwrap_or_else(|| { + panic!("term {} is not interned", term); + })) + } + + pub(super) fn intern_term(&mut self, term: &vir_low::Expression) -> VerificationResult { + if let Some(id) = self.interned_terms.get(term) { + Ok(*id) + } else { + assert!(term.is_heap_independent(), "{} is heap independent", term); + let id = self.intern_term_rec(term)?; + self.interned_terms.insert(term.clone(), id); + Ok(id) + } + } + + /// This method must be called only through `intern_term` that checks its + /// precondition. + fn intern_term_rec(&mut self, term: &vir_low::Expression) -> VerificationResult { + let id = match term { + vir_low::Expression::Local(expression) => { + let symbol = Symbol::from(&expression.variable.name); + self.egraph.add(ExpressionLanguage::Variable(symbol)) + } + vir_low::Expression::Constant(expression) => match expression.value { + vir_low::ConstantValue::Bool(true) => self.true_id, + vir_low::ConstantValue::Bool(false) => self.false_id, + vir_low::ConstantValue::Int(value) => { + self.egraph.add(ExpressionLanguage::Int(value)) + } + vir_low::ConstantValue::BigInt(_) => todo!(), + }, + vir_low::Expression::UnaryOp(expression) => { + let operand_id = self.intern_term_rec(&expression.argument)?; + match expression.op_kind { + vir_low::UnaryOpKind::Not => { + self.egraph.add(ExpressionLanguage::Not(operand_id)) + } + vir_low::UnaryOpKind::Minus => { + self.egraph.add(ExpressionLanguage::Minus(operand_id)) + } + } + } + vir_low::Expression::BinaryOp(expression) => { + let left_id = self.intern_term_rec(&expression.left)?; + let right_id = self.intern_term_rec(&expression.right)?; + match expression.op_kind { + vir_low::BinaryOpKind::EqCmp => self + .egraph + .add(ExpressionLanguage::EqCmp([left_id, right_id])), + vir_low::BinaryOpKind::NeCmp => self + .egraph + .add(ExpressionLanguage::NeCmp([left_id, right_id])), + vir_low::BinaryOpKind::GtCmp => self + .egraph + .add(ExpressionLanguage::GtCmp([left_id, right_id])), + vir_low::BinaryOpKind::GeCmp => self + .egraph + .add(ExpressionLanguage::GeCmp([left_id, right_id])), + vir_low::BinaryOpKind::LtCmp => self + .egraph + .add(ExpressionLanguage::LtCmp([left_id, right_id])), + vir_low::BinaryOpKind::LeCmp => self + .egraph + .add(ExpressionLanguage::LeCmp([left_id, right_id])), + vir_low::BinaryOpKind::Add => self + .egraph + .add(ExpressionLanguage::Add([left_id, right_id])), + vir_low::BinaryOpKind::Sub => self + .egraph + .add(ExpressionLanguage::Sub([left_id, right_id])), + vir_low::BinaryOpKind::Mul => self + .egraph + .add(ExpressionLanguage::Mul([left_id, right_id])), + vir_low::BinaryOpKind::Div => self + .egraph + .add(ExpressionLanguage::Div([left_id, right_id])), + vir_low::BinaryOpKind::Mod => self + .egraph + .add(ExpressionLanguage::Mod([left_id, right_id])), + vir_low::BinaryOpKind::And => self + .egraph + .add(ExpressionLanguage::And([left_id, right_id])), + vir_low::BinaryOpKind::Or => { + self.egraph.add(ExpressionLanguage::Or([left_id, right_id])) + } + vir_low::BinaryOpKind::Implies => self + .egraph + .add(ExpressionLanguage::Implies([left_id, right_id])), + } + } + vir_low::Expression::PermBinaryOp(expression) => todo!("expresion: {}", expression), + vir_low::Expression::ContainerOp(expression) => todo!("expresion: {}", expression), + vir_low::Expression::DomainFuncApp(expression) => { + let symbol = Symbol::from(&expression.function_name); + let arguments = expression + .arguments + .iter() + .map(|argument| self.intern_term_rec(argument)) + .collect::>>()?; + self.egraph + .add(ExpressionLanguage::FuncApp(symbol, arguments)) + } + vir_low::Expression::LabelledOld(expression) => { + self.intern_term_rec(&expression.base)? + } + vir_low::Expression::MagicWand(_) + | vir_low::Expression::PredicateAccessPredicate(_) + | vir_low::Expression::FieldAccessPredicate(_) + | vir_low::Expression::Unfolding(_) + | vir_low::Expression::Conditional(_) + | vir_low::Expression::Quantifier(_) + | vir_low::Expression::LetExpr(_) + | vir_low::Expression::FuncApp(_) + | vir_low::Expression::InhaleExhale(_) + | vir_low::Expression::Field(_) => { + unreachable!("term: {}", term); + } + }; + Ok(id) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/entry.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/entry.rs new file mode 100644 index 00000000000..b4d38d87f9c --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/entry.rs @@ -0,0 +1,41 @@ +use vir_crate::low::{self as vir_low}; + +pub(in super::super) enum HeapEntry { + Comment(vir_low::ast::statement::Comment), + Label(vir_low::ast::statement::Label), + /// An inhale entry that can be purified. + InhalePredicate( + vir_low::ast::expression::PredicateAccessPredicate, + vir_low::Position, + ), + /// An exhale entry that can be purified. + ExhalePredicate( + vir_low::ast::expression::PredicateAccessPredicate, + vir_low::Position, + ), + /// A generic inhale entry that cannot be purified. + InhaleGeneric(vir_low::ast::statement::Inhale), + /// A generic exhale entry that cannot be purified. + ExhaleGeneric(vir_low::ast::statement::Exhale), + Assume(vir_low::ast::statement::Assume), + Assert(vir_low::ast::statement::Assert), +} + +impl std::fmt::Display for HeapEntry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HeapEntry::Comment(statement) => write!(f, "{}", statement), + HeapEntry::Label(statement) => write!(f, "{}", statement), + HeapEntry::InhalePredicate(predicate, _position) => { + write!(f, "inhale-predicate {}", predicate) + } + HeapEntry::ExhalePredicate(predicate, _position) => { + write!(f, "exhale-predicate {}", predicate) + } + HeapEntry::InhaleGeneric(statement) => write!(f, "{}", statement), + HeapEntry::ExhaleGeneric(statement) => write!(f, "{}", statement), + HeapEntry::Assume(statement) => write!(f, "{}", statement), + HeapEntry::Assert(statement) => write!(f, "{}", statement), + } + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/finalizer.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/finalizer.rs new file mode 100644 index 00000000000..f037543d82d --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/finalizer.rs @@ -0,0 +1,248 @@ +use super::{ + predicate_snapshots::PredicateSnapshots, state::PredicateInstanceState, HeapEntry, HeapState, + Location, +}; +use crate::encoder::middle::core_proof::transformations::symbolic_execution::{ + egg::EGraphState, program_context::ProgramContext, trace::Trace, VerificationResult, +}; +use std::collections::BTreeMap; +use vir_crate::low::{self as vir_low, expression::visitors::ExpressionFallibleFolder}; + +pub(super) struct TraceFinalizer<'a> { + final_state: &'a HeapState, + trace: Vec, + new_variables: Vec, + new_labels: Vec, + predicate_snapshots: PredicateSnapshots, + predicate_snapshots_at_label: BTreeMap, + solver: &'a EGraphState, + program: &'a ProgramContext<'a>, +} + +impl<'a> TraceFinalizer<'a> { + pub(super) fn new( + final_state: &'a HeapState, + solver: &'a EGraphState, + program: &'a ProgramContext<'a>, + ) -> Self { + Self { + final_state, + trace: Vec::new(), + new_variables: Vec::new(), + new_labels: Vec::new(), + predicate_snapshots: Default::default(), + predicate_snapshots_at_label: Default::default(), + solver, + program, + } + } + + pub(super) fn into_trace(self) -> Trace { + let mut variables = self.new_variables; + variables.extend(self.predicate_snapshots.into_variables()); + Trace { + statements: self.trace, + variables, + labels: self.new_labels, + } + } + + pub(super) fn add_variables( + &mut self, + new_variables: &[vir_low::VariableDecl], + ) -> VerificationResult<()> { + self.new_variables.extend_from_slice(new_variables); + Ok(()) + } + + pub(super) fn add_labels(&mut self, new_labels: &[vir_low::Label]) -> VerificationResult<()> { + self.new_labels.extend_from_slice(new_labels); + Ok(()) + } + + pub(super) fn add_entry( + &mut self, + location: Location, + entry: &HeapEntry, + ) -> VerificationResult<()> { + match entry { + HeapEntry::Comment(statement) => { + self.trace + .push(vir_low::Statement::Comment(statement.clone())); + } + HeapEntry::Label(statement) => { + self.save_state(statement.label.clone()); + self.trace + .push(vir_low::Statement::Label(statement.clone())); + } + HeapEntry::InhalePredicate(predicate, position) => { + if self.is_purified_inhale(location, predicate) { + self.trace.push(vir_low::Statement::comment(format!( + "purified out: {}", + entry + ))); + self.predicate_snapshots + .create_predicate_snapshot(self.program, predicate); + } else { + self.trace.push(vir_low::Statement::inhale( + vir_low::Expression::PredicateAccessPredicate(predicate.clone()), + *position, + )); + } + } + HeapEntry::ExhalePredicate(predicate, position) => { + if self.is_purified_exhale(location, predicate) { + self.trace.push(vir_low::Statement::comment(format!( + "purified out: {}", + entry + ))); + self.predicate_snapshots + .destroy_predicate_snapshot(predicate, self.solver)?; + } else { + self.trace.push(vir_low::Statement::exhale( + vir_low::Expression::PredicateAccessPredicate(predicate.clone()), + *position, + )); + } + } + HeapEntry::InhaleGeneric(statement) => { + let mut statement = statement.clone(); + statement.expression = self.purify_snap_calls(statement.expression)?; + self.trace.push(vir_low::Statement::Inhale(statement)); + } + HeapEntry::ExhaleGeneric(statement) => { + let mut statement = statement.clone(); + statement.expression = self.purify_snap_calls(statement.expression)?; + self.trace.push(vir_low::Statement::Exhale(statement)); + } + HeapEntry::Assume(statement) => { + let mut statement = statement.clone(); + statement.expression = self.purify_snap_calls(statement.expression)?; + self.trace.push(vir_low::Statement::Assume(statement)); + } + HeapEntry::Assert(statement) => { + let mut statement = statement.clone(); + statement.expression = self.purify_snap_calls(statement.expression)?; + self.trace.push(vir_low::Statement::Assert(statement)); + } + } + Ok(()) + } + + fn is_purified_inhale( + &self, + location: Location, + predicate: &vir_low::expression::PredicateAccessPredicate, + ) -> bool { + if let Some(predicate_state) = self.final_state.get_predicate(&predicate.name) { + for predicate_instance in predicate_state.get_instances() { + if predicate_instance.inhale_location == location { + if let PredicateInstanceState::Exhaled(_) = predicate_instance.state { + return true; + } + } + } + } + false + } + + fn is_purified_exhale( + &self, + location: Location, + predicate: &vir_low::expression::PredicateAccessPredicate, + ) -> bool { + if let Some(predicate_state) = self.final_state.get_predicate(&predicate.name) { + for predicate_instance in predicate_state.get_instances() { + if let PredicateInstanceState::Exhaled(exhale_location) = predicate_instance.state { + if exhale_location == location { + return true; + } + } + } + } + false + } + + fn save_state(&mut self, label: String) { + assert!(self + .predicate_snapshots_at_label + .insert(label, self.predicate_snapshots.clone()) + .is_none()); + } + + fn purify_snap_calls( + &mut self, + expression: vir_low::Expression, + ) -> VerificationResult { + struct Purifier<'a> { + predicate_snapshots: &'a PredicateSnapshots, + predicate_snapshots_at_label: &'a BTreeMap, + solver: &'a EGraphState, + program: &'a ProgramContext<'a>, + label: Option, + } + impl<'a> ExpressionFallibleFolder for Purifier<'a> { + type Error = super::super::Error; + + fn fallible_fold_func_app_enum( + &mut self, + func_app: vir_low::expression::FuncApp, + ) -> Result { + let func_app = self.fallible_fold_func_app(func_app)?; + let function = self.program.get_function(&func_app.function_name); + assert_eq!(function.parameters.len(), func_app.arguments.len()); + match function.kind { + vir_low::FunctionKind::MemoryBlockBytes => todo!(), + vir_low::FunctionKind::CallerFor => todo!(), + vir_low::FunctionKind::Snap => { + if let Some(snapshot_variable) = + self.resolve_snapshot(&func_app.function_name, &func_app.arguments)? + { + Ok(vir_low::Expression::local( + snapshot_variable, + func_app.position, + )) + } else { + Ok(vir_low::Expression::FuncApp(func_app)) + } + } + } + } + + fn fallible_fold_labelled_old( + &mut self, + mut labelled_old: vir_low::expression::LabelledOld, + ) -> Result { + std::mem::swap(&mut labelled_old.label, &mut self.label); + labelled_old.base = self.fallible_fold_expression_boxed(labelled_old.base)?; + std::mem::swap(&mut labelled_old.label, &mut self.label); + Ok(labelled_old) + } + } + impl<'a> Purifier<'a> { + fn resolve_snapshot( + &mut self, + function_name: &str, + arguments: &[vir_low::Expression], + ) -> VerificationResult> { + let predicate_snapshots = if let Some(label) = &self.label { + self.predicate_snapshots_at_label.get(label).unwrap() + } else { + self.predicate_snapshots + }; + // FIXME: Do not use strings here. + let predicate_name = + function_name.replace("snap_owned_non_aliased$", "OwnedNonAliased$"); + predicate_snapshots.find_snapshot(&predicate_name, arguments, self.solver) + } + } + let mut purifier = Purifier { + predicate_snapshots: &self.predicate_snapshots, + predicate_snapshots_at_label: &self.predicate_snapshots_at_label, + solver: self.solver, + program: self.program, + label: None, + }; + purifier.fallible_fold_expression(expression) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/graphviz.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/graphviz.rs new file mode 100644 index 00000000000..b991255f7f2 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/graphviz.rs @@ -0,0 +1,26 @@ +use super::HeapEntry; +use crate::encoder::middle::core_proof::transformations::symbolic_execution::trace_builder::ExecutionTraceHeapView; +use vir_crate::common::graphviz::{escape_html, Graph, ToGraphviz}; + +impl<'a> ToGraphviz for ExecutionTraceHeapView<'a> { + fn to_graph(&self) -> Graph { + let mut graph = Graph::with_columns(&["statement"]); + for (block_id, block) in self.iter_blocks().enumerate() { + let mut node_builder = graph.create_node(format!("block{}", block_id)); + for statement in block.iter_entries() { + let statement_string = match statement { + HeapEntry::Comment(statement) => { + format!("{}", escape_html(statement)) + } + _ => escape_html(statement.to_string()), + }; + node_builder.add_row_sequence(vec![statement_string]); + } + node_builder.build(); + if let Some(parent) = block.parent() { + graph.add_regular_edge(format!("block{}", parent), format!("block{}", block_id)); + } + } + graph + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/mod.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/mod.rs new file mode 100644 index 00000000000..98f1c7c1a14 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/mod.rs @@ -0,0 +1,577 @@ +mod graphviz; +mod entry; +mod state; +mod utils; +mod finalizer; +mod predicate_snapshots; + +use super::{ + egg::EGraphState, + program_context::ProgramContext, + trace::Trace, + trace_builder::{ExecutionTraceBuilder, ExecutionTraceHeapView}, + VerificationResult, +}; +use log::debug; +use rustc_hash::FxHashMap; +use std::collections::BTreeMap; +use vir_crate::{ + common::{ + ast::predicate, + display, + graphviz::{escape_html, Graph, ToGraphviz}, + }, + low::{self as vir_low, expression::visitors::ExpressionFallibleFolder}, +}; + +pub(super) use self::{entry::HeapEntry, state::HeapState}; +use self::{finalizer::TraceFinalizer, utils::arguments_match}; + +impl ExecutionTraceBuilder { + pub(super) fn heap_comment( + &mut self, + statement: vir_low::ast::statement::Comment, + ) -> VerificationResult<()> { + self.add_heap_entry(HeapEntry::Comment(statement))?; + Ok(()) + } + + pub(super) fn heap_label( + &mut self, + statement: vir_low::ast::statement::Label, + ) -> VerificationResult<()> { + self.add_heap_entry(HeapEntry::Label(statement))?; + Ok(()) + } + + pub(super) fn heap_assume( + &mut self, + expression: vir_low::Expression, + position: vir_low::Position, + ) -> VerificationResult<()> { + assert!( + !position.is_default(), + "assume {expression} with default position" + ); + self.add_heap_entry(HeapEntry::Assume(vir_low::ast::statement::Assume { + expression, + position, + }))?; + Ok(()) + } + + pub(super) fn heap_assert( + &mut self, + expression: vir_low::Expression, + position: vir_low::Position, + ) -> VerificationResult<()> { + self.add_heap_entry(HeapEntry::Assert(vir_low::ast::statement::Assert { + expression, + position, + }))?; + Ok(()) + } + + fn next_location(&self) -> Location { + let view = self.heap_view(); + Location { + block_id: view.block_count() - 1, + entry_id: view.last_block_entry_count(), + } + } + + pub(super) fn heap_inhale_predicate( + &mut self, + predicate: vir_low::ast::expression::PredicateAccessPredicate, + position: vir_low::Position, + ) -> VerificationResult<()> { + let next_location = self.next_location(); + // let current_block = self.blocks.last_mut().unwrap(); + // let state = &mut current_block.state; + let state = self.current_heap_state(); + state.add_predicate_instance(&predicate, next_location); + self.add_heap_entry(HeapEntry::InhalePredicate(predicate, position)) + } + + pub(super) fn heap_inhale_generic( + &mut self, + expression: vir_low::Expression, + position: vir_low::Position, + ) -> VerificationResult<()> { + // let current_block = self.blocks.last_mut().unwrap(); + // let state = &mut current_block.state; + let state = self.current_heap_state(); + for predicate_name in expression.collect_access_predicate_names() { + state.mark_predicate_instances_as_inhaled(&predicate_name); + } + self.add_heap_entry(HeapEntry::InhaleGeneric(vir_low::ast::statement::Inhale { + expression, + position, + })) + } + + pub(super) fn heap_exhale_predicate( + &mut self, + predicate: vir_low::ast::expression::PredicateAccessPredicate, + position: vir_low::Position, + ) -> VerificationResult<()> { + let next_location = self.next_location(); + // let current_block = self.blocks.last_mut().unwrap(); + // let state = &mut current_block.state; + let (state, solver) = self.current_heap_and_egraph_state_mut(); + state.try_removing_predicate_instance(&predicate, next_location, solver)?; + self.add_heap_entry(HeapEntry::ExhalePredicate(predicate, position))?; + Ok(()) + } + + pub(super) fn heap_exhale_generic( + &mut self, + expression: vir_low::Expression, + position: vir_low::Position, + ) -> VerificationResult<()> { + // let current_block = self.blocks.last_mut().unwrap(); + // let state = &mut current_block.state; + let state = self.current_heap_state(); + for predicate_name in expression.collect_access_predicate_names() { + state.mark_predicate_instances_as_exhaled(&predicate_name); + } + self.add_heap_entry(HeapEntry::ExhaleGeneric(vir_low::ast::statement::Exhale { + expression, + position, + })) + } + + pub(super) fn heap_finalize_trace( + &self, + program: &ProgramContext, + ) -> VerificationResult { + debug!("Finalizing trace"); + let (state, solver) = self.current_heap_and_egraph_state(); + let view = self.heap_view(); + let last_block_id = view.last_block_id(); + let mut trace_finalizer = TraceFinalizer::new(state, solver, program); + self.finalize_trace_for_block(&mut trace_finalizer, view, last_block_id)?; + Ok(trace_finalizer.into_trace()) + } + + fn finalize_trace_for_block( + &self, + trace_finalizer: &mut TraceFinalizer, + view: ExecutionTraceHeapView, + block_id: usize, + ) -> VerificationResult<()> { + let block = view.get_block(block_id); + if let Some(parent_id) = block.parent() { + self.finalize_trace_for_block(trace_finalizer, view, parent_id)?; + } + trace_finalizer.add_variables(block.get_new_variables())?; + trace_finalizer.add_labels(block.get_new_labels())?; + for (entry_id, entry) in block.iter_entries().enumerate() { + trace_finalizer.add_entry(Location { block_id, entry_id }, entry)?; + } + Ok(()) + } +} + +// pub(super) struct ExecutionTraceHeap { +// blocks: Vec, +// } + +// impl ExecutionTraceHeap { +// // pub(super) fn new() -> Self { +// // let initial_block = ExecutionTraceBlock { +// // parent: None, +// // entries: Vec::new(), +// // state: HeapState::default(), +// // }; +// // Self { +// // blocks: vec![initial_block], +// // } +// // } + +// // fn add_entry(&mut self, entry: Entry) { +// // let current_block = self.blocks.last_mut().unwrap(); +// // current_block.entries.push(entry); +// // } + +// pub(super) fn comment( +// &mut self, +// statement: vir_low::ast::statement::Comment, +// ) -> VerificationResult<()> { +// self.add_entry(Entry::Comment(statement)); +// Ok(()) +// } + +// pub(super) fn label( +// &mut self, +// statement: vir_low::ast::statement::Label, +// ) -> VerificationResult<()> { +// self.add_entry(Entry::Label(statement)); +// Ok(()) +// } + +// pub(super) fn assume( +// &mut self, +// expression: vir_low::Expression, +// position: vir_low::Position, +// ) -> VerificationResult<()> { +// assert!(!position.is_default(), "assume {expression} with default position"); +// self.add_entry(Entry::Assume(vir_low::ast::statement::Assume { +// expression, +// position, +// })); +// Ok(()) +// } + +// pub(super) fn assert( +// &mut self, +// expression: vir_low::Expression, +// position: vir_low::Position, +// ) -> VerificationResult<()> { +// self.add_entry(Entry::Assert(vir_low::ast::statement::Assert { +// expression, +// position, +// })); +// Ok(()) +// } + +// fn next_location(&self) -> Location { +// Location { +// block_id: self.blocks.len() - 1, +// entry_id: self.blocks.last().unwrap().entries.len(), +// } +// } + +// pub(super) fn inhale_predicate( +// &mut self, +// predicate: vir_low::ast::expression::PredicateAccessPredicate, +// position: vir_low::Position, +// ) -> VerificationResult<()> { +// let next_location = self.next_location(); +// let current_block = self.blocks.last_mut().unwrap(); +// let state = &mut current_block.state; +// state.add_predicate_instance(&predicate, next_location); +// self.add_entry(Entry::InhalePredicate(predicate, position)); +// Ok(()) +// } + +// pub(super) fn inhale_generic( +// &mut self, +// expression: vir_low::Expression, +// position: vir_low::Position, +// ) -> VerificationResult<()> { +// let current_block = self.blocks.last_mut().unwrap(); +// let state = &mut current_block.state; +// for predicate_name in expression.collect_access_predicate_names() { +// state.mark_predicate_instances_as_inhaled(&predicate_name); +// } +// self.add_entry(Entry::InhaleGeneric(vir_low::ast::statement::Inhale { +// expression, +// position, +// })); +// Ok(()) +// } + +// pub(super) fn exhale_predicate( +// &mut self, +// predicate: vir_low::ast::expression::PredicateAccessPredicate, +// position: vir_low::Position, +// solver: &EGraphState, +// ) -> VerificationResult<()> { +// let next_location = self.next_location(); +// let current_block = self.blocks.last_mut().unwrap(); +// let state = &mut current_block.state; +// state.try_removing_predicate_instance(&predicate, next_location, solver)?; +// self.add_entry(Entry::ExhalePredicate(predicate, position)); +// Ok(()) +// } + +// pub(super) fn exhale_generic( +// &mut self, +// expression: vir_low::Expression, +// position: vir_low::Position, +// ) -> VerificationResult<()> { +// let current_block = self.blocks.last_mut().unwrap(); +// let state = &mut current_block.state; +// for predicate_name in expression.collect_access_predicate_names() { +// state.mark_predicate_instances_as_exhaled(&predicate_name); +// } +// self.add_entry(Entry::ExhaleGeneric(vir_low::ast::statement::Exhale { +// expression, +// position, +// })); +// Ok(()) +// } + +// pub(super) fn finalize_trace( +// &self, +// solver: &EGraphState, +// program: &ProgramContext, +// ) -> VerificationResult { +// let final_state = &self.blocks.last().unwrap().state; +// let mut trace_finalizer = TraceFinalizer::new(final_state, solver, program); +// self.finalize_trace_for_block(&mut trace_finalizer, self.blocks.len() - 1)?; +// Ok(trace_finalizer.into_trace()) +// } + +// fn finalize_trace_for_block( +// &self, +// trace_finalizer: &mut TraceFinalizer, +// block_id: usize, +// ) -> VerificationResult<()> { +// let block = &self.blocks[block_id]; +// if let Some(parent_id) = block.parent { +// self.finalize_trace_for_block(trace_finalizer, parent_id)?; +// } +// for (entry_id, entry) in block.entries.iter().enumerate() { +// trace_finalizer.add_entry(Location { block_id, entry_id }, entry)?; +// } +// Ok(()) +// } +// } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct Location { + block_id: usize, + entry_id: usize, +} + +// struct ExecutionTraceBlock { +// /// The parent of this block. The root does not have a parent. +// parent: Option, +// entries: Vec, +// /// The last state. If the block is fully executed, it is the state after +// /// the last statement. +// state: HeapState, +// } + +// pub(super) struct Trace { +// statements: Vec, +// } + +// impl Trace { +// pub(super) fn write(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> { +// for statement in &self.statements { +// writeln!(writer, "{}", statement)?; +// } +// Ok(()) +// } + +// pub(super) fn into_statements(self) -> Vec { +// self.statements +// } +// } + +// struct TraceFinalizer<'a> { +// final_state: &'a HeapState, +// trace: Vec, +// predicate_snapshots: PredicateSnapshots, +// predicate_snapshots_at_label: BTreeMap, +// solver: &'a EGraphState, +// program: &'a ProgramContext<'a>, +// } + +// impl<'a> TraceFinalizer<'a> { +// fn new( +// final_state: &'a HeapState, +// solver: &'a EGraphState, +// program: &'a ProgramContext<'a>, +// ) -> Self { +// Self { +// final_state, +// trace: Vec::new(), +// predicate_snapshots: Default::default(), +// predicate_snapshots_at_label: Default::default(), +// solver, +// program, +// } +// } + +// fn into_trace(self) -> Trace { +// Trace { +// statements: self.trace, +// } +// } + +// fn add_entry(&mut self, location: Location, entry: &Entry) -> VerificationResult<()> { +// match entry { +// Entry::Comment(statement) => { +// self.trace +// .push(vir_low::Statement::Comment(statement.clone())); +// } +// Entry::Label(statement) => { +// self.save_state(statement.label.clone()); +// self.trace +// .push(vir_low::Statement::Label(statement.clone())); +// } +// Entry::InhalePredicate(predicate, position) => { +// if self.is_purified_inhale(location, predicate) { +// self.trace.push(vir_low::Statement::comment(format!( +// "purified out: {}", +// entry +// ))); +// self.predicate_snapshots +// .create_predicate_snapshot(predicate); +// } else { +// self.trace.push(vir_low::Statement::inhale( +// vir_low::Expression::PredicateAccessPredicate(predicate.clone()), +// *position, +// )); +// } +// } +// Entry::ExhalePredicate(predicate, position) => { +// if self.is_purified_exhale(location, predicate) { +// self.trace.push(vir_low::Statement::comment(format!( +// "purified out: {}", +// entry +// ))); +// self.predicate_snapshots +// .destroy_predicate_snapshot(predicate, self.solver)?; +// } else { +// self.trace.push(vir_low::Statement::exhale( +// vir_low::Expression::PredicateAccessPredicate(predicate.clone()), +// *position, +// )); +// } +// } +// Entry::InhaleGeneric(statement) => { +// let mut statement = statement.clone(); +// statement.expression = self.purify_snap_calls(statement.expression)?; +// self.trace.push(vir_low::Statement::Inhale(statement)); +// } +// Entry::ExhaleGeneric(statement) => { +// let mut statement = statement.clone(); +// statement.expression = self.purify_snap_calls(statement.expression)?; +// self.trace.push(vir_low::Statement::Exhale(statement)); +// } +// Entry::Assume(statement) => { +// let mut statement = statement.clone(); +// statement.expression = self.purify_snap_calls(statement.expression)?; +// self.trace.push(vir_low::Statement::Assume(statement)); +// } +// Entry::Assert(statement) => { +// let mut statement = statement.clone(); +// statement.expression = self.purify_snap_calls(statement.expression)?; +// self.trace.push(vir_low::Statement::Assert(statement)); +// } +// } +// Ok(()) +// } + +// fn is_purified_inhale( +// &self, +// location: Location, +// predicate: &vir_low::expression::PredicateAccessPredicate, +// ) -> bool { +// if let Some(predicate_state) = self.final_state.predicates.get(&predicate.name) { +// for predicate_instance in &predicate_state.instances { +// if predicate_instance.inhale_location == location { +// if let PredicateInstanceState::Exhaled(_) = predicate_instance.state { +// return true; +// } +// } +// } +// } +// false +// } + +// fn is_purified_exhale( +// &self, +// location: Location, +// predicate: &vir_low::expression::PredicateAccessPredicate, +// ) -> bool { +// if let Some(predicate_state) = self.final_state.predicates.get(&predicate.name) { +// for predicate_instance in &predicate_state.instances { +// if let PredicateInstanceState::Exhaled(exhale_location) = predicate_instance.state { +// if exhale_location == location { +// return true; +// } +// } +// } +// } +// false +// } + +// fn save_state(&mut self, label: String) { +// assert!(self +// .predicate_snapshots_at_label +// .insert(label, self.predicate_snapshots.clone()) +// .is_none()); +// } + +// fn purify_snap_calls( +// &mut self, +// expression: vir_low::Expression, +// ) -> VerificationResult { +// struct Purifier<'a> { +// predicate_snapshots: &'a PredicateSnapshots, +// predicate_snapshots_at_label: &'a BTreeMap, +// solver: &'a EGraphState, +// program: &'a ProgramContext<'a>, +// label: Option, +// } +// impl<'a> ExpressionFallibleFolder for Purifier<'a> { +// type Error = super::Error; + +// fn fallible_fold_func_app_enum( +// &mut self, +// func_app: vir_low::expression::FuncApp, +// ) -> Result { +// let func_app = self.fallible_fold_func_app(func_app)?; +// let function = self.program.get_function(&func_app.function_name); +// assert_eq!(function.parameters.len(), func_app.arguments.len()); +// match function.kind { +// vir_low::FunctionKind::MemoryBlockBytes => todo!(), +// vir_low::FunctionKind::CallerFor => todo!(), +// vir_low::FunctionKind::Snap => { +// if let Some(snapshot_variable) = +// self.resolve_snapshot(&func_app.function_name, &func_app.arguments)? +// { +// Ok(vir_low::Expression::local( +// snapshot_variable, +// func_app.position, +// )) +// } else { +// Ok(vir_low::Expression::FuncApp(func_app)) +// } +// } +// } +// } + +// fn fallible_fold_labelled_old( +// &mut self, +// mut labelled_old: vir_low::expression::LabelledOld, +// ) -> Result { +// std::mem::swap(&mut labelled_old.label, &mut self.label); +// labelled_old.base = self.fallible_fold_expression_boxed(labelled_old.base)?; +// std::mem::swap(&mut labelled_old.label, &mut self.label); +// Ok(labelled_old) +// } +// } +// impl<'a> Purifier<'a> { +// fn resolve_snapshot( +// &mut self, +// function_name: &str, +// arguments: &[vir_low::Expression], +// ) -> VerificationResult> { +// let predicate_snapshots = if let Some(label) = &self.label { +// self.predicate_snapshots_at_label.get(label).unwrap() +// } else { +// self.predicate_snapshots +// }; +// // FIXME: Do not use strings here. +// let predicate_name = +// function_name.replace("snap_owned_non_aliased$", "OwnedNonAliased$"); +// predicate_snapshots.find_snapshot(&predicate_name, arguments, self.solver) +// } +// } +// let mut purifier = Purifier { +// predicate_snapshots: &self.predicate_snapshots, +// predicate_snapshots_at_label: &self.predicate_snapshots_at_label, +// solver: self.solver, +// program: self.program, +// label: None, +// }; +// purifier.fallible_fold_expression(expression) +// } +// } diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/predicate_snapshots.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/predicate_snapshots.rs new file mode 100644 index 00000000000..f1978ad4fe7 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/predicate_snapshots.rs @@ -0,0 +1,118 @@ +use crate::encoder::middle::core_proof::transformations::symbolic_execution::{ + egg::EGraphState, program_context::ProgramContext, VerificationResult, +}; +use std::collections::BTreeMap; +use vir_crate::low::{self as vir_low}; + +use super::utils::arguments_match; + +#[derive(Default, Clone)] +pub(super) struct PredicateSnapshots { + snapshots: BTreeMap>, + variables: Vec, +} +impl PredicateSnapshots { + pub(super) fn create_predicate_snapshot( + &mut self, + program_context: &ProgramContext, + predicate: &vir_low::expression::PredicateAccessPredicate, + ) { + let predicate_snapshots = self.snapshots.entry(predicate.name.clone()).or_default(); + let snapshot_variable_name = + format!("snapshot${}${}", predicate.name, predicate_snapshots.len()); + let snapshot = if let Some(ty) = program_context.get_snapshot_type(&predicate.name) { + let snapshot = vir_low::VariableDecl::new(snapshot_variable_name, ty); + self.variables.push(snapshot.clone()); + PredicateSnapshotState::Inhaled(snapshot) + } else { + PredicateSnapshotState::NoSnapshot + }; + predicate_snapshots.push(PredicateSnapshot { + arguments: predicate.arguments.clone(), + snapshot, + }); + } + + pub(super) fn destroy_predicate_snapshot( + &mut self, + predicate: &vir_low::expression::PredicateAccessPredicate, + solver: &EGraphState, + ) -> VerificationResult<()> { + let predicate_snapshots = self.snapshots.get_mut(&predicate.name).unwrap(); + for predicate_snapshot in predicate_snapshots { + if predicate_snapshot.snapshot.is_not_exhaled() + && predicate_snapshot.matches(predicate, solver)? + { + predicate_snapshot.snapshot = PredicateSnapshotState::Exhaled; + } + } + Ok(()) + } + + pub(super) fn find_snapshot( + &self, + predicate_name: &str, + arguments: &[vir_low::Expression], + solver: &EGraphState, + ) -> VerificationResult> { + if let Some(predicate_snapshots) = self.snapshots.get(predicate_name) { + for predicate_snapshot in predicate_snapshots { + if let PredicateSnapshotState::Inhaled(snapshot) = &predicate_snapshot.snapshot { + if predicate_snapshot.matches_arguments(arguments, solver)? { + return Ok(Some(snapshot.clone())); + } + } + } + } + Ok(None) + } + + pub(super) fn into_variables(self) -> Vec { + self.variables + } +} + +#[derive(Clone)] +enum PredicateSnapshotState { + /// The snapshot is valid. + Inhaled(vir_low::VariableDecl), + /// The snapshot was exhaled and no longer valid. + Exhaled, + /// The predicate does not have a snapshot. + NoSnapshot, +} + +impl PredicateSnapshotState { + pub(super) fn is_not_exhaled(&self) -> bool { + matches!( + self, + PredicateSnapshotState::Inhaled(_) | PredicateSnapshotState::NoSnapshot + ) + } +} + +#[derive(Clone)] +pub(super) struct PredicateSnapshot { + /// Predicate arguments. + arguments: Vec, + /// None means that the corresponding predicate was exhaled. + snapshot: PredicateSnapshotState, +} + +impl PredicateSnapshot { + pub(super) fn matches( + &self, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + solver: &EGraphState, + ) -> VerificationResult { + arguments_match(&self.arguments, &predicate.arguments, solver) + } + + pub(super) fn matches_arguments( + &self, + arguments: &[vir_low::Expression], + solver: &EGraphState, + ) -> VerificationResult { + arguments_match(&self.arguments, arguments, solver) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/state.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/state.rs new file mode 100644 index 00000000000..1f714e52e17 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/state.rs @@ -0,0 +1,153 @@ +use super::Location; +use crate::encoder::middle::core_proof::transformations::symbolic_execution::{ + egg::EGraphState, heap::utils::arguments_match, VerificationResult, +}; +use std::collections::BTreeMap; +use vir_crate::{ + common::display, + low::{self as vir_low, expression::visitors::ExpressionFallibleFolder}, +}; + +#[derive(Default, Clone)] +pub(in super::super) struct HeapState { + /// A map from predicate names to their state. + predicates: BTreeMap, +} + +impl std::fmt::Display for HeapState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (predicate_name, predicate_state) in &self.predicates { + writeln!(f, "{}:", predicate_name)?; + for predicate_instance in &predicate_state.instances { + writeln!( + f, + " {} @ {:?}: {}, {:?}", + display::cjoin(&predicate_instance.arguments), + predicate_instance.inhale_location, + predicate_instance.permission_amount, + predicate_instance.state + )?; + } + } + Ok(()) + } +} + +impl HeapState { + pub(super) fn add_predicate_instance( + &mut self, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + inhale_location: Location, + ) { + let predicate_name = predicate.name.clone(); + let predicate_state = + self.predicates + .entry(predicate_name) + .or_insert_with(|| PredicateState { + instances: Vec::new(), + }); + let predicate_instance = PredicateInstance { + arguments: predicate.arguments.clone(), + permission_amount: (*predicate.permission).clone(), + inhale_location, + state: PredicateInstanceState::Fresh, + }; + predicate_state.instances.push(predicate_instance); + } + + pub(super) fn try_removing_predicate_instance( + &mut self, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + exhale_location: Location, + solver: &EGraphState, + ) -> VerificationResult<()> { + if let Some(predicate_state) = self.predicates.get_mut(&predicate.name) { + for predicate_instance in &mut predicate_state.instances { + if matches!( + predicate_instance.state, + PredicateInstanceState::Fresh | PredicateInstanceState::SeenQPExhale + ) { + if predicate_instance.matches(predicate, solver)? { + predicate_instance.state = PredicateInstanceState::Exhaled(exhale_location); + } + } + } + } + Ok(()) + } + + pub(super) fn mark_predicate_instances_as_inhaled(&mut self, predicate_name: &str) { + if let Some(predicate_state) = self.predicates.get_mut(predicate_name) { + for predicate_instance in &mut predicate_state.instances { + if predicate_instance.state == PredicateInstanceState::SeenQPExhale { + predicate_instance.state = PredicateInstanceState::SeenQPInhale; + } + } + } + } + + pub(super) fn mark_predicate_instances_as_exhaled(&mut self, predicate_name: &str) { + if let Some(predicate_state) = self.predicates.get_mut(predicate_name) { + for predicate_instance in &mut predicate_state.instances { + if predicate_instance.state == PredicateInstanceState::Fresh { + predicate_instance.state = PredicateInstanceState::SeenQPExhale; + } + } + } + } + + pub(super) fn get_predicate(&self, predicate_name: &str) -> Option<&PredicateState> { + self.predicates.get(predicate_name) + } +} + +#[derive(Clone)] +pub(super) struct PredicateState { + instances: Vec, +} + +impl PredicateState { + pub(super) fn get_instances(&self) -> &[PredicateInstance] { + &self.instances + } +} + +#[derive(Clone)] +pub(super) struct PredicateInstance { + /// The arguments of the predicate instance. + pub(super) arguments: Vec, + pub(super) permission_amount: vir_low::Expression, + /// The location of the inhale statement that inhaled this predicate instance. + pub(super) inhale_location: Location, + /// The state of the predicate. + pub(super) state: PredicateInstanceState, +} + +impl PredicateInstance { + fn matches( + &self, + predicate: &vir_low::ast::expression::PredicateAccessPredicate, + solver: &EGraphState, + ) -> VerificationResult { + assert_eq!(self.arguments.len(), predicate.arguments.len()); + if !arguments_match(&self.arguments, &predicate.arguments, solver)? { + return Ok(false); + } + // for (self_arg, predicate_arg) in self.arguments.iter().zip(&predicate.arguments) { + // // if self_arg != predicate_arg { + // if !solver.is_equal(self_arg, predicate_arg)? { + // return Ok(false); + // } + // } + Ok(self.permission_amount == *predicate.permission) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub(super) enum PredicateInstanceState { + /// The predicate was inhaled and has not seen QP exhale yet. + Fresh, + SeenQPExhale, + SeenQPInhale, + Exhaled(Location), +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/utils.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/utils.rs new file mode 100644 index 00000000000..cf6f260b711 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap/utils.rs @@ -0,0 +1,20 @@ +use crate::encoder::middle::core_proof::transformations::symbolic_execution::{ + egg::EGraphState, VerificationResult, +}; +use vir_crate::{ + common::display, + low::{self as vir_low, expression::visitors::ExpressionFallibleFolder}, +}; + +pub(super) fn arguments_match( + args1: &[vir_low::Expression], + args2: &[vir_low::Expression], + solver: &EGraphState, +) -> VerificationResult { + for (arg1, arg2) in args1.iter().zip(args2) { + if !solver.is_equal(arg1, arg2)? { + return Ok(false); + } + } + Ok(true) +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap_dependent.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap_dependent.rs new file mode 100644 index 00000000000..989d63a4341 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/heap_dependent.rs @@ -0,0 +1,171 @@ +use super::{ProcedureExecutor, VerificationResult}; +use vir_crate::{ + common::{ + expression::{BinaryOperationHelpers, ExpressionIterator, UnaryOperationHelpers}, + graphviz::ToGraphviz, + position::Positioned, + }, + low::{self as vir_low, expression::visitors::ExpressionFallibleFolder}, +}; + +impl<'a> ProcedureExecutor<'a> { + // pub(super) fn assume_heap_dependent( + // &mut self, + // assertion: vir_low::Expression, + // position: vir_low::Position, + // ) -> VerificationResult<()> { + // let assertion = self.instantiate_function_calls(assertion, position)?; + // self.assume_pure(assertion, position)?; + // Ok(()) + // } + + // pub(super) fn assert_heap_dependent( + // &mut self, + // assertion: vir_low::Expression, + // position: vir_low::Position, + // ) -> VerificationResult<()> { + // let assertion = self.instantiate_function_calls(assertion, position)?; + // self.assert_pure(assertion, position)?; + // Ok(()) + // } + + // pub(super) fn inhale_heap_dependent( + // &mut self, + // assertion: vir_low::Expression, + // position: vir_low::Position, + // ) -> VerificationResult<()> { + // let assertion = self.instantiate_function_calls(assertion, position)?; + // // unimplemented!(); + // Ok(()) + // } + + // pub(super) fn exhale_heap_dependent( + // &mut self, + // assertion: vir_low::Expression, + // position: vir_low::Position, + // ) -> VerificationResult<()> { + // let assertion = self.instantiate_function_calls(assertion, position)?; + // // unimplemented!(); + // Ok(()) + // } + + // fn instantiate_function_calls( + // &mut self, + // expression: vir_low::Expression, + // position: vir_crate::high::Position, + // ) -> VerificationResult { + // let mut instantiatior = FunctionCallInstantiator { + // procedure_executor: self, + // path_condition: Vec::new(), + // position, + // }; + // instantiatior.fallible_fold_expression(expression) + // } +} + +// struct FunctionCallInstantiator<'e, 'a> { +// procedure_executor: &'e mut ProcedureExecutor<'a>, +// path_condition: Vec, +// position: vir_low::Position, +// } + +// impl<'e, 'a: 'e> ExpressionFallibleFolder for FunctionCallInstantiator<'e, 'a> { +// type Error = Error; + +// fn fallible_fold_func_app_enum( +// &mut self, +// func_app: vir_low::expression::FuncApp, +// ) -> Result { +// let function = self.procedure_executor.functions[&func_app.function_name]; +// assert_eq!(function.parameters.len(), func_app.arguments.len()); +// let arguments = func_app +// .arguments +// .into_iter() +// .map(|argument| self.fallible_fold_expression(argument)) +// .collect::, _>>()?; +// assert!( +// self.path_condition.is_empty(), +// "implications should be desugared" +// ); +// // let path_condition = self.path_condition.iter().cloned().conjoin(); +// let replacements = function.parameters.iter().zip(arguments.iter()).collect(); +// let pres = function +// .pres +// .iter() +// .cloned() +// .conjoin() +// .substitute_variables(&replacements); +// let pres = self.fallible_fold_expression(pres)?; +// // let assert_precondition = vir_low::Expression::implies(path_condition, pres); +// self.procedure_executor +// .assert_heap_dependent(pres, self.position)?; +// // self.procedure_executor.assert_heap_dependent( +// // assert_precondition, +// // self.position, +// // )?; +// match function.kind { +// vir_low::FunctionKind::MemoryBlockBytes => todo!(), +// vir_low::FunctionKind::CallerFor => todo!(), +// vir_low::FunctionKind::Snap => { +// // FIXME +// // let predicate_name = self +// // .procedure_executor +// // .get_predicate_name_for_function(&func_app.function_name)?; +// // let heap_version = if let Some(current_state_label) = +// // self.current_state_label +// // { +// // self.heap_encoder +// // .get_heap_version_at_label(&predicate_name, current_state_label)? +// // } else { +// // self.heap_encoder +// // .get_current_heap_version_for(&predicate_name)? +// // }; +// // arguments.push(heap_version); +// // let heap_function_name = self +// // .heap_encoder +// // .get_heap_function_name_for(&predicate_name); +// // let return_type = self +// // .heap_encoder +// // .get_snapshot_type_for_predicate(&predicate_name) +// // .unwrap(); +// // Ok(vir_low::Expression::domain_function_call( +// // "HeapFunctions", +// // heap_function_name, +// // arguments, +// // return_type, +// // )) +// Ok(true.into()) +// } +// } +// } + +// fn fallible_fold_binary_op( +// &mut self, +// mut binary_op: vir_low::expression::BinaryOp, +// ) -> Result { +// binary_op.left = self.fallible_fold_expression_boxed(binary_op.left)?; +// if binary_op.op_kind == vir_low::BinaryOpKind::Implies { +// self.path_condition.push((*binary_op.left).clone()); +// } +// binary_op.right = self.fallible_fold_expression_boxed(binary_op.right)?; +// if binary_op.op_kind == vir_low::BinaryOpKind::Implies { +// self.path_condition.pop(); +// } +// Ok(binary_op) +// } + +// fn fallible_fold_conditional( +// &mut self, +// mut conditional: vir_low::expression::Conditional, +// ) -> Result { +// conditional.guard = self.fallible_fold_expression_boxed(conditional.guard)?; +// self.path_condition.push((*conditional.guard).clone()); +// conditional.then_expr = self.fallible_fold_expression_boxed(conditional.then_expr)?; +// self.path_condition.pop(); +// self.path_condition +// .push(vir_low::Expression::not((*conditional.guard).clone())); +// conditional.else_expr = self.fallible_fold_expression_boxed(conditional.else_expr)?; +// self.path_condition.pop(); +// Ok(conditional) +// } +// } diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/mod.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/mod.rs new file mode 100644 index 00000000000..fed74605550 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/mod.rs @@ -0,0 +1,336 @@ +//! This module contains the symbolic execution engine that is used to purify +//! predicates in the Viper program. This module depends on `ErrorManager` and, +//! therefore, has to be in the `prusti-viper` crate. + +mod trace_builder; +mod egg; +mod statements; +mod heap_dependent; +mod pure; +mod variable_declarations; +mod heap; +mod trace; +mod program_context; + +use self::{egg::EGraphState, program_context::ProgramContext}; +use crate::encoder::{ + errors::SpannedEncodingResult, middle::core_proof::predicates::PredicateInfo, +}; +use log::debug; +use prusti_common::config; +use rustc_hash::FxHashMap; +use std::collections::BTreeMap; +use viper::VerificationError; +use vir_crate::{ + common::{ + expression::{BinaryOperationHelpers, ExpressionIterator, UnaryOperationHelpers}, + graphviz::ToGraphviz, + position::Positioned, + }, + low::{ + self as vir_low, + expression::visitors::{ExpressionFallibleWalker, ExpressionWalker}, + }, +}; + +pub(in super::super) fn purify_with_symbolic_execution( + source_filename: &str, + program: vir_low::Program, + predicate_info: BTreeMap, +) -> SpannedEncodingResult { + let mut executor = Executor::new(); + let program = executor + .execute(source_filename, program, predicate_info) + .unwrap(); + Ok(program) +} + +#[derive(Debug)] +pub enum Error { + InconsistentState, +} + +pub type VerificationResult = std::result::Result; + +struct Executor { + verification_errors: Vec, +} + +struct ProcedureExecutor<'a> { + executor: &'a mut Executor, + source_filename: &'a str, + program_context: &'a ProgramContext<'a>, + // functions: FxHashMap, + // new_variables: VariableDeclarations, + continuations: Vec, + exhale_label_generator_counter: u64, + /// The execution trace showing in what order the statements were executed. + execution_trace_builder: trace_builder::ExecutionTraceBuilder, + // /// The execution trace containing information necessary for performing + // /// purification. + // execution_trace_heap: heap::ExecutionTraceHeap, + /// The original execution traces. + original_traces: Vec, + /// Traces in which purifiable predicates were purified. + final_traces: Vec, + // /// The execution trace showing the pure assert and assume statements. + // execution_trace_pure: trace::ExecutionTrace, + // current_egraph_state: EGraphState, +} + +#[derive(Debug)] +pub struct Continuation { + next_block_label: vir_low::Label, + parent_block_label: vir_low::Label, + condition: vir_low::Expression, +} + +impl Executor { + pub(crate) fn new() -> Self { + Self { + verification_errors: Vec::new(), + } + } + + pub(crate) fn execute( + &mut self, + source_filename: &str, + mut program: vir_low::Program, + predicate_info: BTreeMap, + ) -> VerificationResult { + let program_context = + ProgramContext::new(&program.functions, &program.predicates, predicate_info); + let mut new_procedures = Vec::new(); + for procedure in program.procedures { + let procedure_executor = + ProcedureExecutor::new(self, source_filename, &program_context); + procedure_executor.execute_procedure(procedure, &mut new_procedures)?; + } + program.procedures = new_procedures; + Ok(program) + } +} + +impl<'a> ProcedureExecutor<'a> { + fn new( + executor: &'a mut Executor, + source_filename: &'a str, + program_context: &'a ProgramContext<'a>, + ) -> Self { + Self { + executor, + source_filename, + continuations: Vec::new(), + // new_variables: Default::default(), + exhale_label_generator_counter: 0, + execution_trace_builder: trace_builder::ExecutionTraceBuilder::new(), + // execution_trace_heap: heap::ExecutionTraceHeap::new(), + // execution_trace_pure: trace::ExecutionTrace::new(), + // current_egraph_state: EGraphState::new(), + program_context, + original_traces: Vec::new(), + final_traces: Vec::new(), + } + } + + fn execute_procedure( + mut self, + procedure: vir_low::ProcedureDecl, + new_procedures: &mut Vec, + ) -> VerificationResult<()> { + debug!("Executing procedure {}", procedure.name); + let mut current_block = procedure.entry.clone(); + loop { + let block = procedure.basic_blocks.get(¤t_block).unwrap(); + self.execute_block(¤t_block, block)?; + if self + .execution_trace_builder + .current_egraph_state() + .is_inconsistent()? + { + self.finalize_trace()?; + if let Some(new_current_block) = self.next_continuation(procedure.position)? { + current_block = new_current_block; + continue; + } else { + break; + } + } + match &block.successor { + vir_low::Successor::Return => { + self.finalize_trace()?; + if let Some(new_current_block) = self.next_continuation(procedure.position)? { + current_block = new_current_block; + } else { + break; + } + } + vir_low::Successor::Goto(label) => current_block = label.clone(), + vir_low::Successor::GotoSwitch(targets) => { + let parent_block_label = current_block.clone(); + self.execution_trace_builder + .add_split_point(parent_block_label.clone())?; + // Since the jumps are evaluated one after another, we need + // to negate all the previous conditions when considering + // the new one. + let mut negated_conditions = Vec::new(); + let mut targets = targets.iter(); + let (condition, label) = targets.next().unwrap(); + self.assume_condition(condition.clone(), procedure.position)?; + current_block = label.clone(); + negated_conditions.push(UnaryOperationHelpers::not(condition.clone())); + for (condition, label) in targets { + let continuation = Continuation { + next_block_label: label.clone(), + parent_block_label: parent_block_label.clone(), + condition: vir_low::Expression::and( + negated_conditions.clone().into_iter().conjoin(), + condition.clone(), + ), + }; + self.continuations.push(continuation); + negated_conditions.push(UnaryOperationHelpers::not(condition.clone())); + } + } + } + } + if prusti_common::config::dump_debug_info() { + prusti_common::report::log::report_with_writer( + "graphviz_method_vir_symbex_original", + format!("{}.{}.dot", self.source_filename, procedure.name), + |writer| { + self.execution_trace_builder + .original_view() + .to_graphviz(writer) + .unwrap() + }, + ); + for (i, trace) in self.original_traces.iter().enumerate() { + prusti_common::report::log::report_with_writer( + "vir_symbex_original_traces", + format!("{}.{}.{}.vpr", self.source_filename, procedure.name, i), + |writer| trace.write(writer).unwrap(), + ); + } + prusti_common::report::log::report_with_writer( + "graphviz_method_vir_symbex_optimized", + format!("{}.{}.dot", self.source_filename, procedure.name), + |writer| { + self.execution_trace_builder + .heap_view() + .to_graphviz(writer) + .unwrap() + }, + ); + for (i, trace) in self.final_traces.iter().enumerate() { + prusti_common::report::log::report_with_writer( + "vir_symbex_optimized_traces", + format!("{}.{}.{}.vpr", self.source_filename, procedure.name, i), + |writer| trace.write(writer).unwrap(), + ); + } + } + if config::purify_with_symbolic_execution() { + for (i, trace) in self.final_traces.into_iter().enumerate() { + let new_procedure = trace.into_procedure(i, &procedure); + new_procedures.push(new_procedure); + } + } else { + for (i, trace) in self.original_traces.into_iter().enumerate() { + let new_procedure = trace.into_procedure(i, &procedure); + new_procedures.push(new_procedure); + } + } + Ok(()) + } + + fn next_continuation( + &mut self, + default_position: vir_low::Position, + ) -> VerificationResult> { + while let Some(continuation) = self.continuations.pop() { + debug!("Rolling back to {}", continuation.parent_block_label); + self.execution_trace_builder + .rollback_to_split_point(continuation.parent_block_label)?; + self.assume_condition(continuation.condition, default_position)?; + if self + .execution_trace_builder + .current_egraph_state() + .is_inconsistent()? + { + debug!("Inconsistent after rollback"); + self.execution_trace_builder.remove_last_block()?; + } else { + return Ok(Some(continuation.next_block_label)); + } + } + Ok(None) + } + + fn assume_condition( + &mut self, + condition: vir_low::Expression, + default_position: vir_low::Position, + ) -> VerificationResult<()> { + self.execution_trace_builder + .current_egraph_state() + .assume_heap_independent_conjuncts(&condition)?; + self.execution_trace_builder + .heap_assume(condition.clone(), default_position)?; + self.execution_trace_builder + .add_original_statement(vir_low::Statement::assume(condition, default_position))?; + Ok(()) + } + + fn execute_block( + &mut self, + current_block: &vir_low::Label, + block: &vir_low::BasicBlock, + ) -> VerificationResult<()> { + debug!("Executing block {}", current_block); + for statement in &block.statements { + self.execute_statement(current_block, statement)?; + } + Ok(()) + } + + fn finalize_trace(&mut self) -> VerificationResult<()> { + // let trace = self + // .execution_trace_heap + // .finalize_trace(&self.current_egraph_state, self.program_context)?; + // self.final_traces.push((trace, self.new_variables.take_variables())); + let (original_trace, final_trace) = self + .execution_trace_builder + .finalize_trace(self.program_context)?; + self.original_traces.push(original_trace); + self.final_traces.push(final_trace); + Ok(()) + } + + fn intern_function_arguments( + &mut self, + expression: &vir_low::Expression, + ) -> VerificationResult<()> { + struct Walker<'a> { + egraph: &'a mut EGraphState, + } + impl<'a> ExpressionFallibleWalker for Walker<'a> { + type Error = Error; + fn fallible_walk_func_app_enum( + &mut self, + func_app: &vir_low::expression::FuncApp, + ) -> VerificationResult<()> { + for argument in &func_app.arguments { + if argument.is_heap_independent() { + self.egraph.intern_term(argument)?; + } + } + self.fallible_walk_func_app(func_app) + } + } + let mut walker = Walker { + egraph: self.execution_trace_builder.current_egraph_state(), + }; + walker.fallible_walk_expression(expression) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/program_context.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/program_context.rs new file mode 100644 index 00000000000..153107465ce --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/program_context.rs @@ -0,0 +1,66 @@ +use crate::encoder::middle::core_proof::predicates::PredicateInfo; +use rustc_hash::FxHashMap; +use std::collections::BTreeMap; +use vir_crate::low::{self as vir_low}; + +pub(super) struct ProgramContext<'a> { + // FIXME: Use this: snapshot_functions_to_predicates: BTreeMap, + functions: FxHashMap, + predicates_to_snapshot_types: BTreeMap, + predicate_decls: FxHashMap, +} + +impl<'a> ProgramContext<'a> { + pub(super) fn new( + functions: &'a [vir_low::FunctionDecl], + predicate_decls: &'a [vir_low::PredicateDecl], + predicate_info: BTreeMap, + ) -> Self { + let mut predicates_to_snapshot_types = BTreeMap::new(); + for ( + predicate_name, + PredicateInfo { + snapshot_function_name: _, + snapshot_type, + }, + ) in predicate_info + { + predicates_to_snapshot_types.insert(predicate_name, snapshot_type); + } + Self { + predicates_to_snapshot_types, + functions: functions + .iter() + .map(|function| (function.name.clone(), function)) + .collect(), + predicate_decls: predicate_decls + .iter() + .map(|predicate| (predicate.name.clone(), predicate)) + .collect(), + } + } + + pub(super) fn get_function(&self, name: &str) -> &'a vir_low::FunctionDecl { + self.functions.get(name).unwrap() + } + + pub(super) fn get_snapshot_type(&self, predicate_name: &str) -> Option { + // FIXME: Code duplication with + // prusti-viper/src/encoder/middle/core_proof/transformations/custom_heap_encoding/heap_encoder/predicates.rs + let predicate = self.predicate_decls[predicate_name]; + match predicate.kind { + vir_low::PredicateKind::MemoryBlock => { + use vir_low::macros::*; + Some(ty!(Bytes)) + } + vir_low::PredicateKind::Owned => Some( + self.predicates_to_snapshot_types + .get(predicate_name) + .unwrap_or_else(|| unreachable!("predicate not found: {}", predicate_name)) + .clone(), + ), + vir_low::PredicateKind::WithoutSnapshotWhole + | vir_low::PredicateKind::WithoutSnapshotFrac => None, + } + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/pure.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/pure.rs new file mode 100644 index 00000000000..db649b75546 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/pure.rs @@ -0,0 +1,92 @@ +use super::{ProcedureExecutor, VerificationResult}; +use vir_crate::{ + common::{ + expression::{BinaryOperationHelpers, ExpressionIterator}, + graphviz::ToGraphviz, + position::Positioned, + }, + low::{self as vir_low, expression::visitors::ExpressionFallibleFolder}, +}; + +impl<'a> ProcedureExecutor<'a> { + // pub(super) fn comment(&mut self, comment: String) -> VerificationResult<()> { + // self.execution_trace_pure + // .add_statement(vir_low::Statement::comment(comment)) + // } + + // pub(super) fn assume_pure( + // &mut self, + // assertion: vir_low::Expression, + // position: vir_low::Position, + // ) -> VerificationResult<()> { + // self.assume_expression(&assertion, position)?; + // self.execution_trace_pure + // .add_statement(vir_low::Statement::assume(assertion, position))?; + // if self.current_egraph_state.is_inconsistent()? { + // self.execution_trace_pure + // .add_statement(vir_low::Statement::comment( + // "Reached inconsistent state.".to_string(), + // ))?; + // return Err(Error::InconsistentState); + // } + // Ok(()) + // } + + // pub(super) fn assert_pure( + // &mut self, + // assertion: vir_low::Expression, + // position: vir_low::Position, + // ) -> VerificationResult<()> { + // self.execution_trace_pure + // .add_statement(vir_low::Statement::assert(assertion, position))?; + // Ok(()) + // } + + // fn assume_expression( + // &mut self, + // expression: &vir_low::Expression, + // position: vir_low::Position, + // ) -> VerificationResult<()> { + // match expression { + // vir_low::Expression::Local(expression) => todo!("expression: {}", expression), + // vir_low::Expression::Field(expression) => todo!("expression: {}", expression), + // vir_low::Expression::LabelledOld(expression) => todo!("expression: {}", expression), + // vir_low::Expression::Constant(constant_expression) => { + // assert!(constant_expression.ty.is_bool()); + // self.current_egraph_state.assume(expression)?; + // } + // vir_low::Expression::MagicWand(expression) => todo!("expression: {}", expression), + // vir_low::Expression::PredicateAccessPredicate(expression) => { + // todo!("expression: {}", expression) + // } + // vir_low::Expression::FieldAccessPredicate(expression) => { + // todo!("expression: {}", expression) + // } + // vir_low::Expression::Unfolding(expression) => todo!("expression: {}", expression), + // vir_low::Expression::UnaryOp(expression) => todo!("expression: {}", expression), + // vir_low::Expression::BinaryOp(binary_op_expression) => { + // match binary_op_expression.op_kind { + // vir_low::BinaryOpKind::EqCmp => { + // self.current_egraph_state.assume_equal( + // &binary_op_expression.left, + // &binary_op_expression.right, + // )?; + // } + // vir_low::BinaryOpKind::And => todo!(), + // _ => self.current_egraph_state.assume(expression)?, + // } + // } + // vir_low::Expression::PermBinaryOp(_) => todo!(), + // vir_low::Expression::ContainerOp(_) => todo!(), + // vir_low::Expression::Conditional(_) => todo!(), + // vir_low::Expression::Quantifier(_) => todo!(), + // vir_low::Expression::LetExpr(_) => todo!(), + // vir_low::Expression::FuncApp(_) => todo!(), + // vir_low::Expression::DomainFuncApp(_) => { + // self.current_egraph_state.assume(expression)?; + // } + // vir_low::Expression::InhaleExhale(_) => todo!(), + // } + // Ok(()) + // } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/statements.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/statements.rs new file mode 100644 index 00000000000..658d602a37e --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/statements.rs @@ -0,0 +1,241 @@ +use super::{ProcedureExecutor, VerificationResult}; +use vir_crate::{ + common::expression::BinaryOperationHelpers, + low::{self as vir_low}, +}; + +impl<'a> ProcedureExecutor<'a> { + pub(super) fn execute_statement( + &mut self, + current_block: &vir_low::Label, + statement: &vir_low::Statement, + ) -> VerificationResult<()> { + self.execution_trace_builder + .add_original_statement(statement.clone())?; + match statement { + vir_low::Statement::Comment(statement) => { + self.execute_statement_comment(statement)?; + } + vir_low::Statement::Label(statement) => { + self.execute_statement_label(statement)?; + } + vir_low::Statement::LogEvent(statement) => { + self.execute_statement_log_event(statement)?; + } + vir_low::Statement::Assume(statement) => { + self.execute_statement_assume(statement)?; + } + vir_low::Statement::Assert(statement) => { + self.execute_statement_assert(statement)?; + } + vir_low::Statement::Inhale(statement) => { + self.execute_statement_inhale(statement)?; + } + vir_low::Statement::Exhale(statement) => { + self.execute_statement_exhale(statement)?; + } + vir_low::Statement::Fold(_) => todo!(), + vir_low::Statement::Unfold(_) => todo!(), + vir_low::Statement::ApplyMagicWand(_) => { + unreachable!(); + } + vir_low::Statement::MethodCall(_) => { + unreachable!(); + } + vir_low::Statement::Assign(statement) => { + self.execute_statement_assign(statement)?; + } + vir_low::Statement::Conditional(_) => { + unreachable!(); + } + } + Ok(()) + } + + fn execute_statement_comment( + &mut self, + statement: &vir_low::ast::statement::Comment, + ) -> VerificationResult<()> { + self.execution_trace_builder + .heap_comment(statement.clone())?; + // self.comment(statement.comment.clone())?; + Ok(()) + } + + fn execute_statement_label( + &mut self, + statement: &vir_low::ast::statement::Label, + ) -> VerificationResult<()> { + self.execution_trace_builder.heap_label(statement.clone())?; + Ok(()) + } + + fn execute_statement_log_event( + &mut self, + statement: &vir_low::ast::statement::LogEvent, + ) -> VerificationResult<()> { + // self.assume_pure(statement.expression.clone(), Default::default())?; + self.execution_trace_builder + .current_egraph_state() + .assume_heap_independent_conjuncts(&statement.expression)?; + self.execution_trace_builder + .heap_assume(statement.expression.clone(), statement.position)?; + Ok(()) + } + + fn execute_statement_assume( + &mut self, + statement: &vir_low::ast::statement::Assume, + ) -> VerificationResult<()> { + // self.assume_heap_dependent(statement.expression.clone(), statement.position)?; + self.intern_function_arguments(&statement.expression)?; + self.execution_trace_builder + .current_egraph_state() + .assume_heap_independent_conjuncts(&statement.expression)?; + self.execution_trace_builder + .heap_assume(statement.expression.clone(), statement.position)?; + Ok(()) + } + + fn execute_statement_assert( + &mut self, + statement: &vir_low::ast::statement::Assert, + ) -> VerificationResult<()> { + // self.assert_heap_dependent(statement.expression.clone(), statement.position)?; + self.intern_function_arguments(&statement.expression)?; + self.execution_trace_builder + .heap_assert(statement.expression.clone(), statement.position)?; + Ok(()) + } + + /// Returns true if all arguments are valid terms; that is they are heap + /// independent. + fn check_and_register_terms( + &mut self, + arguments: &[vir_low::Expression], + ) -> VerificationResult { + let mut all_arguments_heap_independent = true; + for argument in arguments { + if argument.is_heap_independent() { + self.execution_trace_builder + .current_egraph_state() + .intern_term(argument)?; + } else { + all_arguments_heap_independent = false; + } + } + Ok(all_arguments_heap_independent) + } + + fn execute_statement_inhale( + &mut self, + statement: &vir_low::ast::statement::Inhale, + ) -> VerificationResult<()> { + // self.inhale_heap_dependent(statement.expression.clone(), statement.position)?; + self.intern_function_arguments(&statement.expression)?; + self.execute_inhale(&statement.expression, statement.position)?; + Ok(()) + } + + fn execute_inhale( + &mut self, + expression: &vir_low::Expression, + position: vir_low::Position, + ) -> VerificationResult<()> { + if let vir_low::Expression::BinaryOp(expression) = expression { + if expression.op_kind == vir_low::BinaryOpKind::And { + self.execute_inhale(&expression.left, position)?; + self.execute_inhale(&expression.right, position)?; + return Ok(()); + } + } + if let vir_low::Expression::PredicateAccessPredicate(predicate) = expression { + if predicate.permission.is_full_permission() + && self.check_and_register_terms(&predicate.arguments)? + { + self.execution_trace_builder + .heap_inhale_predicate(predicate.clone(), position)?; + return Ok(()); + } + } + self.execution_trace_builder + .heap_inhale_generic(expression.clone(), position)?; + Ok(()) + } + + fn execute_statement_exhale( + &mut self, + statement: &vir_low::ast::statement::Exhale, + ) -> VerificationResult<()> { + self.intern_function_arguments(&statement.expression)?; + // self.exhale_heap_dependent(statement.expression.clone(), statement.position)?; + let exhale_label = format!("exhale_label${}", self.exhale_label_generator_counter); + self.exhale_label_generator_counter += 1; + self.execution_trace_builder + .heap_label(vir_low::ast::statement::Label { + label: exhale_label.clone(), + position: statement.position, + })?; + self.execution_trace_builder + .register_label(vir_low::Label::new(&exhale_label))?; + self.execute_exhale(&statement.expression, statement.position, &exhale_label)?; + Ok(()) + } + + fn execute_exhale( + &mut self, + expression: &vir_low::Expression, + position: vir_low::Position, + exhale_label: &str, + ) -> VerificationResult<()> { + if let vir_low::Expression::BinaryOp(expression) = expression { + if expression.op_kind == vir_low::BinaryOpKind::And { + self.execute_exhale(&expression.left, position, exhale_label)?; + self.execute_exhale(&expression.right, position, exhale_label)?; + return Ok(()); + } + } + if let vir_low::Expression::PredicateAccessPredicate(predicate) = expression { + if predicate.permission.is_full_permission() + && self.check_and_register_terms(&predicate.arguments)? + { + self.execution_trace_builder + .current_egraph_state() + .saturate()?; + self.execution_trace_builder + .heap_exhale_predicate(predicate.clone(), position)?; + return Ok(()); + } + } + let expression = expression.clone().wrap_in_old(exhale_label); + self.execution_trace_builder + .heap_exhale_generic(expression, position)?; + Ok(()) + } + + fn execute_statement_assign( + &mut self, + statement: &vir_low::ast::statement::Assign, + ) -> VerificationResult<()> { + assert!( + !statement.position.is_default(), + "{statement} has no position" + ); + assert!(statement.value.is_constant()); + let target_variable = self + .execution_trace_builder + .create_new_bool_variable_version(&statement.target.name)?; + // self.assume_heap_dependent( + // vir_low::Expression::equals(target_variable.into(), statement.value.clone()), + // statement.position, + // )?; + let expression = + vir_low::Expression::equals(target_variable.into(), statement.value.clone()); + self.execution_trace_builder + .current_egraph_state() + .assume_heap_independent_conjuncts(&expression)?; + self.execution_trace_builder + .heap_assume(expression, statement.position)?; + Ok(()) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace.rs new file mode 100644 index 00000000000..fe3bbf7de4a --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace.rs @@ -0,0 +1,44 @@ +use std::collections::BTreeMap; +use vir_crate::low::{self as vir_low}; + +pub(super) struct Trace { + pub statements: Vec, + pub variables: Vec, + pub labels: Vec, +} + +impl Trace { + pub(super) fn write(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> { + for statement in &self.statements { + writeln!(writer, "{}", statement)?; + } + Ok(()) + } + + pub(super) fn into_procedure( + self, + trace_index: usize, + original_procedure: &vir_low::ProcedureDecl, + ) -> vir_low::ProcedureDecl { + let entry = vir_low::Label::new("trace_start"); + let exit = vir_low::Label::new("trace_end"); + let entry_block = + vir_low::BasicBlock::new(self.statements, vir_low::Successor::Goto(exit.clone())); + let exit_block = vir_low::BasicBlock::new(Vec::new(), vir_low::Successor::Return); + let mut basic_blocks = BTreeMap::new(); + basic_blocks.insert(entry.clone(), entry_block); + basic_blocks.insert(exit.clone(), exit_block); + let mut locals = original_procedure.locals.clone(); + locals.extend(self.variables); + let mut custom_labels = original_procedure.custom_labels.clone(); + custom_labels.extend(self.labels); + vir_low::ProcedureDecl::new( + format!("{}$trace_{}", original_procedure.name, trace_index), + locals, + custom_labels, + entry, + exit, + basic_blocks, + ) + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace_builder/heap_view.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace_builder/heap_view.rs new file mode 100644 index 00000000000..82cb056161d --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace_builder/heap_view.rs @@ -0,0 +1,53 @@ +use super::{ExecutionTraceBlock, ExecutionTraceBuilder}; +use crate::encoder::middle::core_proof::transformations::symbolic_execution::heap::HeapEntry; +use vir_crate::low::{self as vir_low}; + +pub(in super::super) struct ExecutionTraceHeapView<'a> { + pub(super) trace: &'a ExecutionTraceBuilder, +} + +pub(in super::super) struct BlockView<'a> { + block: &'a ExecutionTraceBlock, +} + +impl<'a> ExecutionTraceHeapView<'a> { + pub(in super::super) fn iter_blocks(&self) -> impl Iterator> { + self.trace.blocks.iter().map(|block| BlockView { block }) + } + + pub(in super::super) fn block_count(&self) -> usize { + self.trace.blocks.len() + } + + pub(in super::super) fn last_block_id(&self) -> usize { + self.trace.blocks.len() - 1 + } + + pub(in super::super) fn last_block_entry_count(&self) -> usize { + self.trace.blocks.last().unwrap().heap_statements.len() + } + + pub(in super::super) fn get_block(&self, id: usize) -> BlockView<'a> { + BlockView { + block: &self.trace.blocks[id], + } + } +} + +impl<'a> BlockView<'a> { + pub(in super::super) fn iter_entries(&self) -> impl Iterator { + self.block.heap_statements.iter() + } + + pub(in super::super) fn parent(&self) -> Option { + self.block.parent + } + + pub(in super::super) fn get_new_variables(&self) -> &[vir_low::VariableDecl] { + &self.block.new_variables + } + + pub(in super::super) fn get_new_labels(&self) -> &[vir_low::Label] { + &self.block.new_labels + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace_builder/mod.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace_builder/mod.rs new file mode 100644 index 00000000000..58e4dab542f --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace_builder/mod.rs @@ -0,0 +1,223 @@ +mod original_view; +mod heap_view; + +use super::{ + egg::EGraphState, + heap::{HeapEntry, HeapState}, + program_context::ProgramContext, + trace::Trace, + VerificationResult, +}; +use rustc_hash::FxHashMap; +use std::collections::BTreeMap; +use vir_crate::low::{self as vir_low}; + +pub(super) use self::{ + heap_view::ExecutionTraceHeapView, original_view::ExecutionTraceOriginalView, +}; + +pub(super) struct ExecutionTraceBuilder { + blocks: Vec, + split_point_parents: BTreeMap, + variable_versions: FxHashMap, +} + +impl ExecutionTraceBuilder { + pub(super) fn new() -> Self { + // let initial_block = ExecutionTraceBlock { + // parent: None, + // original_statements: Vec::new(), + // }; + let initial_block = ExecutionTraceBlock::root(); + Self { + blocks: vec![initial_block], + split_point_parents: Default::default(), + variable_versions: Default::default(), + } + } + + fn current_block(&self) -> &ExecutionTraceBlock { + self.blocks.last().unwrap() + } + + fn current_block_mut(&mut self) -> &mut ExecutionTraceBlock { + self.blocks.last_mut().unwrap() + } + + pub(super) fn current_egraph_state(&mut self) -> &mut EGraphState { + &mut self.current_block_mut().egraph_state + } + + pub(super) fn current_heap_state(&mut self) -> &mut HeapState { + &mut self.current_block_mut().heap_state + } + + pub(super) fn current_heap_and_egraph_state(&self) -> (&HeapState, &EGraphState) { + let current_block = self.current_block(); + (¤t_block.heap_state, ¤t_block.egraph_state) + } + + pub(super) fn current_heap_and_egraph_state_mut( + &mut self, + ) -> (&mut HeapState, &mut EGraphState) { + let current_block = self.current_block_mut(); + ( + &mut current_block.heap_state, + &mut current_block.egraph_state, + ) + } + + pub(super) fn add_original_statement( + &mut self, + statement: vir_low::Statement, + ) -> VerificationResult<()> { + let current_block = self.current_block_mut(); + current_block.original_statements.push(statement); + Ok(()) + } + + pub(super) fn add_heap_entry(&mut self, entry: HeapEntry) -> VerificationResult<()> { + let current_block = self.current_block_mut(); + current_block.heap_statements.push(entry); + Ok(()) + } + + pub(super) fn add_split_point( + &mut self, + parent_block_label: vir_low::Label, + ) -> VerificationResult<()> { + let parent_id = self.blocks.len() - 1; + let new_block = ExecutionTraceBlock::from_parent(parent_id, self.current_block()); + self.blocks.push(new_block); + self.split_point_parents + .insert(parent_block_label, parent_id); + Ok(()) + } + + pub(super) fn rollback_to_split_point( + &mut self, + split_point_label: vir_low::Label, + ) -> VerificationResult<()> { + let parent_id = self.split_point_parents[&split_point_label]; + let parent = &self.blocks[parent_id]; + let new_block = ExecutionTraceBlock::from_parent(parent_id, parent); + self.blocks.push(new_block); + Ok(()) + } + + pub(super) fn original_view(&self) -> ExecutionTraceOriginalView { + ExecutionTraceOriginalView { trace: self } + } + + pub(super) fn heap_view(&self) -> ExecutionTraceHeapView { + ExecutionTraceHeapView { trace: self } + } + + pub(super) fn create_new_bool_variable_version( + &mut self, + variable_name: &str, + ) -> VerificationResult { + let version = self + .variable_versions + .entry(variable_name.to_string()) + .or_default(); + *version += 1; + let version = *version; + let variable = vir_low::VariableDecl::new( + format!("{}${}", variable_name, version), + vir_low::Type::Bool, + ); + self.current_block_mut() + .new_variables + .push(variable.clone()); + Ok(variable) + } + + pub(super) fn register_label(&mut self, label: vir_low::Label) -> VerificationResult<()> { + self.current_block_mut().new_labels.push(label); + Ok(()) + } + + pub(super) fn finalize_trace( + &self, + program: &ProgramContext, + ) -> VerificationResult<(Trace, Trace)> { + let mut original_trace = Trace { + statements: Vec::new(), + variables: Vec::new(), + labels: Vec::new(), + }; + self.finalize_original_trace_for_block(&mut original_trace, self.blocks.len() - 1)?; + let final_trace = self.heap_finalize_trace(program)?; + Ok((original_trace, final_trace)) + } + + fn finalize_original_trace_for_block( + &self, + trace: &mut Trace, + block_id: usize, + ) -> VerificationResult<()> { + let block = &self.blocks[block_id]; + if let Some(parent_id) = block.parent { + self.finalize_original_trace_for_block(trace, parent_id)?; + } + for statement in &block.original_statements { + trace.statements.push(statement.clone()); + } + Ok(()) + } + + /// Removes the last block from the trace. This method should be used only + /// when the last method is a freshly added unreachable branch. + pub(super) fn remove_last_block(&mut self) -> VerificationResult<()> { + let last_block = self.blocks.pop().unwrap(); + assert_eq!(last_block.original_statements.len(), 1); + assert_eq!(last_block.heap_statements.len(), 1); + Ok(()) + } +} + +struct ExecutionTraceBlock { + /// The parent of this block. The root does not have a parent. + parent: Option, + /// New variables declared while executing the trace. + new_variables: Vec, + /// New labels declared while executing the trace. + new_labels: Vec, + /// Original statements that were executed in the trace. + original_statements: Vec, + /// Statements that make the heap operations more explicit. + heap_statements: Vec, + /// The last heap state. If the block is fully executed, it is the state + /// after the last statement. + heap_state: HeapState, + /// The last e-graph state. If the block is fully executed, it is the state + /// after the last statement. + egraph_state: EGraphState, +} + +impl ExecutionTraceBlock { + fn root() -> Self { + Self { + parent: None, + new_variables: Vec::new(), + new_labels: Vec::new(), + original_statements: Vec::new(), + heap_statements: Vec::new(), + heap_state: HeapState::default(), + egraph_state: EGraphState::new(), + } + } + + fn from_parent(parent_id: usize, parent: &Self) -> Self { + Self { + parent: Some(parent_id), + new_variables: Vec::new(), + new_labels: Vec::new(), + original_statements: Vec::new(), + heap_statements: Vec::new(), + heap_state: parent.heap_state.clone(), + egraph_state: parent.egraph_state.clone(), + } + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace_builder/original_view.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace_builder/original_view.rs new file mode 100644 index 00000000000..0d7a79c4e78 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/trace_builder/original_view.rs @@ -0,0 +1,32 @@ +use super::ExecutionTraceBuilder; +use vir_crate::{ + common::graphviz::{escape_html, Graph, ToGraphviz}, + low::{self as vir_low}, +}; + +pub(in super::super) struct ExecutionTraceOriginalView<'a> { + pub(super) trace: &'a ExecutionTraceBuilder, +} + +impl<'a> ToGraphviz for ExecutionTraceOriginalView<'a> { + fn to_graph(&self) -> Graph { + let mut graph = Graph::with_columns(&["statement"]); + for (block_id, block) in self.trace.blocks.iter().enumerate() { + let mut node_builder = graph.create_node(format!("block{}", block_id)); + for statement in &block.original_statements { + let statement_string = match statement { + vir_low::Statement::Comment(statement) => { + format!("{}", escape_html(statement)) + } + _ => escape_html(statement.to_string()), + }; + node_builder.add_row_sequence(vec![statement_string]); + } + node_builder.build(); + if let Some(parent) = block.parent { + graph.add_regular_edge(format!("block{}", parent), format!("block{}", block_id)); + } + } + graph + } +} diff --git a/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/variable_declarations.rs b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/variable_declarations.rs new file mode 100644 index 00000000000..ac9499a45b6 --- /dev/null +++ b/prusti-viper/src/encoder/middle/core_proof/transformations/symbolic_execution/variable_declarations.rs @@ -0,0 +1,39 @@ +// use super::VerificationResult; +// use rustc_hash::{FxHashMap, FxHashSet}; +// use vir_crate::low::{self as vir_low}; + +// #[derive(Default)] +// pub(super) struct VariableDeclarations { +// variable_versions: FxHashMap, +// variables: FxHashSet, +// } + +// impl VariableDeclarations { +// fn create_variable( +// &mut self, +// variable_name: &str, +// ty: vir_low::Type, +// version: u64, +// ) -> VerificationResult { +// let variable = vir_low::VariableDecl::new(format!("{}${}", variable_name, version), ty); +// self.variables.insert(variable.clone()); +// Ok(variable) +// } + +// pub(super) fn create_new_variable_version( +// &mut self, +// variable_name: &str, +// ) -> VerificationResult { +// let version = self +// .variable_versions +// .entry(variable_name.to_string()) +// .or_default(); +// *version += 1; +// let version = *version; +// self.create_variable(variable_name, vir_low::Type::Bool, version) +// } + +// pub(super) fn take_variables(&mut self) -> Vec { +// std::mem::take(&mut self.variables).into_iter().collect() +// } +// } diff --git a/prusti-viper/src/encoder/middle/core_proof/type_layouts/interface.rs b/prusti-viper/src/encoder/middle/core_proof/type_layouts/interface.rs index 90c7c1bac67..ab15a5d37aa 100644 --- a/prusti-viper/src/encoder/middle/core_proof/type_layouts/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/type_layouts/interface.rs @@ -3,7 +3,9 @@ use crate::encoder::{ high::{type_layouts::HighTypeLayoutsEncoderInterface, types::HighTypeEncoderInterface}, middle::core_proof::{ lowerer::Lowerer, - snapshots::{IntoBuiltinMethodSnapshot, IntoProcedureSnapshot, IntoSnapshot}, + snapshots::{ + IntoBuiltinMethodSnapshot, IntoProcedureSnapshot, IntoSnapshot, SnapshotValuesInterface, + }, }, }; use vir_crate::{ @@ -23,6 +25,14 @@ pub(in super::super) trait TypeLayoutsInterface { ty: &vir_mid::Type, generics: &impl WithConstArguments, ) -> SpannedEncodingResult; + /// The size multiplied by `repetitions`. + fn encode_type_size_expression_repetitions( + &mut self, + ty: &vir_mid::Type, + generics: &impl WithConstArguments, + repetitions: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult; fn encode_type_padding_size_expression( &mut self, ty: &vir_mid::Type, @@ -57,6 +67,24 @@ impl<'p, 'v: 'p, 'tcx: 'v> TypeLayoutsInterface for Lowerer<'p, 'v, 'tcx> { ); size.to_builtin_method_snapshot(self) } + fn encode_type_size_expression_repetitions( + &mut self, + ty: &vir_mid::Type, + generics: &impl WithConstArguments, + repetitions: vir_low::Expression, + position: vir_low::Position, + ) -> SpannedEncodingResult { + let size = self.encode_type_size_expression2(ty, generics)?; + let size_type = self.size_type_mid()?; + self.construct_binary_op_snapshot( + vir_mid::BinaryOpKind::Mul, + &size_type, + &size_type, + repetitions, + size, + position, + ) + } fn encode_type_padding_size_expression( &mut self, ty: &vir_mid::Type, diff --git a/prusti-viper/src/encoder/middle/core_proof/types/interface.rs b/prusti-viper/src/encoder/middle/core_proof/types/interface.rs index aa2b046f5dd..01db9906d39 100644 --- a/prusti-viper/src/encoder/middle/core_proof/types/interface.rs +++ b/prusti-viper/src/encoder/middle/core_proof/types/interface.rs @@ -15,11 +15,14 @@ use crate::encoder::{ high::types::HighTypeEncoderInterface, middle::core_proof::{ addresses::AddressesInterface, + footprint::FootprintInterface, lowerer::{DomainsLowererInterface, Lowerer}, snapshots::{ - IntoPureSnapshot, IntoSnapshot, SnapshotAdtsInterface, SnapshotDomainsInterface, - SnapshotValidityInterface, SnapshotValuesInterface, + IntoFramedPureSnapshot, IntoPureBoolExpression, IntoPureSnapshot, IntoSnapshot, + IntoSnapshotLowerer, SnapshotAdtsInterface, SnapshotDomainsInterface, + SnapshotValidityInterface, SnapshotValuesInterface, ValidityAssertionToSnapshot, }, + type_layouts::TypeLayoutsInterface, }, }; use prusti_common::config; @@ -31,7 +34,7 @@ use vir_crate::{ identifier::WithIdentifier, }, low::{self as vir_low}, - middle as vir_mid, + middle::{self as vir_mid}, }; #[derive(Default)] @@ -63,6 +66,11 @@ trait Private { parameters: Vec, evaluation_result: vir_low::Expression, ) -> SpannedEncodingResult<()>; + // fn purify_structural_invariant( + // &mut self, + // structural_invariant: Vec, + // field_count: usize, + // ) -> SpannedEncodingResult>; } impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { @@ -130,8 +138,39 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { field.ty.to_snapshot(self)?, )); } + let parameters_with_validity = decl.fields.len(); + let invariant = if let Some(invariant) = &decl.structural_invariant { + let deref_fields = self.structural_invariant_to_deref_fields(invariant)?; + for (_, name, ty) in &deref_fields { + parameters.push(vir_low::VariableDecl::new(name, ty.clone())); + } + let mut validity_assertion_encoder = + ValidityAssertionToSnapshot::new(deref_fields); + // let invariant = self.structural_invariant_to_pure_expression( + // invariant.clone(), + // ty, + // decl, + // &mut parameters, + // )?; + let mut conjuncts = Vec::new(); + for expression in invariant { + conjuncts.push( + validity_assertion_encoder + .expression_to_snapshot(self, expression, true)?, + ); + // conjuncts.push(expression.to_pure_bool_expression(self)?); + } + conjuncts.into_iter().conjoin() //.remove_acc_predicates() + } else { + true.into() + }; self.register_struct_constructor(&domain_name, parameters.clone())?; - self.encode_validity_axioms_struct(&domain_name, parameters, true.into())?; + self.encode_validity_axioms_struct_with_invariant( + &domain_name, + parameters, + parameters_with_validity, + invariant, + )?; } vir_mid::TypeDecl::Enum(decl) => { let mut variants = Vec::new(); @@ -159,27 +198,42 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { vir_mid::TypeDecl::Pointer(decl) => { self.ensure_type_definition(&decl.target_type)?; let address_type = self.address_type()?; - self.register_constant_constructor(&domain_name, address_type.clone())?; - self.encode_validity_axioms_primitive(&domain_name, address_type, true.into())?; + let mut parameters = vec![vir_low::VariableDecl::new("address", address_type)]; + if decl.target_type.is_slice() { + let len_type = self.size_type()?; + parameters.push(vir_low::VariableDecl::new("len", len_type)); + } + self.register_struct_constructor(&domain_name, parameters.clone())?; + // self.register_constant_constructor(&domain_name, address_type.clone())?; + // self.encode_validity_axioms_primitive(&domain_name, address_type, true.into())?; + self.encode_validity_axioms_struct(&domain_name, parameters)?; } - vir_mid::TypeDecl::Reference(reference) => { - self.ensure_type_definition(&reference.target_type)?; - let target_type = reference.target_type.to_snapshot(self)?; - if reference.uniqueness.is_unique() { - let parameters = vars! { + vir_mid::TypeDecl::Reference(decl) => { + self.ensure_type_definition(&decl.target_type)?; + let target_type = decl.target_type.to_snapshot(self)?; + if decl.uniqueness.is_unique() { + let mut parameters = vars! { address: Address, target_current: {target_type.clone()}, target_final: {target_type} }; + if decl.target_type.is_slice() { + let len_type = self.size_type()?; + parameters.push(vir_low::VariableDecl::new("len", len_type)); + } self.register_struct_constructor(&domain_name, parameters.clone())?; - self.encode_validity_axioms_struct(&domain_name, parameters, true.into())?; + self.encode_validity_axioms_struct(&domain_name, parameters)?; } else { - let parameters = vars! { + let mut parameters = vars! { address: Address, target_current: {target_type.clone()} }; + if decl.target_type.is_slice() { + let len_type = self.size_type()?; + parameters.push(vir_low::VariableDecl::new("len", len_type)); + } self.register_struct_constructor(&domain_name, parameters.clone())?; - self.encode_validity_axioms_struct(&domain_name, parameters, true.into())?; + self.encode_validity_axioms_struct(&domain_name, parameters)?; let no_alloc_parameters = vars! { target_current: {target_type} }; self.register_alternative_constructor_with_injectivity_axioms( &domain_name, @@ -187,17 +241,24 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { true, no_alloc_parameters.clone(), )?; + let parameters_with_validity = no_alloc_parameters.len(); self.encode_validity_axioms_struct_alternative_constructor( &domain_name, "no_alloc", no_alloc_parameters, + parameters_with_validity, true.into(), )?; } } vir_mid::TypeDecl::Never => { self.register_struct_constructor(&domain_name, Vec::new())?; - self.encode_validity_axioms_struct(&domain_name, Vec::new(), false.into())?; + self.encode_validity_axioms_struct_with_invariant( + &domain_name, + Vec::new(), + 0, + false.into(), + )?; } _ => unimplemented!("type: {:?}", type_decl), }; @@ -283,6 +344,87 @@ impl<'p, 'v: 'p, 'tcx: 'v> Private for Lowerer<'p, 'v, 'tcx> { } Ok(()) } + + // fn purify_structural_invariant( + // &mut self, + // structural_invariant: Vec, + // field_count: usize, + // ) -> SpannedEncodingResult> { + + // // TODO: Create deref fields in vir_high together with a required + // // structural invariant that links their values? Probably does not work + // // because I need different treatment in predicate and snapshot + // // encoders. + + // // TODO: Maybe a better idea would be to have code that computes a + // // footprint of an expression? Then I could also use it for pure + // // functions. + + // struct Purifier<'l, 'p, 'v, 'tcx> { + // lowerer: &'l mut Lowerer<'p, 'v, 'tcx>, + // field_count: usize, + // } + // impl<'l, 'p, 'v, 'tcx> vir_mid::visitors::ExpressionFolder for Purifier<'l, 'p, 'v, 'tcx> { + // fn fold_acc_predicate_enum( + // &mut self, + // acc_predicate: vir_mid::AccPredicate, + // ) -> vir_mid::Expression { + // match *acc_predicate.predicate { + // vir_mid::Predicate::LifetimeToken(_) => { + // unimplemented!() + // } + // vir_mid::Predicate::MemoryBlockStack(_) + // | vir_mid::Predicate::MemoryBlockStackDrop(_) + // | vir_mid::Predicate::MemoryBlockHeap(_) + // | vir_mid::Predicate::MemoryBlockHeapDrop(_) => true.into(), + // vir_mid::Predicate::OwnedNonAliased(predicate) => { + // match predicate.place { + // vir_mid::Expression::Deref(vir_mid::Deref { + // base: + // box vir_mid::Expression::Field(vir_mid::Field { + // box base, + // field, + // .. + // }), + // ty, + // position, + // }) => { + // // let parameter = vir_mid::VariableDecl::new( + // // format!("{}$deref", field.name), + // // ty, + // // ); + // let app = vir_mid::Expression::builtin_func_app( + // vir_mid::BuiltinFunc::IsValid, + // Vec::new(), + // vec![ + // vir_mid::Expression::field( + // base, + // vir_mid::FieldDecl { + // name: format!("{}$deref", field.name), + // index: self.field_count, + // ty, + // }, + // position, + // )], + // vir_mid::Type::Bool, + // position, + // ); + // self.field_count += 1; + // app + // // self.lowerer.encode_snapshot_valid_call_for_type(parameter.into(), ty)? + // } + // _ => unimplemented!(), + // } + // } + // } + // } + // } + // let mut purifier = Purifier { lowerer: self, field_count }; + // Ok(structural_invariant + // .into_iter() + // .map(|expression| purifier.fold_expression(expression)) + // .collect()) + // } } pub(in super::super) trait TypesInterface { diff --git a/prusti-viper/src/encoder/middle/core_proof/utils/place_domain_encoder.rs b/prusti-viper/src/encoder/middle/core_proof/utils/place_domain_encoder.rs index d5d85b9b1ec..76e9fcc19df 100644 --- a/prusti-viper/src/encoder/middle/core_proof/utils/place_domain_encoder.rs +++ b/prusti-viper/src/encoder/middle/core_proof/utils/place_domain_encoder.rs @@ -23,6 +23,11 @@ pub(in super::super) trait PlaceExpressionDomainEncoder { lowerer: &mut Lowerer, arg: vir_low::Expression, ) -> SpannedEncodingResult; + fn encode_labelled_old( + &mut self, + expression: &vir_mid::expression::LabelledOld, + lowerer: &mut Lowerer, + ) -> SpannedEncodingResult; fn encode_array_index_axioms( &mut self, base_type: &vir_mid::Type, @@ -82,6 +87,9 @@ pub(in super::super) trait PlaceExpressionDomainEncoder { *position, )? } + vir_mid::Expression::LabelledOld(expression) => { + self.encode_labelled_old(expression, lowerer)? + } x => unimplemented!("{}", x), }; Ok(result) diff --git a/prusti-viper/src/encoder/mir/contracts/contracts.rs b/prusti-viper/src/encoder/mir/contracts/contracts.rs index 55477369c19..9b80f332577 100644 --- a/prusti-viper/src/encoder/mir/contracts/contracts.rs +++ b/prusti-viper/src/encoder/mir/contracts/contracts.rs @@ -109,7 +109,7 @@ impl ProcedureContractGeneric { &'a self, env: &'a Environment<'tcx>, substs: SubstsRef<'tcx>, - ) -> Option<(LocalDefId, SubstsRef<'tcx>)> { + ) -> Option<(DefId, SubstsRef<'tcx>)> { match self.specification.terminates { typed::SpecificationItem::Empty => None, typed::SpecificationItem::Inherent(t) | typed::SpecificationItem::Refined(_, t) => { diff --git a/prusti-viper/src/encoder/mir/errors/interface.rs b/prusti-viper/src/encoder/mir/errors/interface.rs index 180c08dacaf..0c4373f0cc7 100644 --- a/prusti-viper/src/encoder/mir/errors/interface.rs +++ b/prusti-viper/src/encoder/mir/errors/interface.rs @@ -25,12 +25,12 @@ pub(crate) trait ErrorInterface { error_ctxt: ErrorCtxt, ) -> vir_high::Position; fn set_surrounding_error_context( - &mut self, + &self, position: vir_high::Position, error_ctxt: ErrorCtxt, ) -> vir_high::Position; fn set_surrounding_error_context_for_expression( - &mut self, + &self, expression: vir_high::Expression, default_position: vir_high::Position, error_ctxt: ErrorCtxt, @@ -91,7 +91,7 @@ impl<'v, 'tcx: 'v> ErrorInterface for super::super::super::Encoder<'v, 'tcx> { new_position.into() } fn set_surrounding_error_context( - &mut self, + &self, position: vir_high::Position, error_ctxt: ErrorCtxt, ) -> vir_high::Position { @@ -104,14 +104,14 @@ impl<'v, 'tcx: 'v> ErrorInterface for super::super::super::Encoder<'v, 'tcx> { /// 1. `default_position` if `position.is_default()`. /// 2. With surrounding error context otherwise. fn set_surrounding_error_context_for_expression( - &mut self, + &self, expression: vir_high::Expression, default_position: vir_high::Position, error_ctxt: ErrorCtxt, ) -> vir_high::Expression { assert!(!default_position.is_default()); struct Visitor<'p, 'v: 'p, 'tcx: 'v> { - encoder: &'p mut super::super::super::Encoder<'v, 'tcx>, + encoder: &'p super::super::super::Encoder<'v, 'tcx>, default_position: vir_high::Position, error_ctxt: ErrorCtxt, } diff --git a/prusti-viper/src/encoder/mir/places/interface.rs b/prusti-viper/src/encoder/mir/places/interface.rs index 3690cf45413..2b5dc3158ac 100644 --- a/prusti-viper/src/encoder/mir/places/interface.rs +++ b/prusti-viper/src/encoder/mir/places/interface.rs @@ -188,7 +188,9 @@ impl<'v, 'tcx: 'v> PlacesEncoderInterface<'tcx> for super::super::super::Encoder .with_span(declaration_span)?; if parent_type.is_union() { // We treat union fields as variants. - let union_decl = self.encode_type_def_high(&parent_type)?.unwrap_union(); + let union_decl = self + .encode_type_def_high(&parent_type, false)? + .unwrap_union(); let variant = &union_decl.variants[field.index()]; let variant_index: vir_high::ty::VariantIndex = variant.name.clone().into(); let variant_type = parent_type.variant(variant_index.clone()); diff --git a/prusti-viper/src/encoder/mir/procedures/encoder/builtin_function_encoder.rs b/prusti-viper/src/encoder/mir/procedures/encoder/builtin_function_encoder.rs index adc3be26a3b..9f1066376ab 100644 --- a/prusti-viper/src/encoder/mir/procedures/encoder/builtin_function_encoder.rs +++ b/prusti-viper/src/encoder/mir/procedures/encoder/builtin_function_encoder.rs @@ -78,12 +78,12 @@ impl<'p, 'v, 'tcx> BuiltinFuncAppEncoder<'p, 'v, 'tcx> for super::ProcedureEncod size, ); block_builder.add_statement(encoder.encoder.set_statement_error_ctxt( - vir_high::Statement::exhale_no_pos(target_memory_block), + vir_high::Statement::exhale_predicate_no_pos(target_memory_block), span, ErrorCtxt::ProcedureCall, encoder.def_id, )?); - let inhale_statement = vir_high::Statement::inhale_no_pos( + let inhale_statement = vir_high::Statement::inhale_predicate_no_pos( vir_high::Predicate::owned_non_aliased_no_pos(encoded_target_place.clone()), ); block_builder.add_statement(encoder.encoder.set_statement_error_ctxt( @@ -199,6 +199,9 @@ impl<'p, 'v, 'tcx> BuiltinFuncAppEncoder<'p, 'v, 'tcx> for super::ProcedureEncod unimplemented!(); } } + "prusti_contracts::prusti_take_lifetime" => { + make_builtin_call(self, block_builder, vir_high::BuiltinFunc::TakeLifetime)? + } "prusti_contracts::Int::new" => { make_builtin_call(self, block_builder, vir_high::BuiltinFunc::NewInt)? } diff --git a/prusti-viper/src/encoder/mir/procedures/encoder/check_mode_converters.rs b/prusti-viper/src/encoder/mir/procedures/encoder/check_mode_converters.rs new file mode 100644 index 00000000000..e665e7299e7 --- /dev/null +++ b/prusti-viper/src/encoder/mir/procedures/encoder/check_mode_converters.rs @@ -0,0 +1,228 @@ +use super::ProcedureEncoder; +use crate::encoder::{ + errors::{SpannedEncodingError, SpannedEncodingResult}, + mir::{specifications::SpecificationsInterface, types::MirTypeEncoderInterface}, + Encoder, +}; + +use vir_crate::{ + common::{ + check_mode::CheckMode, + expression::{BinaryOperationHelpers, ExpressionIterator}, + position::Positioned, + }, + high::{self as vir_high, operations::ty::Typed}, +}; + +impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { + /// Convert expression to the one usable for the current check mode: + /// + /// * For `Both` and `Specifications`: keep the expression unchanged. + /// * For `CoreProof` keep only the raw pointer dereferences because we need + /// to check that they are framed. + /// + /// If `disallow_permissions` is true, then checks that the expression does + /// not contain accesibility predicates. + pub(super) fn convert_expression_to_check_mode( + &mut self, + expression: vir_high::Expression, + disallow_permissions: bool, + framing_variables: &[vir_high::VariableDecl], + ) -> SpannedEncodingResult> { + if disallow_permissions && !expression.is_pure() { + let span = self + .encoder + .error_manager() + .position_manager() + .get_span(expression.position().into()) + .cloned() + .unwrap(); + return Err(SpannedEncodingError::incorrect( + "only unsafe functions can use permissions in their contracts", + span, + )); + } + match self.check_mode { + CheckMode::MemorySafety => { + // Unsafe functions are checked with `CheckMode::UnsafeSafety`. For all + // other functions it is forbidden to have accessibility + // predicates in their contracts. + assert!(disallow_permissions); + // Framing will be checked with `CheckMode::MemorySafetyWithFunctional`. + Ok(None) + } + CheckMode::MemorySafetyWithFunctional => { + // Unsafe functions are checked with `CheckMode::UnsafeSafety`. For all + // other functions it is forbidden to have accessibility + // predicates in their contracts. + assert!(disallow_permissions); + // Framing is checked automatically by the encoding. + Ok(Some(expression)) + } + CheckMode::PurificationFunctional => { + // Unsafe functions are checked with `CheckMode::UnsafeSafety`. For all + // other functions it is forbidden to have accessibility + // predicates in their contracts. + assert!(disallow_permissions); + Ok(Some(expression)) + } + CheckMode::PurificationSoudness => { + // Check comment for `CheckMode::PurificationFunctional`. + assert!(disallow_permissions); + // Even though we forbid accessibility predicates in safe + // functions, we may still have raw pointers in specifications + // that are framed by type invariants. + let dereferenced_places = expression.collect_guarded_dereferenced_places(); + if dereferenced_places.is_empty() { + Ok(None) + } else { + let framing_places: Vec = framing_variables + .iter() + .map(|var| var.clone().into()) + .collect(); + let check = construct_framing_assertion( + self.encoder, + dereferenced_places, + &framing_places, + )?; + Ok(Some(check)) + } + } + CheckMode::UnsafeSafety => { + // Framing is checked automatically by the encoding. + Ok(Some(expression)) + } + } + } + + pub(super) fn convert_expression_to_check_mode_call_site( + &mut self, + expression: vir_high::Expression, + is_unsafe: bool, + framing_arguments: &[vir_high::Expression], + ) -> SpannedEncodingResult> { + match self.check_mode { + CheckMode::MemorySafety => { + if is_unsafe { + // We are calling an unsafe function from a safe one. + Ok(Some(expression)) + } else { + Ok(None) + } + } + CheckMode::MemorySafetyWithFunctional + | CheckMode::PurificationFunctional + | CheckMode::UnsafeSafety => Ok(Some(expression)), + CheckMode::PurificationSoudness => { + let dereferenced_places = expression.collect_guarded_dereferenced_places(); + let check = if dereferenced_places.is_empty() { + if is_unsafe { + Some(expression) + } else { + None + } + } else { + let check = construct_framing_assertion( + self.encoder, + dereferenced_places, + framing_arguments, + )?; + if is_unsafe { + Some(vir_high::Expression::and(expression, check)) + } else { + Some(check) + } + }; + Ok(check) + } + } + } +} + +fn construct_framing_assertion( + encoder: &mut Encoder, + dereferenced_places: Vec<(vir_high::Expression, vir_high::Expression)>, + framing_places: &[vir_high::Expression], +) -> SpannedEncodingResult { + let type_invariant_framing_places = + construct_type_invariant_framing_places(encoder, framing_places)?; + let mut type_invariant_framed_places = Vec::new(); + for (guard, place) in dereferenced_places { + if is_framed(&place, &type_invariant_framing_places) { + let function = vir_high::Expression::builtin_func_app( + vir_high::BuiltinFunc::EnsureOwnedPredicate, + Vec::new(), + vec![place.clone()], + vir_high::Type::Bool, + place.position(), + ); + let check = vir_high::Expression::implies(guard, function); + type_invariant_framed_places.push(check); + } else { + let span = encoder + .error_manager() + .position_manager() + .get_span(place.position().into()) + .cloned() + .unwrap(); + return Err(SpannedEncodingError::incorrect( + "the place must be framed by permissions", + span, + )); + } + } + Ok(type_invariant_framed_places.into_iter().conjoin()) +} + +fn construct_type_invariant_framing_places( + encoder: &mut Encoder, + framing_places: &[vir_high::Expression], +) -> SpannedEncodingResult> { + let mut type_invariant_framing_places = Vec::new(); + for framing_place in framing_places { + if framing_place.get_type().is_struct() { + let type_decl = encoder + .encode_type_def_high(framing_place.get_type(), true)? + .unwrap_struct(); + if let Some(invariants) = type_decl.structural_invariant { + for expression in invariants { + let expression = expression.replace_self(framing_place); + type_invariant_framing_places.extend(expression.collect_owned_places()); + } + } + } + } + Ok(type_invariant_framing_places) +} + +fn is_framed( + place: &vir_high::Expression, + type_invariant_framing_places: &[vir_high::Expression], +) -> bool { + for framing_place in type_invariant_framing_places { + if is_framed_rec(framing_place, place, type_invariant_framing_places) { + return true; + } + } + false +} + +fn is_framed_rec( + framing_place: &vir_high::Expression, + place: &vir_high::Expression, + type_invariant_framing_places: &[vir_high::Expression], +) -> bool { + if framing_place == place { + if let Some(pointer_place) = place.get_last_dereferenced_pointer() { + is_framed(pointer_place, type_invariant_framing_places) + } else { + true + } + } else if place.is_deref() { + false + } else if let Some(parent) = place.get_parent_ref() { + is_framed_rec(framing_place, parent, type_invariant_framing_places) + } else { + true + } +} diff --git a/prusti-viper/src/encoder/mir/procedures/encoder/lifetimes.rs b/prusti-viper/src/encoder/mir/procedures/encoder/lifetimes.rs index d989380f42d..43f8b98bd2d 100644 --- a/prusti-viper/src/encoder/mir/procedures/encoder/lifetimes.rs +++ b/prusti-viper/src/encoder/mir/procedures/encoder/lifetimes.rs @@ -1100,10 +1100,9 @@ impl<'p, 'v: 'p, 'tcx: 'v> LifetimesEncoder<'tcx> for ProcedureEncoder<'p, 'v, ' permission_amount: vir_high::Expression, ) -> SpannedEncodingResult { self.encoder.set_statement_error_ctxt( - vir_high::Statement::inhale_no_pos(vir_high::Predicate::lifetime_token_no_pos( - lifetime_const, - permission_amount, - )), + vir_high::Statement::inhale_predicate_no_pos( + vir_high::Predicate::lifetime_token_no_pos(lifetime_const, permission_amount), + ), self.mir.span, ErrorCtxt::LifetimeInhale, self.def_id, @@ -1136,10 +1135,9 @@ impl<'p, 'v: 'p, 'tcx: 'v> LifetimesEncoder<'tcx> for ProcedureEncoder<'p, 'v, ' permission_amount: vir_high::Expression, ) -> SpannedEncodingResult { self.encoder.set_statement_error_ctxt( - vir_high::Statement::exhale_no_pos(vir_high::Predicate::lifetime_token_no_pos( - lifetime_const, - permission_amount, - )), + vir_high::Statement::exhale_predicate_no_pos( + vir_high::Predicate::lifetime_token_no_pos(lifetime_const, permission_amount), + ), self.mir.span, ErrorCtxt::LifetimeExhale, self.def_id, diff --git a/prusti-viper/src/encoder/mir/procedures/encoder/mod.rs b/prusti-viper/src/encoder/mir/procedures/encoder/mod.rs index 6d1bd330662..b84e2d20dd8 100644 --- a/prusti-viper/src/encoder/mir/procedures/encoder/mod.rs +++ b/prusti-viper/src/encoder/mir/procedures/encoder/mod.rs @@ -18,6 +18,7 @@ use crate::encoder::{ spans::SpanInterface, specifications::SpecificationsInterface, type_layouts::MirTypeLayoutsEncoderInterface, + types::MirTypeEncoderInterface, }, mir_encoder::PRECONDITION_LABEL, Encoder, @@ -57,6 +58,7 @@ use vir_crate::{ }; mod builtin_function_encoder; +mod check_mode_converters; mod elaborate_drops; mod ghost; mod initialisation; @@ -66,20 +68,28 @@ mod scc; pub mod specification_blocks; mod termination; +#[derive(Debug)] +pub(super) enum ProcedureEncodingKind { + Regular, + PostconditionFrameCheck, +} + pub(super) fn encode_procedure<'v, 'tcx: 'v>( encoder: &mut Encoder<'v, 'tcx>, def_id: DefId, check_mode: CheckMode, + encoding_kind: ProcedureEncodingKind, ) -> SpannedEncodingResult { let procedure = Procedure::new(encoder.env(), def_id); let tcx = encoder.env().tcx(); let (mir, lifetimes) = self::elaborate_drops::elaborate_drops(encoder, def_id, &procedure)?; let mir = &mir; // Mark body as immutable. + let is_unsafe_function = encoder.env().query.is_unsafe_function(def_id); let move_env = self::initialisation::create_move_data_param_env(tcx, mir, def_id); let init_data = InitializationData::new(tcx, mir, &move_env); let locals_without_explicit_allocation: BTreeSet<_> = mir.vars_and_temps_iter().collect(); let specification_blocks = - SpecificationBlocks::build(encoder.env().query, mir, &procedure, true); + SpecificationBlocks::build(encoder.env().query, mir, Some(&procedure), true); let initialization = compute_definitely_initialized(def_id, mir, encoder.env().tcx()); let allocation = compute_definitely_allocated(def_id, mir); let lifetime_count = lifetimes.lifetime_count(); @@ -95,6 +105,7 @@ pub(super) fn encode_procedure<'v, 'tcx: 'v>( encoder, def_id, check_mode, + is_unsafe_function, procedure: &procedure, mir, init_data, @@ -105,7 +116,7 @@ pub(super) fn encode_procedure<'v, 'tcx: 'v>( specification_blocks, specification_block_encoding: Default::default(), loop_invariant_encoding: Default::default(), - check_panics: config::check_panics() && check_mode != CheckMode::CoreProof, + check_panics: config::check_panics() && check_mode.check_specifications(), locals_without_explicit_allocation, used_locals: Default::default(), fresh_id_generator: 0, @@ -118,6 +129,12 @@ pub(super) fn encode_procedure<'v, 'tcx: 'v>( reborrow_lifetimes_to_remove_for_block, current_basic_block, termination_variable: None, + encoding_kind, + opened_reference_place_permissions: Default::default(), + opened_reference_witnesses: Default::default(), + user_named_lifetimes: Default::default(), + manually_managed_places: Default::default(), + stashed_ranges: Default::default(), }; procedure_encoder.encode() } @@ -126,6 +143,8 @@ struct ProcedureEncoder<'p, 'v: 'p, 'tcx: 'v> { encoder: &'p mut Encoder<'v, 'tcx>, def_id: DefId, check_mode: CheckMode, + encoding_kind: ProcedureEncodingKind, + is_unsafe_function: bool, procedure: &'p Procedure<'tcx>, mir: &'p mir::Body<'tcx>, init_data: InitializationData<'p, 'tcx>, @@ -158,30 +177,55 @@ struct ProcedureEncoder<'p, 'v: 'p, 'tcx: 'v> { reborrow_lifetimes_to_remove_for_block: BTreeMap>, current_basic_block: Option, termination_variable: Option, + /// A map from opened reference place to the corresponding permission + /// variable. + opened_reference_place_permissions: + BTreeMap>, + /// A map from opened reference witnesses to the corresponding places and lifetimes. + opened_reference_witnesses: + BTreeMap, + /// The lifetimes extracted by the user by using `take_lifetime!` macro. + user_named_lifetimes: BTreeMap, + /// Places that are manually managed by the user and for which we should not + /// automatically generate open/close/fold/unfold statements. + manually_managed_places: BTreeSet, + /// Information about stashed ranges with a given name: `(pointer, + /// start_index, end_index)`. + stashed_ranges: BTreeMap< + String, + ( + vir_high::Expression, + vir_high::Expression, + vir_high::Expression, + ), + >, } impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { fn encode(&mut self) -> SpannedEncodingResult { self.pure_sanity_checks()?; let name = format!( - "{}${}", + "{}${}${:?}", self.encoder.encode_item_name(self.def_id), - self.check_mode + self.check_mode, + self.encoding_kind, ); let (allocate_parameters, deallocate_parameters) = self.encode_parameters()?; let (allocate_returns, deallocate_returns) = self.encode_returns()?; self.lifetime_token_permission = Some(self.fresh_ghost_variable("lifetime_token_perm_amount", vir_high::Type::MPerm)); - let (assume_preconditions, assert_postconditions) = match self.check_mode { - CheckMode::CoreProof => { - // Unsafe functions will come with CheckMode::Both because they - // are allowed to have preconditions. - (Vec::new(), Vec::new()) - } - CheckMode::Both | CheckMode::Specifications => { - self.encode_functional_specifications()? - } - }; + let (assume_preconditions, assert_postconditions) = + self.encode_functional_specifications()?; + // match self.check_mode { + // CheckMode::CoreProof => { + // // Unsafe functions will come with CheckMode::Both because they + // // are allowed to have preconditions. + // (Vec::new(), Vec::new()) + // } + // CheckMode::Both | CheckMode::Specifications => { + // self.encode_functional_specifications()? + // } + // }; let (assume_lifetime_preconditions, assert_lifetime_postconditions) = self.encode_lifetime_specifications()?; let termination_initialization = self.encode_termination_initialization()?; @@ -201,9 +245,22 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { post_statements.extend(deallocate_parameters); post_statements.extend(deallocate_returns); post_statements.extend(assert_lifetime_postconditions); - let mut procedure_builder = - ProcedureBuilder::new(name, self.check_mode, pre_statements, post_statements); - self.encode_body(&mut procedure_builder)?; + let procedure_position = + self.encoder + .register_error(self.mir.span, ErrorCtxt::Unexpected, self.def_id); + let mut procedure_builder = ProcedureBuilder::new( + name, + self.check_mode, + procedure_position, + pre_statements, + post_statements, + ); + match self.encoding_kind { + ProcedureEncodingKind::Regular => self.encode_body(&mut procedure_builder)?, + ProcedureEncodingKind::PostconditionFrameCheck => { + self.encode_postcondition_frame_check(&mut procedure_builder)?; + } + } self.encode_implicit_allocations(&mut procedure_builder)?; Ok(procedure_builder.build()) } @@ -256,7 +313,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { )]; for mir_arg in self.mir.args_iter() { let parameter = self.encode_local(mir_arg)?; - let alloc_statement = vir_high::Statement::inhale_no_pos( + let alloc_statement = vir_high::Statement::inhale_predicate_no_pos( vir_high::Predicate::owned_non_aliased_no_pos(parameter.clone().into()), ); allocation.push(self.encoder.set_surrounding_error_context_for_statement( @@ -266,7 +323,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { )?); let mir_type = self.encoder.get_local_type(self.mir, mir_arg)?; let size = self.encoder.encode_type_size_expression(mir_type)?; - let dealloc_statement = vir_high::Statement::exhale_no_pos( + let dealloc_statement = vir_high::Statement::exhale_predicate_no_pos( vir_high::Predicate::memory_block_stack_no_pos(parameter.clone().into(), size), ); deallocation.push(self.encoder.set_surrounding_error_context_for_statement( @@ -285,17 +342,16 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { let mir_type = self.encoder.get_local_type(self.mir, mir::RETURN_PLACE)?; let size = self.encoder.encode_type_size_expression(mir_type)?; let alloc_statement = self.encoder.set_surrounding_error_context_for_statement( - vir_high::Statement::inhale_no_pos(vir_high::Predicate::memory_block_stack_no_pos( - return_local.clone().into(), - size, - )), + vir_high::Statement::inhale_predicate_no_pos( + vir_high::Predicate::memory_block_stack_no_pos(return_local.clone().into(), size), + ), return_local.position, ErrorCtxt::UnexpectedStorageLive, )?; let dealloc_statement = self.encoder.set_surrounding_error_context_for_statement( - vir_high::Statement::exhale_no_pos(vir_high::Predicate::owned_non_aliased_no_pos( - return_local.clone().into(), - )), + vir_high::Statement::exhale_predicate_no_pos( + vir_high::Predicate::owned_non_aliased_no_pos(return_local.clone().into()), + ), return_local.position, ErrorCtxt::UnexpectedStorageDead, )?; @@ -384,24 +440,35 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { "Assume functional preconditions.".to_string(), )]; let mut arguments: Vec = Vec::new(); + let mut framing_variables = Vec::new(); for local in self.mir.args_iter() { - arguments.push(self.encode_local(local)?.into()); + let parameter = self.encode_local(local)?; + framing_variables.push(parameter.variable.clone()); + arguments.push(parameter.into()); } for expression in self.encode_precondition_expressions(&procedure_contract, substs, &arguments)? { - let assume_statement = self.encoder.set_statement_error_ctxt( - vir_high::Statement::assume_no_pos(expression), - mir_span, - ErrorCtxt::UnexpectedAssumeMethodPrecondition, - self.def_id, - )?; - preconditions.push(assume_statement); + if let Some(expression) = self.convert_expression_to_check_mode( + expression, + !self.is_unsafe_function, + &framing_variables, + )? { + let inhale_statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::inhale_expression_no_pos(expression), + mir_span, + ErrorCtxt::UnexpectedAssumeMethodPrecondition, + self.def_id, + )?; + preconditions.push(inhale_statement); + } } let mut postconditions = vec![vir_high::Statement::comment( "Assert functional postconditions.".to_string(), )]; - let result: vir_high::Expression = self.encode_local(mir::RETURN_PLACE)?.into(); + let result_variable = self.encode_local(mir::RETURN_PLACE)?; + framing_variables.push(result_variable.variable.clone()); + let result: vir_high::Expression = result_variable.into(); for expression in self.encode_postcondition_expressions( &procedure_contract, substs, @@ -409,13 +476,19 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { &result, PRECONDITION_LABEL, )? { - let assert_statement = self.encoder.set_statement_error_ctxt( - vir_high::Statement::assert_no_pos(expression), - mir_span, - ErrorCtxt::AssertMethodPostcondition, - self.def_id, - )?; - postconditions.push(assert_statement); + if let Some(expression) = self.convert_expression_to_check_mode( + expression, + !self.is_unsafe_function, + &framing_variables, + )? { + let exhale_statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::exhale_expression_no_pos(expression), + mir_span, + ErrorCtxt::AssertMethodPostcondition, + self.def_id, + )?; + postconditions.push(exhale_statement); + } } Ok((preconditions, postconditions)) } @@ -438,14 +511,14 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { ); procedure_builder.add_alloc_statement( self.encoder.set_surrounding_error_context_for_statement( - vir_high::Statement::inhale_no_pos(predicate.clone()), + vir_high::Statement::inhale_predicate_no_pos(predicate.clone()), encoded_local.position, ErrorCtxt::UnexpectedStorageLive, )?, ); procedure_builder.add_dealloc_statement( self.encoder.set_surrounding_error_context_for_statement( - vir_high::Statement::exhale_no_pos(predicate.clone()), + vir_high::Statement::exhale_predicate_no_pos(predicate.clone()), encoded_local.position, ErrorCtxt::UnexpectedStorageLive, )?, @@ -490,6 +563,142 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { Ok(()) } + fn encode_postcondition_frame_check( + &mut self, + procedure_builder: &mut ProcedureBuilder, + ) -> SpannedEncodingResult<()> { + // FIXME: code duplication with encode_function_call. + let entry_label = vir_high::BasicBlockId::new("label_entry".to_string()); + let mut block_builder = procedure_builder.create_basic_block_builder(entry_label.clone()); + block_builder.set_successor_exit(SuccessorExitKind::Return); + let location = mir::Location { + block: 0usize.into(), + statement_index: 0, + }; + let span = self.mir.span; + let called_def_id = self.def_id; + let call_substs = self.encoder.env().query.identity_substs(called_def_id); + let args: Vec<_> = self + .mir + .args_iter() + .map(|arg| mir::Operand::Move(arg.into())) + .collect(); + let target_place_local = mir::RETURN_PLACE; + let destination: mir::Place = target_place_local.into(); + // let target = Some(1usize.into()); + // let cleanup = Some(1usize.into()); + + let is_unsafe = self.encoder.env().query.is_unsafe_function(called_def_id); + + // self.encode_function_call(&mut block_builder, location, span, called_def_id, call_substs, &args, destination, &target, &cleanup)?; + + let old_label = self.fresh_old_label(); + block_builder.add_statement(self.encoder.set_statement_error_ctxt( + vir_high::Statement::old_label_no_pos(old_label.clone()), + span, + ErrorCtxt::ProcedureCall, + self.def_id, + )?); + + let mut arguments = Vec::new(); + for arg in &args { + arguments.push( + self.encoder + .encode_operand_high(self.mir, arg, span) + .with_span(span)?, + ); + let encoded_arg = self.encode_statement_operand(location, arg)?; + let statement = vir_high::Statement::consume_no_pos(encoded_arg); + block_builder.add_statement(self.encoder.set_statement_error_ctxt( + statement, + span, + ErrorCtxt::ProcedureCall, + self.def_id, + )?); + } + + let procedure_contract = self + .encoder + .get_mir_procedure_contract_for_call(self.def_id, called_def_id, call_substs) + .with_span(span)?; + + let precondition_expressions = + self.encode_precondition_expressions(&procedure_contract, call_substs, &arguments)?; + for expression in precondition_expressions { + if let Some(expression) = + self.convert_expression_to_check_mode_call_site(expression, is_unsafe, &arguments)? + { + let exhale_statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::exhale_expression_no_pos(expression), + span, + ErrorCtxt::ExhaleMethodPrecondition, + self.def_id, + )?; + block_builder.add_statement(exhale_statement); + } + } + + let position = self.register_error(location, ErrorCtxt::ProcedureCall); + let encoded_target_place = self + .encode_place(destination, None)? + .set_default_position(position); + let postcondition_expressions = self.encode_postcondition_expressions( + &procedure_contract, + call_substs, + arguments.clone(), + &encoded_target_place, + &old_label, + )?; + let size = self.encoder.encode_type_size_expression( + self.encoder.get_local_type(self.mir, target_place_local)?, + )?; + let target_memory_block = + vir_high::Predicate::memory_block_stack_no_pos(encoded_target_place.clone(), size); + block_builder.add_statement(self.encoder.set_statement_error_ctxt( + vir_high::Statement::exhale_predicate_no_pos(target_memory_block), + span, + ErrorCtxt::ProcedureCall, + self.def_id, + )?); + let statement = vir_high::Statement::inhale_predicate_no_pos( + vir_high::Predicate::owned_non_aliased_no_pos(encoded_target_place.clone()), + ); + block_builder.add_statement(self.encoder.set_statement_error_ctxt( + statement, + span, + ErrorCtxt::ProcedureCall, + self.def_id, + )?); + let result_place = vec![encoded_target_place.clone()]; + for expression in postcondition_expressions { + if let Some(expression) = self.convert_expression_to_check_mode_call_site( + expression, + is_unsafe, + &result_place, + )? { + let inhale_statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::inhale_expression_no_pos(expression), + span, + ErrorCtxt::MethodPostconditionFraming, + self.def_id, + )?; + block_builder.add_statement(inhale_statement); + } + } + + let assume_statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::assume_no_pos(false.into()), + span, + ErrorCtxt::UnexpectedAssumeEndMethodPostconditionFraming, + self.def_id, + )?; + block_builder.add_statement(assume_statement); + + block_builder.build(); + procedure_builder.set_entry(entry_label); + Ok(()) + } + fn encode_basic_block( &mut self, procedure_builder: &mut ProcedureBuilder, @@ -598,7 +807,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { block_builder.add_statement(self.set_statement_error( location, ErrorCtxt::UnexpectedStorageLive, - vir_high::Statement::inhale_no_pos(memory_block), + vir_high::Statement::inhale_predicate_no_pos(memory_block), )?); let memory_block_drop = self .encoder @@ -606,7 +815,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { block_builder.add_statement(self.set_statement_error( location, ErrorCtxt::UnexpectedStorageLive, - vir_high::Statement::inhale_no_pos(memory_block_drop), + vir_high::Statement::inhale_predicate_no_pos(memory_block_drop), )?); } mir::StatementKind::StorageDead(local) => { @@ -617,7 +826,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { block_builder.add_statement(self.set_statement_error( location, ErrorCtxt::UnexpectedStorageDead, - vir_high::Statement::exhale_no_pos(memory_block), + vir_high::Statement::exhale_predicate_no_pos(memory_block), )?); let memory_block_drop = self .encoder @@ -625,7 +834,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { block_builder.add_statement(self.set_statement_error( location, ErrorCtxt::UnexpectedStorageDead, - vir_high::Statement::exhale_no_pos(memory_block_drop), + vir_high::Statement::exhale_predicate_no_pos(memory_block_drop), )?); } mir::StatementKind::Assign(box (target, source)) => { @@ -672,7 +881,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { mir::Rvalue::Ref(region, borrow_kind, place) => { let is_reborrow = place .iter_projections() - .filter(|(_ref, projection)| projection == &mir::ProjectionElem::Deref) + .filter(|(place, projection)| { + projection == &mir::ProjectionElem::Deref + && place.ty(self.mir, self.encoder.env().tcx()).ty.is_ref() + }) .last(); let uniquness = match borrow_kind { mir::BorrowKind::Mut { .. } => vir_high::ty::Uniqueness::Unique, @@ -738,7 +950,20 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { vir_high::Statement::assign_no_pos(encoded_target, encoded_rvalue), )?); } - // mir::Rvalue::Cast(CastKind, Operand<'tcx>, Ty<'tcx>), + mir::Rvalue::Cast(_kind, operand, ty) => { + let encoded_operand = self.encode_statement_operand(location, operand)?; + let ty = self.encoder.encode_type_high(*ty)?; + let encoded_rvalue = vir_high::Rvalue::cast(encoded_operand, ty); + block_builder.add_statement(self.set_statement_error( + location, + ErrorCtxt::Assign, + vir_high::Statement::assign_no_pos(encoded_target, encoded_rvalue), + )?); + // self.encode_assign_cast(block_builder, location, encoded_target, *kind, operand, *ty)?; + // TODO: For raw pointers do nothing because we care only about + // the type of the target. + // unimplemented!("kind={kind:?} operand={operand:?} ty={ty:?}"); + } mir::Rvalue::BinaryOp(op, box (left, right)) => { let encoded_left = self.encode_statement_operand(location, left)?; let encoded_right = self.encode_statement_operand(location, right)?; @@ -779,11 +1004,11 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { mir::Rvalue::Discriminant(place) => { let encoded_place = self.encode_place(*place, None)?; - let deref_base = encoded_place.get_dereference_base().cloned(); - let source_permission = self.encode_open_reference( + // let deref_base = encoded_place.get_dereference_base().cloned(); + let source_permission = self.encode_automatic_open_reference( block_builder, location, - &deref_base, + // &deref_base, encoded_place.clone(), )?; @@ -797,10 +1022,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { vir_high::Statement::assign_no_pos(encoded_target, encoded_rvalue), )?); - self.encode_close_reference( + self.encode_automatic_close_reference( block_builder, location, - &deref_base, + // &deref_base, encoded_place, source_permission, )?; @@ -903,95 +1128,157 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { Ok(()) } + fn is_manually_managed(&self, place: &vir_high::Expression) -> bool { + for manual_place in &self.manually_managed_places { + if place.has_prefix(manual_place) { + return true; + } + } + false + } + fn encode_close_reference( &mut self, - block_builder: &mut BasicBlockBuilder, location: mir::Location, deref_base: &Option, place: vir_high::Expression, permission: Option, - ) -> SpannedEncodingResult<()> { + ) -> SpannedEncodingResult> { + let mut statement = None; if let Some(base) = deref_base { - if let vir_high::ty::Type::Reference(vir_high::ty::Reference { - lifetime, - uniqueness, - .. - }) = base.get_type() - { - if *uniqueness == vir_high::ty::Uniqueness::Unique { - block_builder.add_statement(self.set_statement_error( - location, - ErrorCtxt::CloseMutRef, - vir_high::Statement::close_mut_ref_no_pos( - lifetime.clone(), - self.lifetime_token_fractional_permission(self.lifetime_count), - place, - ), - )?); - } else { - block_builder.add_statement(self.set_statement_error( - location, - ErrorCtxt::CloseFracRef, - vir_high::Statement::close_frac_ref_no_pos( - lifetime.clone(), - self.lifetime_token_fractional_permission(self.lifetime_count), - place, - permission.unwrap(), - ), - )?); + match base.get_type() { + vir_high::ty::Type::Reference(vir_high::ty::Reference { + lifetime, + uniqueness, + .. + }) => { + if *uniqueness == vir_high::ty::Uniqueness::Unique { + statement = Some(self.set_statement_error( + location, + ErrorCtxt::CloseMutRef, + vir_high::Statement::close_mut_ref_no_pos( + lifetime.clone(), + self.lifetime_token_fractional_permission(self.lifetime_count), + place, + ), + )?); + } else { + statement = Some(self.set_statement_error( + location, + ErrorCtxt::CloseFracRef, + vir_high::Statement::close_frac_ref_no_pos( + lifetime.clone(), + self.lifetime_token_fractional_permission(self.lifetime_count), + place, + permission.unwrap(), + ), + )?); + } } - } else { - unreachable!(); - }; + vir_high::ty::Type::Pointer(_) => {} + _ => unreachable!(), + } + } + Ok(statement) + } + + fn encode_automatic_close_reference( + &mut self, + block_builder: &mut BasicBlockBuilder, + location: mir::Location, + place: vir_high::Expression, + permission: Option, + ) -> SpannedEncodingResult<()> { + if self.is_manually_managed(&place) { + return Ok(()); + } + let deref_base = place.get_dereference_base().cloned(); + let statement = self.encode_close_reference(location, &deref_base, place, permission)?; + if let Some(statement) = statement { + block_builder.add_statement(statement); } Ok(()) } fn encode_open_reference( &mut self, - block_builder: &mut BasicBlockBuilder, location: mir::Location, deref_base: &Option, place: vir_high::Expression, - ) -> SpannedEncodingResult> { + ) -> SpannedEncodingResult<(Option, Option)> { let mut variable = None; + let mut statement = None; if let Some(base) = deref_base { - if let vir_high::ty::Type::Reference(vir_high::ty::Reference { - lifetime, - uniqueness, - .. - }) = base.get_type() - { - if *uniqueness == vir_high::ty::Uniqueness::Unique { - block_builder.add_statement(self.set_statement_error( - location, - ErrorCtxt::OpenMutRef, - vir_high::Statement::open_mut_ref_no_pos( - lifetime.clone(), - self.lifetime_token_fractional_permission(self.lifetime_count), - place, - ), - )?); - } else { - let permission = - self.fresh_ghost_variable("tmp_frac_ref_perm", vir_high::Type::MPerm); - variable = Some(permission.clone()); - block_builder.add_statement(self.set_statement_error( - location, - ErrorCtxt::OpenFracRef, - vir_high::Statement::open_frac_ref_no_pos( - lifetime.clone(), - permission, - self.lifetime_token_fractional_permission(self.lifetime_count), - place, - ), - )?); + match base.get_type() { + vir_high::ty::Type::Reference(vir_high::ty::Reference { + lifetime, + uniqueness, + .. + }) => { + if *uniqueness == vir_high::ty::Uniqueness::Unique { + statement = Some(self.set_statement_error( + location, + ErrorCtxt::OpenMutRef, + vir_high::Statement::open_mut_ref_no_pos( + lifetime.clone(), + self.lifetime_token_fractional_permission(self.lifetime_count), + place, + ), + )?); + } else { + let permission = + self.fresh_ghost_variable("tmp_frac_ref_perm", vir_high::Type::MPerm); + variable = Some(permission.clone()); + statement = Some(self.set_statement_error( + location, + ErrorCtxt::OpenFracRef, + vir_high::Statement::open_frac_ref_no_pos( + lifetime.clone(), + permission, + self.lifetime_token_fractional_permission(self.lifetime_count), + place, + ), + )?); + } } - } else { - unreachable!("place: {} deref_base: {:?}", place, deref_base); + vir_high::ty::Type::Pointer(_) => { + // Note: if the dereferenced place is behind a raw pointer + // and reference, we require the user to manually open the + // reference. + } + _ => unreachable!("place: {} deref_base: {:?}", place, base), } }; - Ok(variable) + Ok((variable, statement)) + } + + fn encode_automatic_open_reference( + &mut self, + block_builder: &mut BasicBlockBuilder, + location: mir::Location, + // deref_base: &Option, + place: vir_high::Expression, + ) -> SpannedEncodingResult> { + if self.is_manually_managed(&place) { + return Ok(None); + } + let deref_place = place.get_dereference_base().cloned(); + let (variable, statement) = + self.encode_open_reference(location, &deref_place, place.clone())?; + if let Some(statement) = statement { + block_builder.add_statement(statement); + } + if variable.is_some() { + Ok(variable) + } else { + // Check whether the place was manually opened. + for (opened_place, variable) in &self.opened_reference_place_permissions { + if place.has_prefix(opened_place) { + return Ok(variable.clone()); + } + } + Ok(None) + } } fn encode_assign_operand( @@ -1003,11 +1290,11 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { ) -> SpannedEncodingResult<()> { let span = self.encoder.get_span_of_location(self.mir, location); - let deref_base = encoded_target.get_dereference_base().cloned(); - let target_permission = self.encode_open_reference( + // let deref_base = encoded_target.get_dereference_base().cloned(); + let target_permission = self.encode_automatic_open_reference( block_builder, location, - &deref_base, + // &deref_base, encoded_target.clone(), )?; match operand { @@ -1033,11 +1320,11 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { "{encoded_source} is not place (encoded from: {source:?}" ); - let deref_base = encoded_source.get_dereference_base().cloned(); - let source_permission = self.encode_open_reference( + // let deref_base = encoded_source.get_dereference_base().cloned(); + let source_permission = self.encode_automatic_open_reference( block_builder, location, - &deref_base, + // &deref_base, encoded_source.clone(), )?; @@ -1051,10 +1338,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { ), )?); - self.encode_close_reference( + self.encode_automatic_close_reference( block_builder, location, - &deref_base, + // &deref_base, encoded_source, source_permission, )?; @@ -1075,10 +1362,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { } } - self.encode_close_reference( + self.encode_automatic_close_reference( block_builder, location, - &deref_base, + // &deref_base, encoded_target, target_permission, )?; @@ -1086,6 +1373,19 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { Ok(()) } + // fn encode_assign_cast( + // &mut self, + // block_builder: &mut BasicBlockBuilder, + // location: mir::Location, + // encoded_target: vir_crate::high::Expression, + // kind: mir::CastKind, + // operand: &mir::Operand<'tcx>, + // ty: ty::Ty<'tcx>, + // ) -> SpannedEncodingResult<()> { + // let span = self.encoder.get_span_of_location(self.mir, location); + // match ty {} + // } + fn encode_statement_operand( &mut self, location: mir::Location, @@ -1487,6 +1787,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { let query = self.encoder.env().query; let (called_def_id, call_substs) = query.resolve_method_call(self.def_id, called_def_id, call_substs); + let is_unsafe = query.is_unsafe_function(called_def_id); // find static lifetime to exhale let mut lifetimes_to_exhale_inhale: Vec = Vec::new(); @@ -1604,18 +1905,32 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { )?; } - for expression in - self.encode_precondition_expressions(&procedure_contract, call_substs, &arguments)? - { - let assert_statement = self.encoder.set_statement_error_ctxt( - vir_high::Statement::assert_no_pos(expression), + let precondition_expressions = + self.encode_precondition_expressions(&procedure_contract, call_substs, &arguments)?; + let has_no_precondition = precondition_expressions.is_empty(); + for expression in precondition_expressions { + if let Some(expression) = + self.convert_expression_to_check_mode_call_site(expression, is_unsafe, &arguments)? + { + let exhale_statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::exhale_expression_no_pos(expression), + span, + ErrorCtxt::ExhaleMethodPrecondition, + self.def_id, + )?; + block_builder.add_statement(exhale_statement); + } + } + + let is_pure = self.encoder.is_pure(called_def_id, Some(call_substs)); + if !is_pure && self.check_mode.is_purification_group() { + let heap_havoc_statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::heap_havoc_no_pos(), span, ErrorCtxt::ExhaleMethodPrecondition, self.def_id, )?; - if self.check_mode != CheckMode::CoreProof { - block_builder.add_statement(assert_statement); - } + block_builder.add_statement(heap_havoc_statement); } if self.encoder.env().query.is_closure(called_def_id) { @@ -1646,7 +1961,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { size, ); block_builder.add_statement(self.encoder.set_statement_error_ctxt( - vir_high::Statement::exhale_no_pos(target_memory_block.clone()), + vir_high::Statement::exhale_predicate_no_pos(target_memory_block.clone()), span, ErrorCtxt::ProcedureCall, self.def_id, @@ -1657,7 +1972,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { let mut post_call_block_builder = block_builder.create_basic_block_builder(fresh_destination_label.clone()); post_call_block_builder.set_successor_jump(vir_high::Successor::Goto(target_label)); - let statement = vir_high::Statement::inhale_no_pos( + let statement = vir_high::Statement::inhale_predicate_no_pos( vir_high::Predicate::owned_non_aliased_no_pos(encoded_target_place.clone()), ); post_call_block_builder.add_statement(self.encoder.set_statement_error_ctxt( @@ -1687,18 +2002,26 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { self.encode_lft_for_block(*target_block, location, &mut post_call_block_builder)?; + let result_place = vec![encoded_target_place.clone()]; for expression in postcondition_expressions { - let assume_statement = self.encoder.set_statement_error_ctxt( - vir_high::Statement::assume_no_pos(expression), - span, - ErrorCtxt::UnexpectedAssumeMethodPostcondition, - self.def_id, - )?; - if self.check_mode != CheckMode::CoreProof { - post_call_block_builder.add_statement(assume_statement); + if let Some(expression) = self.convert_expression_to_check_mode_call_site( + expression, + is_unsafe || + // If we have no precondition, then we can soundly + // allways include the function postcondition. + has_no_precondition, + &result_place, + )? { + let inhale_statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::inhale_expression_no_pos(expression), + span, + ErrorCtxt::UnexpectedAssumeMethodPostcondition, + self.def_id, + )?; + post_call_block_builder.add_statement(inhale_statement); } } - if self.encoder.is_pure(called_def_id, Some(call_substs)) + if is_pure && !self.encoder.env().callee_reaches_caller( self.def_id, called_def_id, @@ -1733,9 +2056,42 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { ErrorCtxt::UnexpectedAssumeMethodPostcondition, self.def_id, )?; - if self.check_mode != CheckMode::CoreProof { + if self.check_mode.check_specifications() || + // If we have no precondition, then we can soundly + // allways include the function postcondition. + has_no_precondition + { post_call_block_builder.add_statement(assume_statement); } + } else { + // // FIXME: We do this because extern specs do not support primitive + // // types. + // let func_name = self.encoder.env().name.get_unique_item_name(called_def_id); + // if func_name.starts_with("std::ptr::mut_ptr::::is_null") + // || func_name.starts_with("core::std::ptr::mut_ptr::::is_null") { + // let type_arguments = self + // .encoder + // .encode_generic_arguments_high(called_def_id, call_substs) + // .with_span(span)?; + // let expression = vir_high::Expression::equals( + // encoded_target_place, + // vir_high::Expression::builtin_func_app_no_pos( + // vir_high::BuiltinFunc::IsNull, + // type_arguments, + // arguments, + // vir_high::Type::Bool, + // ), + // ); + // let assume_statement = self.encoder.set_statement_error_ctxt( + // vir_high::Statement::assume_no_pos(expression), + // span, + // ErrorCtxt::UnexpectedAssumeMethodPostcondition, + // self.def_id, + // )?; + // if self.check_mode != CheckMode::CoreProof { + // post_call_block_builder.add_statement(assume_statement); + // } + // } } post_call_block_builder.build(); @@ -1747,7 +2103,8 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { cleanup_block_builder .set_successor_jump(vir_high::Successor::Goto(encoded_cleanup_block)); - let statement = vir_high::Statement::inhale_no_pos(target_memory_block); + let statement = + vir_high::Statement::inhale_predicate_no_pos(target_memory_block); cleanup_block_builder.add_statement(self.encoder.set_statement_error_ctxt( statement, span, @@ -1994,7 +2351,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { self.def_id, )?; - if self.check_mode != CheckMode::CoreProof { + if self.check_mode.check_specifications() { encoded_statements.push(assert_stmt); } @@ -2041,7 +2398,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { self.encoder .set_statement_error_ctxt(stmt, span, error_ctxt, self.def_id)?; - if self.check_mode != CheckMode::CoreProof { + if self.check_mode.check_specifications() { encoded_statements.push(stmt); } @@ -2074,6 +2431,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { Ok(false) } + // TODO: Move this function to a separate file and extract nested functions. fn try_encode_specification_function_call( &mut self, bb: mir::BasicBlock, @@ -2081,6 +2439,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { encoded_statements: &mut Vec, ) -> SpannedEncodingResult { let span = self.encoder.get_mir_terminator_span(block.terminator()); + let location = mir::Location { + block: bb, + statement_index: block.statements.len(), + }; match &block.terminator().kind { mir::TerminatorKind::Call { func: mir::Operand::Constant(box mir::Constant { literal, .. }), @@ -2094,10 +2456,35 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { if let ty::TyKind::FnDef(def_id, _substs) = literal.ty().kind() { let full_called_function_name = self.encoder.env().name.get_absolute_item_name(*def_id); - match full_called_function_name.as_ref() { - "prusti_contracts::prusti_set_union_active_field" => { - assert_eq!(args.len(), 1); - let argument_place = if let mir::Operand::Move(place) = args[0] { + enum ArgKind { + Place(vir_high::Expression), + String(String), + } + fn extract_args<'p, 'v: 'p, 'tcx: 'v>( + mir: &mir::Body<'tcx>, + args: &[mir::Operand<'tcx>], + block: &mir::BasicBlockData<'tcx>, + encoder: &mut ProcedureEncoder<'p, 'v, 'tcx>, + ) -> SpannedEncodingResult> { + // assert_eq!(args.len(), 1); + let mut encoded_args = Vec::new(); + for arg in args { + eprintln!("arg: {:?}", arg); + match arg { + mir::Operand::Move(_) => {} + mir::Operand::Constant(constant) => { + // FIXME: There should be a proper way of doing this. + let value = format!("{:?}", constant); + let value = + value.trim_start_matches("const \"").trim_end_matches("\""); + encoded_args.push(ArgKind::String(value.to_string())); + continue; // FIXME: Do proper control flow. + } + _ => { + unreachable!() + } + } + let argument_place = if let mir::Operand::Move(place) = arg { place } else { unreachable!() @@ -2105,15 +2492,31 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { // Find the place whose address was stored in the argument by // iterating backwards through statements. let mut statement_index = block.statements.len() - 1; - let union_variant_place = loop { + let place = loop { if let Some(statement) = block.statements.get(statement_index) { - if let mir::StatementKind::Assign(box ( - target_place, - mir::Rvalue::AddressOf(_, union_variant_place), - )) = &statement.kind + eprintln!("statement: {:?}", statement); + if let mir::StatementKind::Assign(box (target_place, rvalue)) = + &statement.kind { - if *target_place == argument_place { - break union_variant_place; + if target_place == argument_place { + match rvalue { + mir::Rvalue::AddressOf(_, place) => { + break encoder.encode_place(*place, None)?; + } + mir::Rvalue::Use(operand) => { + break encoder + .encoder + .encode_operand_high( + mir, + operand, + statement.source_info.span, + ) + .with_span(statement.source_info.span)?; + } + _ => { + unimplemented!("rvalue: {:?}", rvalue); + } + } } } statement_index -= 1; @@ -2121,8 +2524,65 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { unreachable!(); } }; - let encoded_variant_place = - self.encode_place(*union_variant_place, None)?; + encoded_args.push(ArgKind::Place(place)); + } + Ok(encoded_args) + } + fn extract_places<'p, 'v: 'p, 'tcx: 'v>( + mir: &mir::Body<'tcx>, + args: &[mir::Operand<'tcx>], + block: &mir::BasicBlockData<'tcx>, + encoder: &mut ProcedureEncoder<'p, 'v, 'tcx>, + ) -> SpannedEncodingResult> { + let places = extract_args(mir, args, block, encoder)? + .into_iter() + .map(|arg| match arg { + ArgKind::Place(place) => place, + ArgKind::String(_) => unreachable!(), + }) + .collect(); + Ok(places) + } + fn extract_place<'p, 'v: 'p, 'tcx: 'v>( + mir: &mir::Body<'tcx>, + args: &[mir::Operand<'tcx>], + block: &mir::BasicBlockData<'tcx>, + encoder: &mut ProcedureEncoder<'p, 'v, 'tcx>, + ) -> SpannedEncodingResult { + assert_eq!(args.len(), 1); + Ok(extract_places(mir, args, block, encoder)?.pop().unwrap()) + } + match full_called_function_name.as_ref() { + "prusti_contracts::prusti_set_union_active_field" => { + assert_eq!(args.len(), 1); + // assert_eq!(args.len(), 1); + // let argument_place = if let mir::Operand::Move(place) = args[0] { + // place + // } else { + // unreachable!() + // }; + // // Find the place whose address was stored in the argument by + // // iterating backwards through statements. + // let mut statement_index = block.statements.len() - 1; + // let union_variant_place = loop { + // if let Some(statement) = block.statements.get(statement_index) { + // if let mir::StatementKind::Assign(box ( + // target_place, + // mir::Rvalue::AddressOf(_, union_variant_place), + // )) = &statement.kind + // { + // if *target_place == argument_place { + // break union_variant_place; + // } + // } + // statement_index -= 1; + // } else { + // unreachable!(); + // } + // }; + // let encoded_variant_place = + // self.encode_place(*union_variant_place, None)?; + let encoded_variant_place = extract_place(self.mir, args, block, self)?; let statement = self.encoder.set_statement_error_ctxt( vir_high::Statement::set_union_variant_no_pos( encoded_variant_place, @@ -2135,6 +2595,563 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { encoded_statements.push(statement); Ok(true) } + "prusti_contracts::prusti_manually_manage" => { + let encoded_place = extract_place(self.mir, args, block, self)?; + assert!(self.manually_managed_places.insert(encoded_place)); + Ok(true) + } + "prusti_contracts::prusti_pack_place" => { + let encoded_place = extract_place(self.mir, args, block, self)?; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::pack_no_pos( + encoded_place, + vir_high::PredicateKind::Owned, + ), + span, + ErrorCtxt::Pack, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_unpack_place" => { + let encoded_place = extract_place(self.mir, args, block, self)?; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::unpack_no_pos( + encoded_place, + vir_high::PredicateKind::Owned, + ), + span, + ErrorCtxt::Unpack, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_pack_ref_place" => { + assert_eq!(args.len(), 2); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::Place(place) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::String(lifetime_name) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + assert!(encoded_args.is_empty()); + let lifetime = self + .user_named_lifetimes + .get(&lifetime_name) + .unwrap() + .clone(); + // let encoded_place = extract_place(self.mir, args, block, self)?; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::pack_no_pos( + place, + vir_high::PredicateKind::frac_ref(lifetime), + ), + span, + ErrorCtxt::Pack, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_unpack_ref_place" => { + assert_eq!(args.len(), 2); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::Place(place) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::String(lifetime_name) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + assert!(encoded_args.is_empty()); + let lifetime = self + .user_named_lifetimes + .get(&lifetime_name) + .unwrap() + .clone(); + // let encoded_place = extract_place(self.mir, args, block, self)?; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::unpack_no_pos( + place, + vir_high::PredicateKind::frac_ref(lifetime), + ), + span, + ErrorCtxt::Unpack, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_pack_mut_ref_place" => { + assert_eq!(args.len(), 2); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::Place(place) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::String(lifetime_name) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + assert!(encoded_args.is_empty()); + let lifetime = self + .user_named_lifetimes + .get(&lifetime_name) + .unwrap() + .clone(); + // let encoded_place = extract_place(self.mir, args, block, self)?; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::pack_no_pos( + place, + vir_high::PredicateKind::unique_ref(lifetime), + ), + span, + ErrorCtxt::Pack, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_unpack_mut_ref_place" => { + assert_eq!(args.len(), 2); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::Place(place) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::String(lifetime_name) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + assert!(encoded_args.is_empty()); + let lifetime = self + .user_named_lifetimes + .get(&lifetime_name) + .unwrap() + .clone(); + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::unpack_no_pos( + place, + vir_high::PredicateKind::unique_ref(lifetime), + ), + span, + ErrorCtxt::Unpack, + self.def_id, + )?; + // let encoded_place = extract_place(self.mir, args, block, self)?; + // let statement = self.encoder.set_statement_error_ctxt( + // vir_high::Statement::unpack_no_pos( + // encoded_place, + // vir_high::PredicateKind::UniqueRef, + // ), + // span, + // ErrorCtxt::Unpack, + // self.def_id, + // )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_take_lifetime" => { + assert_eq!(args.len(), 2); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::String(lifetime_name) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::Place(place) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + assert!(encoded_args.is_empty()); + let vir_high::ty::Type::Reference(ref_type) = place.get_type() else { + unimplemented!("FIXME: A proper error message."); + }; + let lifetime = ref_type.lifetime.clone(); + assert!(self + .user_named_lifetimes + .insert(lifetime_name, lifetime) + .is_none()); + Ok(true) + } + "prusti_contracts::prusti_join_place" => { + let encoded_place = extract_place(self.mir, args, block, self)?; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::join_no_pos(encoded_place), + span, + ErrorCtxt::Pack, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_join_range" => { + assert_eq!(args.len(), 3); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::Place(end_index) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::Place(start_index) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::Place(pointer) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::join_range_no_pos( + pointer.clone(), + start_index.clone(), + end_index.clone(), + ), + span, + ErrorCtxt::JoinRange, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_split_place" => { + let encoded_place = extract_place(self.mir, args, block, self)?; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::split_no_pos(encoded_place), + span, + ErrorCtxt::Unpack, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_split_range" => { + assert_eq!(args.len(), 3); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::Place(end_index) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::Place(start_index) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::Place(pointer) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::split_range_no_pos( + pointer.clone(), + start_index.clone(), + end_index.clone(), + ), + span, + ErrorCtxt::SplitRange, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_stash_range" => { + assert_eq!(args.len(), 4); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::String(stash_name) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::Place(end_index) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::Place(start_index) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::Place(pointer) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + encoded_statements.push(vir_high::Statement::old_label( + stash_name.clone(), + self.encoder.register_error( + span, + ErrorCtxt::StashRange, + self.def_id, + ), + )); + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::stash_range_no_pos( + pointer.clone(), + start_index.clone(), + end_index.clone(), + stash_name.clone(), + ), + span, + ErrorCtxt::StashRange, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + encoded_statements.push(vir_high::Statement::old_label( + format!("{}$post", stash_name), + self.encoder.register_error( + span, + ErrorCtxt::StashRange, + self.def_id, + ), + )); + assert!(self + .stashed_ranges + .insert(stash_name, (pointer, start_index, end_index)) + .is_none()); + Ok(true) + } + "prusti_contracts::prusti_restore_stash_range" => { + assert_eq!(args.len(), 3); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::String(stash_name) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::Place(new_start_index) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let ArgKind::Place(new_pointer) = encoded_args.pop().unwrap() else { + unreachable!("Wrong function parameters?"); + }; + let (old_pointer, old_start_index, old_end_index) = + self.stashed_ranges.get(&stash_name).unwrap().clone(); + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::stash_range_restore_no_pos( + old_pointer, + old_start_index, + old_end_index, + stash_name, + new_pointer, + new_start_index, + ), + span, + ErrorCtxt::RestoreStashRange, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_close_ref_place" => { + assert_eq!(args.len(), 1); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::String(witness) = encoded_args.pop().unwrap() else { + unreachable!() + }; + assert!(encoded_args.is_empty()); + // FIXME: These should actually remove the + // witnesses. However, since specification blocks + // are processed before all other blocks, the state + // cannot be easily transfered. A proper solution + // would be to check whether the state that uses the + // opened permission is dominated by the statement + // that opens the reference. Alternatively, we could + // have annotations that specify which permission + // amount to use for copy statements. Another + // alternative (probably the easiest) would be to + // make a static analysis that inserts the right + // permission amount into the copy statement. + let (place, lifetime) = self + .opened_reference_witnesses + .get(&witness) + .expect("FIXME: a proper error message"); + let variable = self + .opened_reference_place_permissions + .get(&place) + .expect("FIXME: A proper error message"); + // let deref_base = place.get_last_dereferenced_reference().cloned(); + // let statement = self.encode_close_reference( + // location, + // &deref_base, + // place.clone(), + // variable.clone(), + // )?; + let statement = self.set_statement_error( + location, + ErrorCtxt::CloseFracRef, + vir_high::Statement::close_frac_ref_no_pos( + lifetime.clone(), + self.lifetime_token_fractional_permission(self.lifetime_count), + place.clone(), + variable.clone().unwrap(), + ), + )?; + encoded_statements.push(statement); + // encoded_statements.push(statement.expect( + // "FIXME: A proper error message for closing not a reference", + // )); + Ok(true) + } + "prusti_contracts::prusti_open_ref_place" => { + assert_eq!(args.len(), 3); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::String(witness) = encoded_args.pop().unwrap() else { + unreachable!() + }; + let ArgKind::Place(place) = encoded_args.pop().unwrap() else { + unreachable!() + }; + let ArgKind::String(lifetime_name) = encoded_args.pop().unwrap() else { + unreachable!() + }; + assert!(encoded_args.is_empty()); + let lifetime = self + .user_named_lifetimes + .get(&lifetime_name) + .unwrap() + .clone(); + let permission = self + .fresh_ghost_variable("tmp_frac_ref_perm", vir_high::Type::MPerm); + let variable = Some(permission.clone()); + let statement = self.set_statement_error( + location, + ErrorCtxt::OpenFracRef, + vir_high::Statement::open_frac_ref_no_pos( + lifetime.clone(), + permission, + self.lifetime_token_fractional_permission(self.lifetime_count), + place.clone(), + ), + )?; + + // let deref_place = place.get_last_dereferenced_reference().cloned(); + // let (variable, statement) = + // self.encode_open_reference(location, &deref_place, place.clone())?; + encoded_statements.push(statement); + assert!(self + .opened_reference_place_permissions + .insert(place.clone(), variable) + .is_none()); + assert!(self + .opened_reference_witnesses + .insert(witness, (place, lifetime.clone())) + .is_none()); + Ok(true) + } + "prusti_contracts::prusti_close_mut_ref_place" => { + assert_eq!(args.len(), 1); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::String(witness) = encoded_args.pop().unwrap() else { + unreachable!() + }; + assert!(encoded_args.is_empty()); + // FIXME: These should actually remove the + // witnesses. However, since specification blocks + // are processed before all other blocks, the state + // cannot be easily transfered. A proper solution + // would be to check whether the state that uses the + // opened permission is dominated by the statement + // that opens the reference. Alternatively, we could + // have annotations that specify which permission + // amount to use for copy statements. Another + // alternative (probably the easiest) would be to + // make a static analysis that inserts the right + // permission amount into the copy statement. + let (place, lifetime) = self + .opened_reference_witnesses + .get(&witness) + .expect("FIXME: a proper error message"); + // let variable = self + // .opened_reference_place_permissions + // .get(&place) + // .expect("FIXME: A proper error message"); + // let deref_base = place.get_last_dereferenced_reference().cloned(); + let statement = self.set_statement_error( + location, + ErrorCtxt::CloseMutRef, + vir_high::Statement::close_mut_ref_no_pos( + lifetime.clone(), + self.lifetime_token_fractional_permission(self.lifetime_count), + place.clone(), + ), + )?; + // let statement = self.encode_close_reference( + // location, + // &deref_base, + // place.clone(), + // variable.clone(), + // )?; + // encoded_statements.push(statement.expect( + // "FIXME: A proper error message for closing not a reference", + // )); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_open_mut_ref_place" => { + assert_eq!(args.len(), 3); + let mut encoded_args = extract_args(self.mir, args, block, self)?; + let ArgKind::String(witness) = encoded_args.pop().unwrap() else { + unreachable!() + }; + let ArgKind::Place(place) = encoded_args.pop().unwrap() else { + unreachable!() + }; + let ArgKind::String(lifetime_name) = encoded_args.pop().unwrap() else { + unreachable!() + }; + assert!(encoded_args.is_empty()); + // let lifetime = self + // .user_named_lifetimes + // .get(&lifetime_name) + // .unwrap() + // .clone(); + let Some(lifetime) = self + .user_named_lifetimes + .get(&lifetime_name) + .cloned() else { + return Err(SpannedEncodingError::incorrect( + format!("Lifetime name `{}` not defined", lifetime_name), span)); + }; + let statement = self.set_statement_error( + location, + ErrorCtxt::OpenMutRef, + vir_high::Statement::open_mut_ref_no_pos( + lifetime.clone(), + self.lifetime_token_fractional_permission(self.lifetime_count), + place.clone(), + ), + )?; + encoded_statements.push(statement); + assert!(self + .opened_reference_place_permissions + .insert(place.clone(), None) + .is_none()); + assert!(self + .opened_reference_witnesses + .insert(witness, (place, lifetime.clone())) + .is_none()); + Ok(true) + } + "prusti_contracts::prusti_forget_initialization" => { + let encoded_place = extract_place(self.mir, args, block, self)?; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::forget_initialization_no_pos(encoded_place), + span, + ErrorCtxt::ForgetInitialization, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } + "prusti_contracts::prusti_restore_place" => { + assert_eq!(args.len(), 2); + let mut encoded_places = extract_places(self.mir, args, block, self)?; + let restored_place = encoded_places.pop().unwrap(); + let borrowing_place = encoded_places.pop().unwrap(); + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::restore_raw_borrowed_no_pos( + borrowing_place, + restored_place, + ), + span, + ErrorCtxt::RestoreRawBorrowed, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } _ => unreachable!(), } } else { @@ -2144,4 +3161,15 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { _ => unreachable!("block: {:?}", bb), } } + + fn is_pure(&self, def_id: DefId, substs: Option>) -> bool { + self.encoder.is_pure(def_id, substs) + // || { + // // FIXME: We do this because extern specs do not support primitive + // // types. + // let func_name = self.encoder.env().name.get_unique_item_name(def_id); + // func_name.starts_with("std::ptr::mut_ptr::::is_null") + // || func_name.starts_with("core::std::ptr::mut_ptr::::is_null") + // } + } } diff --git a/prusti-viper/src/encoder/mir/procedures/encoder/specification_blocks.rs b/prusti-viper/src/encoder/mir/procedures/encoder/specification_blocks.rs index eb258f5547a..0a8cfa90507 100644 --- a/prusti-viper/src/encoder/mir/procedures/encoder/specification_blocks.rs +++ b/prusti-viper/src/encoder/mir/procedures/encoder/specification_blocks.rs @@ -32,7 +32,7 @@ impl SpecificationBlocks { pub fn build<'tcx>( env_query: EnvQuery<'tcx>, body: &mir::Body<'tcx>, - procedure: &Procedure<'tcx>, + procedure: Option<&Procedure<'tcx>>, collect_loop_invariants: bool, ) -> Self { // Blocks that contain closures marked with `#[spec_only]` attributes. @@ -77,11 +77,13 @@ impl SpecificationBlocks { } // Collect loop invariant blocks. - let loop_info = procedure.loop_info(); - let predecessors = body.basic_blocks.predecessors(); let mut loop_invariant_blocks = BTreeMap::<_, LoopInvariantBlocks>::new(); let mut loop_spec_blocks_flat = BTreeSet::new(); if collect_loop_invariants { + let loop_info = procedure + .expect("procedure needs to be Some when collect_loop_invariants is true") + .loop_info(); + let predecessors = body.basic_blocks.predecessors(); // We use reverse_postorder here because we need to make sure that we // preserve the order of invariants in which they were specified by the // user. @@ -176,7 +178,7 @@ impl SpecificationBlocks { self.specification_entry_blocks.iter().cloned() } - pub(super) fn is_specification_block(&self, bb: mir::BasicBlock) -> bool { + pub fn is_specification_block(&self, bb: mir::BasicBlock) -> bool { self.specification_blocks.contains(&bb) } diff --git a/prusti-viper/src/encoder/mir/procedures/encoder/termination.rs b/prusti-viper/src/encoder/mir/procedures/encoder/termination.rs index 06376864ac8..f1ac8d4153c 100644 --- a/prusti-viper/src/encoder/mir/procedures/encoder/termination.rs +++ b/prusti-viper/src/encoder/mir/procedures/encoder/termination.rs @@ -49,7 +49,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> super::ProcedureEncoder<'p, 'v, 'tcx> { })?; let expression = self.encoder.encode_assertion_high( - expr.to_def_id(), + expr, None, arguments, None, @@ -80,7 +80,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> super::ProcedureEncoder<'p, 'v, 'tcx> { arguments.push(self.encode_local(local)?.into()); } - if self.encoder.terminates(self.def_id, None) && self.check_mode != CheckMode::CoreProof { + if self.encoder.terminates(self.def_id, None) && self.check_mode.check_specifications() { let termination_expr = self.encode_termination_expression( &procedure_contract, mir_span, diff --git a/prusti-viper/src/encoder/mir/procedures/interface.rs b/prusti-viper/src/encoder/mir/procedures/interface.rs index 68e42152cb9..b2a7932fff3 100644 --- a/prusti-viper/src/encoder/mir/procedures/interface.rs +++ b/prusti-viper/src/encoder/mir/procedures/interface.rs @@ -1,6 +1,9 @@ use crate::encoder::{ errors::SpannedEncodingResult, - mir::{procedures::passes, spans::SpanInterface}, + mir::{ + procedures::{encoder::ProcedureEncodingKind, passes}, + spans::SpanInterface, + }, }; use prusti_rustc_interface::{hir::def_id::DefId, middle::mir, span::Span}; use rustc_hash::FxHashMap; @@ -17,7 +20,7 @@ pub(crate) trait MirProcedureEncoderInterface<'tcx> { &mut self, proc_def_id: DefId, check_mode: CheckMode, - ) -> SpannedEncodingResult; + ) -> SpannedEncodingResult>; fn get_span_of_location(&self, mir: &mir::Body<'tcx>, location: mir::Location) -> Span; } @@ -26,17 +29,34 @@ impl<'v, 'tcx: 'v> MirProcedureEncoderInterface<'tcx> for super::super::super::E &mut self, proc_def_id: DefId, check_mode: CheckMode, - ) -> SpannedEncodingResult { - let procedure = super::encoder::encode_procedure(self, proc_def_id, check_mode)?; + ) -> SpannedEncodingResult> { + let procedure = super::encoder::encode_procedure( + self, + proc_def_id, + check_mode, + ProcedureEncodingKind::Regular, + )?; let procedure = passes::run_passes(self, procedure)?; + let mut procedures = Vec::new(); + if check_mode.check_core_proof() { + let postcondition_check = super::encoder::encode_procedure( + self, + proc_def_id, + check_mode, + ProcedureEncodingKind::PostconditionFrameCheck, + )?; + procedures.push(postcondition_check); + } assert!( self.mir_procedure_encoder_state .encoded_procedure_def_ids .insert(procedure.name.clone(), (proc_def_id, check_mode)) .is_none(), - "The procedure was encoed twice: {proc_def_id:?}" + "The procedure was encoded twice: {:?}", + proc_def_id ); - Ok(procedure) + procedures.push(procedure); + Ok(procedures) } fn get_span_of_location(&self, mir: &mir::Body<'tcx>, location: mir::Location) -> Span { self.get_mir_location_span(mir, location) diff --git a/prusti-viper/src/encoder/mir/procedures/passes/assertions.rs b/prusti-viper/src/encoder/mir/procedures/passes/assertions.rs index d50a34c4e08..22d6dcccfd9 100644 --- a/prusti-viper/src/encoder/mir/procedures/passes/assertions.rs +++ b/prusti-viper/src/encoder/mir/procedures/passes/assertions.rs @@ -25,7 +25,7 @@ pub(in super::super) fn propagate_assertions_back<'v, 'tcx: 'v>( can_be_soundly_skipped = match &block.statements[statement_index] { vir_high::Statement::Comment(_) | vir_high::Statement::OldLabel(_) - | vir_high::Statement::Inhale(vir_high::Inhale { + | vir_high::Statement::InhalePredicate(vir_high::InhalePredicate { predicate: vir_high::Predicate::LifetimeToken(_) | vir_high::Predicate::MemoryBlockStack(_) @@ -34,10 +34,12 @@ pub(in super::super) fn propagate_assertions_back<'v, 'tcx: 'v>( | vir_high::Predicate::MemoryBlockHeapDrop(_), position: _, }) - | vir_high::Statement::Exhale(_) + | vir_high::Statement::ExhalePredicate(_) + | vir_high::Statement::ExhaleExpression(_) | vir_high::Statement::Consume(_) | vir_high::Statement::Havoc(_) | vir_high::Statement::GhostHavoc(_) + | vir_high::Statement::HeapHavoc(_) | vir_high::Statement::Assert(_) | vir_high::Statement::MovePlace(_) | vir_high::Statement::CopyPlace(_) @@ -59,7 +61,19 @@ pub(in super::super) fn propagate_assertions_back<'v, 'tcx: 'v>( | vir_high::Statement::CloseMutRef(_) | vir_high::Statement::CloseFracRef(_) | vir_high::Statement::BorShorten(_) => true, - vir_high::Statement::Assume(_) | vir_high::Statement::Inhale(_) => false, + vir_high::Statement::Pack(_) + | vir_high::Statement::Unpack(_) + | vir_high::Statement::Join(_) + | vir_high::Statement::JoinRange(_) + | vir_high::Statement::Split(_) + | vir_high::Statement::SplitRange(_) + | vir_high::Statement::ForgetInitialization(_) + | vir_high::Statement::RestoreRawBorrowed(_) + | vir_high::Statement::Assume(_) + | vir_high::Statement::InhalePredicate(_) + | vir_high::Statement::InhaleExpression(_) + | vir_high::Statement::StashRange(_) + | vir_high::Statement::StashRangeRestore(_) => false, vir_high::Statement::LoopInvariant(_) => unreachable!(), }; } diff --git a/prusti-viper/src/encoder/mir/procedures/passes/loop_desugaring.rs b/prusti-viper/src/encoder/mir/procedures/passes/loop_desugaring.rs index 287de540068..de4309e5924 100644 --- a/prusti-viper/src/encoder/mir/procedures/passes/loop_desugaring.rs +++ b/prusti-viper/src/encoder/mir/procedures/passes/loop_desugaring.rs @@ -71,7 +71,7 @@ pub(in super::super) fn desugar_loops<'v, 'tcx: 'v>( )); for assertion in &loop_invariant.functional_specifications { let statement = encoder.set_surrounding_error_context_for_statement( - vir_high::Statement::assert_no_pos(assertion.clone()), + vir_high::Statement::exhale_expression_no_pos(assertion.clone()), loop_invariant.position, ErrorCtxt::AssertLoopInvariantOnEntry, )?; @@ -122,7 +122,7 @@ pub(in super::super) fn desugar_loops<'v, 'tcx: 'v>( for assertion in loop_invariant.functional_specifications { let statement = encoder.set_surrounding_error_context_for_statement( - vir_high::Statement::assume_no_pos(assertion), + vir_high::Statement::inhale_expression_no_pos(assertion), loop_invariant.position, ErrorCtxt::UnexpectedAssumeLoopInvariantOnEntry, )?; @@ -198,7 +198,7 @@ fn duplicate_blocks<'v, 'tcx: 'v>( let loop_invariant = block.statements.pop().unwrap().unwrap_loop_invariant(); for assertion in loop_invariant.functional_specifications { let statement = encoder.set_surrounding_error_context_for_statement( - vir_high::Statement::assert_no_pos(assertion), + vir_high::Statement::exhale_expression_no_pos(assertion), loop_invariant.position, ErrorCtxt::AssertLoopInvariantAfterIteration, )?; diff --git a/prusti-viper/src/encoder/mir/pure/interpreter/interpreter_high.rs b/prusti-viper/src/encoder/mir/pure/interpreter/interpreter_high.rs index 08da906ed5e..40884d1793d 100644 --- a/prusti-viper/src/encoder/mir/pure/interpreter/interpreter_high.rs +++ b/prusti-viper/src/encoder/mir/pure/interpreter/interpreter_high.rs @@ -17,6 +17,7 @@ use crate::encoder::{ casts::CastsEncoderInterface, generics::MirGenericsEncoderInterface, places::PlacesEncoderInterface, + procedures::encoder::specification_blocks::SpecificationBlocks, pure::{ interpreter::BackwardMirInterpreter, PureEncodingContext, PureFunctionEncoderInterface, SpecificationEncoderInterface, @@ -50,6 +51,8 @@ pub(in super::super) struct ExpressionBackwardInterpreter<'p, 'v: 'p, 'tcx: 'v> encoder: &'p Encoder<'v, 'tcx>, /// MIR of the pure function being encoded. mir: &'p mir::Body<'tcx>, + /// The specification blocks used in the pure function. + specification_blocks: SpecificationBlocks, /// MirEncoder of the pure function being encoded. mir_encoder: MirEncoder<'p, 'v, 'tcx>, /// How panics are handled depending on the encoding context. @@ -73,9 +76,12 @@ impl<'p, 'v: 'p, 'tcx: 'v> ExpressionBackwardInterpreter<'p, 'v, 'tcx> { caller_def_id: DefId, substs: SubstsRef<'tcx>, ) -> Self { + let specification_blocks = + SpecificationBlocks::build(encoder.env().query, mir, None, false); Self { encoder, mir, + specification_blocks, mir_encoder: MirEncoder::new(encoder, mir, def_id), pure_encoding_context, caller_def_id, @@ -322,8 +328,19 @@ impl<'p, 'v: 'p, 'tcx: 'v> ExpressionBackwardInterpreter<'p, 'v, 'tcx> { let expr = vir_high::Expression::constructor_no_pos(ty, arguments); state.substitute_value(&encoded_lhs, expr); } + mir::Rvalue::AddressOf(_, place) => { + let encoded_place = self.encoder.encode_place_high(self.mir, *place, None)?; + let ty = self + .encoder + .encode_type_of_place_high(self.mir, *place) + .with_span(span)?; + let expr = vir_high::Expression::addr_of_no_pos( + encoded_place, + vir_high::Type::pointer(ty), + ); + state.substitute_value(&encoded_lhs, expr); + } mir::Rvalue::ThreadLocalRef(..) - | mir::Rvalue::AddressOf(..) | mir::Rvalue::ShallowInitBox(..) | mir::Rvalue::NullaryOp(..) => { return Err(SpannedEncodingError::unsupported( @@ -653,6 +670,77 @@ impl<'p, 'v: 'p, 'tcx: 'v> ExpressionBackwardInterpreter<'p, 'v, 'tcx> { } match proc_name { + "prusti_contracts::prusti_own" => { + assert_eq!(encoded_args.len(), 1); + let place = encoded_args[0].clone(); + let position = place.position(); + let encoded_rhs = vir_high::Expression::acc_predicate( + vir_high::Predicate::owned_non_aliased(place, position), + position, + ); + subst_with(encoded_rhs) + } + "prusti_contracts::prusti_own_range" => { + assert_eq!(encoded_args.len(), 3); + let address = encoded_args[0].clone(); + let start = encoded_args[1].clone(); + let end = encoded_args[2].clone(); + let position = address.position(); + let encoded_rhs = vir_high::Expression::acc_predicate( + vir_high::Predicate::owned_range(address, start, end, position), + position, + ); + subst_with(encoded_rhs) + } + "prusti_contracts::prusti_raw" => { + assert_eq!(encoded_args.len(), 2); + let address = encoded_args[0].clone(); + let size = encoded_args[1].clone(); + let position = address.position(); + let encoded_rhs = vir_high::Expression::acc_predicate( + vir_high::Predicate::memory_block_heap(address, size, position), + position, + ); + subst_with(encoded_rhs) + } + "prusti_contracts::prusti_raw_range" => { + assert_eq!(encoded_args.len(), 4); + let address = encoded_args[0].clone(); + let size = encoded_args[1].clone(); + let start = encoded_args[2].clone(); + let end = encoded_args[3].clone(); + let position = address.position(); + let encoded_rhs = vir_high::Expression::acc_predicate( + vir_high::Predicate::memory_block_heap_range( + address, size, start, end, position, + ), + position, + ); + subst_with(encoded_rhs) + } + "prusti_contracts::prusti_raw_dealloc" => { + assert_eq!(encoded_args.len(), 2); + let address = encoded_args[0].clone(); + let size = encoded_args[1].clone(); + let position = address.position(); + let encoded_rhs = vir_high::Expression::acc_predicate( + vir_high::Predicate::memory_block_heap_drop(address, size, position), + position, + ); + subst_with(encoded_rhs) + } + "prusti_contracts::prusti_unpacking" => { + assert_eq!(encoded_args.len(), 2); + let place = encoded_args[0].clone(); + let body = encoded_args[1].clone(); + let position = place.position(); + let encoded_rhs = vir_high::Expression::unfolding( + vir_high::Predicate::owned_non_aliased(place, position), + body, + position, + ); + subst_with(encoded_rhs) + } "prusti_contracts::old" => { let argument = encoded_args.last().cloned().unwrap(); let position = argument.position(); @@ -739,6 +827,18 @@ impl<'p, 'v: 'p, 'tcx: 'v> ExpressionBackwardInterpreter<'p, 'v, 'tcx> { .map(Some), } } + "std::ptr::mut_ptr::::is_null" => { + assert_eq!(encoded_args.len(), 1); + builtin((IsNull, vir_high::Type::Bool)) + } + "std::mem::size_of" => { + assert_eq!(encoded_args.len(), 0); + builtin((Size, vir_high::Type::Int(vir_high::ty::Int::Usize))) + } + "std::mem::align_of" => { + assert_eq!(encoded_args.len(), 0); + builtin((Align, vir_high::Type::Int(vir_high::ty::Int::Usize))) + } // Prusti-specific syntax // TODO: check we are in a spec function @@ -873,7 +973,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> BackwardMirInterpreter<'tcx> fn apply_terminator( &self, - _bb: mir::BasicBlock, + bb: mir::BasicBlock, terminator: &mir::Terminator<'tcx>, states: FxHashMap, ) -> Result { @@ -951,6 +1051,14 @@ impl<'p, 'v: 'p, 'tcx: 'v> BackwardMirInterpreter<'tcx> func: mir::Operand::Constant(box mir::Constant { literal, .. }), .. } => { + if self.specification_blocks.is_specification_block(bb) { + trace!("Skipping call terminator because inside a specification block"); + if let Some(target) = target { + return Ok(states[target].clone()); + } else { + unimplemented!(); + } + } self.apply_call_terminator(args, *destination, target, literal.ty(), states, span)? } @@ -1048,6 +1156,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> BackwardMirInterpreter<'tcx> state: &mut Self::State, ) -> Result<(), Self::Error> { trace!("apply_statement {:?}, state: {}", statement, state); + if self.specification_blocks.is_specification_block(bb) { + trace!("Skipping statement because inside a specification block"); + return Ok(()); + } let span = statement.source_info.span; let location = mir::Location { block: bb, diff --git a/prusti-viper/src/encoder/mir/pure/pure_functions/cleaner.rs b/prusti-viper/src/encoder/mir/pure/pure_functions/cleaner.rs new file mode 100644 index 00000000000..37f5fae7a08 --- /dev/null +++ b/prusti-viper/src/encoder/mir/pure/pure_functions/cleaner.rs @@ -0,0 +1,215 @@ +use crate::encoder::{ + errors::{SpannedEncodingError, SpannedEncodingResult}, + Encoder, +}; +use prusti_rustc_interface::span::Span; +use vir_crate::{ + common::{expression::SyntacticEvaluation, position::Positioned}, + high::{ + self as vir_high, + visitors::{ + default_fallible_fold_acc_predicate, default_fallible_fold_binary_op, + default_fallible_fold_unfolding, ExpressionFallibleFolder, + }, + }, +}; + +/// When encoding an assertion we sometimes get strange artefacts as a result of +/// using procedural macros. This functions removes them. +pub(super) fn clean_encoding_result<'p, 'v: 'p, 'tcx: 'v>( + encoder: &'p Encoder<'v, 'tcx>, + expression: vir_high::Expression, + span: Span, +) -> SpannedEncodingResult { + let _position = expression.position(); + let mut cleaner = Cleaner { encoder, span }; + + // TODO: Check that permission is never negated. + let expression = cleaner.fallible_fold_expression(expression)?; + + Ok(expression) +} + +struct Cleaner<'p, 'v: 'p, 'tcx: 'v> { + encoder: &'p Encoder<'v, 'tcx>, + span: Span, +} + +fn peel_addr_of(place: vir_high::Expression) -> vir_high::Expression { + match place { + vir_high::Expression::AddrOf(vir_high::AddrOf { base, .. }) => *base, + _ => { + unreachable!("mut be addr_of: {}", place) + } + } +} + +fn clean_acc_predicate(predicate: vir_high::Predicate) -> vir_high::Predicate { + match predicate { + vir_high::Predicate::OwnedNonAliased(mut predicate) => { + predicate.place = peel_addr_of(predicate.place); + vir_high::Predicate::OwnedNonAliased(predicate) + } + // vir_high::Predicate::OwnedNonAliased(vir_high::OwnedNonAliased { + // place: vir_high::Expression::AddrOf(vir_high::AddrOf { base, .. }), position + // }) => { + // vir_high::Predicate::owned_non_aliased(*base, position) + // } + vir_high::Predicate::MemoryBlockHeap(mut predicate) => { + predicate.address = peel_addr_of(predicate.address); + vir_high::Predicate::MemoryBlockHeap(predicate) + } + vir_high::Predicate::MemoryBlockHeapRange(mut predicate) => { + predicate.address = peel_addr_of(predicate.address); + vir_high::Predicate::MemoryBlockHeapRange(predicate) + } + vir_high::Predicate::MemoryBlockHeapDrop(mut predicate) => { + predicate.address = peel_addr_of(predicate.address); + vir_high::Predicate::MemoryBlockHeapDrop(predicate) + } + vir_high::Predicate::OwnedRange(mut predicate) => { + predicate.address = peel_addr_of(predicate.address); + vir_high::Predicate::OwnedRange(predicate) + } + _ => unimplemented!("{:?}", predicate), + } +} + +impl<'p, 'v: 'p, 'tcx: 'v> ExpressionFallibleFolder for Cleaner<'p, 'v, 'tcx> { + type Error = SpannedEncodingError; + + fn fallible_fold_acc_predicate( + &mut self, + mut acc_predicate: vir_high::AccPredicate, + ) -> Result { + // let predicate = match *acc_predicate.predicate { + // vir_high::Predicate::OwnedNonAliased(mut predicate) => { + // predicate.place = peel_addr_of(predicate.place); + // vir_high::Predicate::OwnedNonAliased(predicate) + // } + // // vir_high::Predicate::OwnedNonAliased(vir_high::OwnedNonAliased { + // // place: vir_high::Expression::AddrOf(vir_high::AddrOf { base, .. }), position + // // }) => { + // // vir_high::Predicate::owned_non_aliased(*base, position) + // // } + // vir_high::Predicate::MemoryBlockHeap(mut predicate) => { + // predicate.address = peel_addr_of(predicate.address); + // vir_high::Predicate::MemoryBlockHeap(predicate) + // } + // vir_high::Predicate::MemoryBlockHeapDrop(mut predicate) => { + // predicate.address = peel_addr_of(predicate.address); + // vir_high::Predicate::MemoryBlockHeapDrop(predicate) + // } + // _ => unimplemented!("{:?}", acc_predicate), + // }; + let predicate = clean_acc_predicate(*acc_predicate.predicate); + acc_predicate.predicate = Box::new(predicate); + + // if let box vir_high::Expression::AddrOf(vir_high::AddrOf { base, .. }) = acc_predicate.place + // { + // acc_predicate.place = base; + // } else { + // unreachable!("{:?}", acc_predicate); + // }; + default_fallible_fold_acc_predicate(self, acc_predicate) + } + + fn fallible_fold_unfolding( + &mut self, + mut unfolding: vir_high::Unfolding, + ) -> Result { + let predicate = clean_acc_predicate(*unfolding.predicate); + unfolding.predicate = Box::new(predicate); + default_fallible_fold_unfolding(self, unfolding) + } + + fn fallible_fold_conditional_enum( + &mut self, + conditional: vir_high::Conditional, + ) -> Result { + let conditional = self.fallible_fold_conditional(conditional)?; + let expression = match conditional { + _ if conditional.guard.is_true() => *conditional.then_expr, + _ if conditional.guard.is_false() => *conditional.else_expr, + vir_high::Conditional { + guard: + box vir_high::Expression::UnaryOp(vir_high::UnaryOp { + op_kind: vir_high::UnaryOpKind::Not, + argument: guard, + .. + }), + then_expr, + else_expr, + position, + } if then_expr.is_false() || then_expr.is_true() => { + // This happens due to short-circuiting in Rust. + if then_expr.is_false() { + vir_high::Expression::BinaryOp(vir_high::BinaryOp { + op_kind: vir_high::BinaryOpKind::And, + left: guard, + right: else_expr, + position, + }) + } else if then_expr.is_true() { + if !guard.is_pure() { + return Err(SpannedEncodingError::incorrect( + "permission predicates can be only in positive positions", + self.span, + )); + } + vir_high::Expression::BinaryOp(vir_high::BinaryOp { + op_kind: vir_high::BinaryOpKind::Implies, + left: guard, + right: else_expr, + position, + }) + } else { + unreachable!(); + } + } + _ if conditional.else_expr.is_true() => { + // Clean up stuff generated by `own!` expansion. + if !conditional.guard.is_pure() { + unimplemented!("TODO: A proper error message: {conditional}") + } + vir_high::Expression::BinaryOp(vir_high::BinaryOp { + op_kind: vir_high::BinaryOpKind::Implies, + left: conditional.guard, + right: conditional.then_expr, + position: conditional.position, + }) + } + _ => { + if !conditional.guard.is_pure() { + unimplemented!("TODO: A proper error message: {conditional}") + } + return Ok(vir_high::Expression::Conditional(conditional)); + } + }; + // let expression = + // vir_high::Expression::BinaryOp(default_fallible_fold_binary_op(self, binary_op)?); + // let expression = if conditional.else_expr.is_true() { + // let binary_op = ; + // vir_high::Expression::BinaryOp(default_fallible_fold_binary_op(self, binary_op)?) + // } else { + // }; + Ok(expression) + } + + fn fallible_fold_binary_op( + &mut self, + binary_op: vir_high::BinaryOp, + ) -> Result { + if binary_op.op_kind != vir_high::BinaryOpKind::And && !binary_op.left.is_pure() { + unimplemented!("TODO: A proper error message.") + } + if !matches!( + binary_op.op_kind, + vir_high::BinaryOpKind::And | vir_high::BinaryOpKind::Implies + ) && !binary_op.right.is_pure() + { + unimplemented!("TODO: A proper error message.") + } + default_fallible_fold_binary_op(self, binary_op) + } +} diff --git a/prusti-viper/src/encoder/mir/pure/pure_functions/encoder_high.rs b/prusti-viper/src/encoder/mir/pure/pure_functions/encoder_high.rs index 2daa9e166d1..09e2751b8ee 100644 --- a/prusti-viper/src/encoder/mir/pure/pure_functions/encoder_high.rs +++ b/prusti-viper/src/encoder/mir/pure/pure_functions/encoder_high.rs @@ -98,16 +98,17 @@ pub(super) fn encode_pure_expression<'p, 'v: 'p, 'tcx: 'v>( parent_def_id, substs, ); + let span = encoder.env().query.get_def_span(proc_def_id); let state = run_backward_interpretation(&mir, &interpreter)?.ok_or_else(|| { SpannedEncodingError::incorrect( - format!("procedure {proc_def_id:?} contains a loop"), - encoder.env().query.get_def_span(proc_def_id), + format!("procedure {:?} contains a loop", proc_def_id), + span, ) })?; let body = state.into_expr().ok_or_else(|| { SpannedEncodingError::internal( - format!("failed to encode function's body: {proc_def_id:?}"), - encoder.env().query.get_def_span(proc_def_id), + format!("failed to encode function's body: {:?}", proc_def_id), + span, ) })?; debug!( @@ -116,6 +117,7 @@ pub(super) fn encode_pure_expression<'p, 'v: 'p, 'tcx: 'v>( ); // FIXME: Traverse the encoded function and check that all used types are // Copy. Doing this before encoding causes too many false positives. + let body = super::cleaner::clean_encoding_result(encoder, body, span)?; Ok(body) } diff --git a/prusti-viper/src/encoder/mir/pure/pure_functions/mod.rs b/prusti-viper/src/encoder/mir/pure/pure_functions/mod.rs index ab45f165594..422695dce1d 100644 --- a/prusti-viper/src/encoder/mir/pure/pure_functions/mod.rs +++ b/prusti-viper/src/encoder/mir/pure/pure_functions/mod.rs @@ -6,6 +6,7 @@ //! Encoders of pure functions. +mod cleaner; mod interface; mod encoder_high; mod encoder_poly; diff --git a/prusti-viper/src/encoder/mir/pure/specifications/encoder_high.rs b/prusti-viper/src/encoder/mir/pure/specifications/encoder_high.rs index 5a3a5564be4..4c0fecbc697 100644 --- a/prusti-viper/src/encoder/mir/pure/specifications/encoder_high.rs +++ b/prusti-viper/src/encoder/mir/pure/specifications/encoder_high.rs @@ -75,15 +75,31 @@ pub(super) fn inline_spec_item_high<'tcx>( parent_def_id: DefId, substs: SubstsRef<'tcx>, ) -> SpannedEncodingResult { + assert_eq!( + substs.len(), + encoder.env().query.identity_substs(def_id).len() + ); + let mir = encoder .env() .body - .get_spec_body(def_id, substs, parent_def_id); + .get_expression_body(def_id, substs, parent_def_id); assert_eq!( mir.arg_count, target_args.len() + usize::from(target_return.is_some()), "def_id: {def_id:?}" ); + + // let mir = encoder + // .env() + // .body + // .get_spec_body(def_id, substs, parent_def_id); + // assert_eq!( + // mir.arg_count, + // target_args.len() + if target_return.is_some() { 1 } else { 0 }, + // "def_id: {:?}", + // def_id + // ); let mir_encoder = MirEncoder::new(encoder, &mir, def_id); let mut body_replacements = vec![]; for (arg_idx, arg_local) in mir.args_iter().enumerate() { diff --git a/prusti-viper/src/encoder/mir/specifications/interface.rs b/prusti-viper/src/encoder/mir/specifications/interface.rs index 87c88c29917..c7d53bfe237 100644 --- a/prusti-viper/src/encoder/mir/specifications/interface.rs +++ b/prusti-viper/src/encoder/mir/specifications/interface.rs @@ -140,6 +140,7 @@ impl<'v, 'tcx: 'v> SpecificationsInterface<'tcx> for super::super::super::Encode || func_name.starts_with("prusti_contracts::prusti_contracts::Seq") || func_name.starts_with("prusti_contracts::prusti_contracts::Ghost") || func_name.starts_with("prusti_contracts::prusti_contracts::Int") + // || func_name.starts_with("prusti_contracts::prusti_contracts::prusti_own") { pure = true; } diff --git a/prusti-viper/src/encoder/mir/types/encoder.rs b/prusti-viper/src/encoder/mir/types/encoder.rs index 80468ee6aac..6dc0c258bce 100644 --- a/prusti-viper/src/encoder/mir/types/encoder.rs +++ b/prusti-viper/src/encoder/mir/types/encoder.rs @@ -6,9 +6,10 @@ use super::{helpers::compute_discriminant_values, interface::MirTypeEncoderInterface}; use crate::encoder::{ - errors::{EncodingResult, SpannedEncodingError, SpannedEncodingResult, WithSpan}, + errors::{EncodingResult, ErrorCtxt, SpannedEncodingError, SpannedEncodingResult, WithSpan}, mir::{ - constants::ConstantsEncoderInterface, generics::MirGenericsEncoderInterface, + constants::ConstantsEncoderInterface, errors::ErrorInterface, + generics::MirGenericsEncoderInterface, pure::SpecificationEncoderInterface, specifications::SpecificationsInterface, types::helpers::compute_discriminant_ranges, }, Encoder, @@ -62,10 +63,7 @@ impl<'p, 'v, 'r: 'v, 'tcx: 'v> TypeEncoder<'p, 'v, 'tcx> { self.encoder.compute_array_len(size) } - pub fn encode_type( - self, - const_arguments: &[vir::Expression], - ) -> SpannedEncodingResult { + pub fn encode_type(self) -> SpannedEncodingResult { debug!("Encode type '{:?}'", self.ty); // self.encode_polymorphic_predicate_use() let lifetimes = self.encoder.get_lifetimes_from_type_high(self.ty)?; @@ -190,22 +188,26 @@ impl<'p, 'v, 'r: 'v, 'tcx: 'v> TypeEncoder<'p, 'v, 'tcx> { ty::TyKind::Str => vir::Type::Str, ty::TyKind::Array(elem_ty, size) => { - let (array_len, tail): (_, &[vir::Expression]) = - if let Some((array_len, tail)) = const_arguments.split_first() { - (array_len.clone(), tail) - } else { - let array_len: usize = self - .compute_array_len(*size) - .with_span(self.get_definition_span())? - .try_into() - .unwrap(); - (array_len.into(), &[]) - }; + // let (array_len, tail): (_, &[vir::Expression]) = + // if let Some((array_len, tail)) = const_arguments.split_first() { + // (array_len.clone(), tail) + // } else { + // let array_len: usize = self + // .compute_array_len(*size) + // .with_span(self.get_definition_span())? + // .try_into() + // .unwrap(); + // (array_len.into(), &[]) + // }; + let array_len: usize = self + .compute_array_len(*size) + .with_span(self.get_definition_span())? + .try_into() + .unwrap(); let lifetimes = self.encoder.get_lifetimes_from_type_high(*elem_ty)?; vir::Type::array( - vir::ty::ConstGenericArgument::new(Some(Box::new(array_len))), - self.encoder - .encode_type_high_with_const_arguments(*elem_ty, tail)?, + vir::ty::ConstGenericArgument::new(Some(Box::new(array_len.into()))), + self.encoder.encode_type_high(*elem_ty)?, lifetimes, ) } @@ -338,7 +340,11 @@ impl<'p, 'v, 'r: 'v, 'tcx: 'v> TypeEncoder<'p, 'v, 'tcx> { } } - pub fn encode_type_def_high(self) -> SpannedEncodingResult { + pub fn encode_type_def_high( + self, + ty: &vir::Type, + with_invariant: bool, + ) -> SpannedEncodingResult { debug!("Encode type predicate '{:?}'", self.ty); let type_decl = match self.ty.kind() { ty::TyKind::Bool => vir::TypeDecl::bool(), @@ -442,7 +448,9 @@ impl<'p, 'v, 'r: 'v, 'tcx: 'v> TypeEncoder<'p, 'v, 'tcx> { }), "prusti_contracts::Ghost" => { if let ty::subst::GenericArgKind::Type(ty) = substs[0].unpack() { - Self::new(self.encoder, ty).encode_type_def_high()? + let encoded_type = Self::new(self.encoder, ty).encode_type()?; + Self::new(self.encoder, ty) + .encode_type_def_high(&encoded_type, with_invariant)? } else { unreachable!("no type parameter given for Ghost") } @@ -462,7 +470,7 @@ impl<'p, 'v, 'r: 'v, 'tcx: 'v> TypeEncoder<'p, 'v, 'tcx> { ) } ty::TyKind::Adt(adt_def, substs) => { - encode_adt_def(self.encoder, *adt_def, substs, None)? + encode_adt_def(self.encoder, ty, *adt_def, substs, None, with_invariant)? } ty::TyKind::Never => vir::TypeDecl::never(), ty::TyKind::Param(param_ty) => { @@ -696,6 +704,8 @@ fn encode_variant<'v, 'tcx: 'v>( name: String, substs: ty::subst::SubstsRef<'tcx>, variant: &ty::VariantDef, + mut structural_invariant: Option>, + def_id: Option, ) -> SpannedEncodingResult { let tcx = encoder.env().tcx(); let mut fields = Vec::new(); @@ -708,18 +718,87 @@ fn encode_variant<'v, 'tcx: 'v>( } let lifetimes = encoder.get_lifetimes_from_substs(substs)?; let const_parameters = encoder.get_const_parameters_from_substs(substs)?; - let variant = vir::type_decl::Struct::new(name, lifetimes, const_parameters, fields); + let position = if let Some(def_id) = def_id { + let span = encoder.env().query.get_def_span(def_id); + let position = encoder + .error_manager() + .register_error(span, ErrorCtxt::TypeInvariantDefinition, def_id) + .into(); + if let Some(structural_invariant) = &mut structural_invariant { + for expression in std::mem::take(structural_invariant) { + structural_invariant.push(encoder.set_surrounding_error_context_for_expression( + expression, + position, + ErrorCtxt::TypeInvariantDefinition, + )); + } + } + position + } else { + Default::default() + }; + let variant = vir::type_decl::Struct::new_with_pos( + name, + lifetimes, + const_parameters, + structural_invariant, + fields, + position, + ); Ok(variant) } +fn encode_structural_invariant<'v, 'tcx: 'v>( + encoder: &Encoder<'v, 'tcx>, + ty: &vir::Type, + substs: ty::subst::SubstsRef<'tcx>, + did: DefId, +) -> SpannedEncodingResult>> { + let invariant = if let Some(specs) = encoder.get_type_specs(did) { + match &specs.structural_invariant { + prusti_interface::specs::typed::SpecificationItem::Empty => None, + prusti_interface::specs::typed::SpecificationItem::Inherent(invs) => { + Some( + invs.iter() + .map(|inherent_def_id| { + encoder.encode_assertion_high( + *inherent_def_id, + None, + &[vir::Expression::self_variable(ty.clone())], + None, + // true, + *inherent_def_id, + substs, + ) + }) + .collect::, _>>()?, + ) + } + _ => todo!(), + // TODO(inv): handle invariant inheritance + } + } else { + None + }; + Ok(invariant) +} + +/// `with_invariant` is used to break infinite recursion. pub(super) fn encode_adt_def<'v, 'tcx>( encoder: &Encoder<'v, 'tcx>, + ty: &vir::Type, adt_def: ty::AdtDef<'tcx>, substs: ty::subst::SubstsRef<'tcx>, variant_index: Option, + with_invariant: bool, ) -> SpannedEncodingResult { let lifetimes = encoder.get_lifetimes_from_substs(substs)?; let const_parameters = encoder.get_const_parameters_from_substs(substs)?; + let structural_invariant = if with_invariant { + encode_structural_invariant(encoder, ty, substs, adt_def.did())? + } else { + None + }; let tcx = encoder.env().tcx(); if adt_def.is_box() { debug!("ADT {:?} is a box", adt_def); @@ -730,7 +809,9 @@ pub(super) fn encode_adt_def<'v, 'tcx>( encode_box_name(), lifetimes, const_parameters, + structural_invariant, vec![field], + Default::default(), )) } else if adt_def.is_struct() { debug!("ADT {:?} is a struct", adt_def); @@ -738,10 +819,21 @@ pub(super) fn encode_adt_def<'v, 'tcx>( let name = encode_struct_name(encoder, adt_def.did()); let variant = adt_def.non_enum_variant(); Ok(vir::TypeDecl::Struct(encode_variant( - encoder, name, substs, variant, + encoder, + name, + substs, + variant, + structural_invariant, + Some(adt_def.did()), )?)) } else if adt_def.is_union() { debug!("ADT {:?} is a union", adt_def); + if structural_invariant.is_some() { + return Err(SpannedEncodingError::unsupported( + "Structural invariants are not supported on unions", + encoder.env().query.get_def_span(adt_def.did()), + )); + } if !config::unsafe_core_proof() { return Err(SpannedEncodingError::unsupported( "unions are not supported", @@ -765,6 +857,7 @@ pub(super) fn encode_adt_def<'v, 'tcx>( field_name, lifetimes.clone(), const_parameters.clone(), + None, vec![encoded_field], ); variants.push(variant); @@ -781,6 +874,12 @@ pub(super) fn encode_adt_def<'v, 'tcx>( )) } else if adt_def.is_enum() { debug!("ADT {:?} is an enum", adt_def); + if structural_invariant.is_some() { + return Err(SpannedEncodingError::unsupported( + "Structural invariants are not supported on enums", + encoder.env().query.get_def_span(adt_def.did()), + )); + } let name = encode_enum_name(encoder, adt_def.did()); let num_variants = adt_def.variants().len(); debug!("ADT {:?} is enum with {} variants", adt_def, num_variants); @@ -788,7 +887,14 @@ pub(super) fn encode_adt_def<'v, 'tcx>( // FIXME: Currently fold-unfold assumes that everything that // has only a single variant is a struct. let variant = &adt_def.variants()[0usize.into()]; - vir::TypeDecl::Struct(encode_variant(encoder, name, substs, variant)?) + vir::TypeDecl::Struct(encode_variant( + encoder, + name, + substs, + variant, + None, + Default::default(), + )?) } else if let Some(_variant_index) = variant_index { // let variant = &adt_def.variants()[variant_index]; // vir::TypeDecl::Struct(encode_variant(encoder, name, substs, variant)?) @@ -801,7 +907,8 @@ pub(super) fn encode_adt_def<'v, 'tcx>( let mut variants = Vec::new(); for variant in adt_def.variants() { let name = variant.ident(tcx).to_string(); - let encoded_variant = encode_variant(encoder, name, substs, variant)?; + let encoded_variant = + encode_variant(encoder, name, substs, variant, None, Default::default())?; variants.push(encoded_variant); } let mir_discriminant_type = match adt_def.repr().discr_type() { diff --git a/prusti-viper/src/encoder/mir/types/interface.rs b/prusti-viper/src/encoder/mir/types/interface.rs index 982ed6a6259..2fea3b751cf 100644 --- a/prusti-viper/src/encoder/mir/types/interface.rs +++ b/prusti-viper/src/encoder/mir/types/interface.rs @@ -55,11 +55,11 @@ pub(crate) trait MirTypeEncoderInterface<'tcx> { ty: ty::Ty<'tcx>, ) -> SpannedEncodingResult>; fn encode_type_high(&self, ty: ty::Ty<'tcx>) -> SpannedEncodingResult; - fn encode_type_high_with_const_arguments( - &self, - ty: ty::Ty<'tcx>, - const_arguments: &[vir_high::Expression], - ) -> SpannedEncodingResult; + // fn encode_type_high_with_const_arguments( + // &self, + // ty: ty::Ty<'tcx>, + // const_arguments: &[vir_high::Expression], + // ) -> SpannedEncodingResult; fn encode_place_type_high(&self, ty: mir::tcx::PlaceTy<'tcx>) -> EncodingResult; fn encode_enum_variant_index_high( @@ -76,13 +76,14 @@ pub(crate) trait MirTypeEncoderInterface<'tcx> { fn encode_type_def_high( &self, ty: &vir_high::Type, + with_invariant: bool, ) -> SpannedEncodingResult; - fn encode_adt_def( - &self, - adt_def: ty::AdtDef<'tcx>, - substs: ty::subst::SubstsRef<'tcx>, - variant_index: Option, - ) -> SpannedEncodingResult; + // fn encode_adt_def( + // &self, + // adt_def: ty::AdtDef<'tcx>, + // substs: ty::subst::SubstsRef<'tcx>, + // variant_index: Option, + // ) -> SpannedEncodingResult; fn encode_type_bounds_high( &self, var: &vir_high::Expression, @@ -123,7 +124,7 @@ impl<'v, 'tcx: 'v> MirTypeEncoderInterface<'tcx> for super::super::super::Encode use_span: Option, declaration_span: Span, ) -> SpannedEncodingResult { - let type_decl = self.encode_type_def_high(ty)?; + let type_decl = self.encode_type_def_high(ty, false)?; let primary_span = if let Some(use_span) = use_span { use_span } else { @@ -223,15 +224,6 @@ impl<'v, 'tcx: 'v> MirTypeEncoderInterface<'tcx> for super::super::super::Encode Ok(const_parameters) } fn encode_type_high(&self, ty: ty::Ty<'tcx>) -> SpannedEncodingResult { - // FIXME: Remove encode_type_high_with_const_arguments because it is a - // failed attempt. - self.encode_type_high_with_const_arguments(ty, &[]) - } - fn encode_type_high_with_const_arguments( - &self, - ty: ty::Ty<'tcx>, - const_arguments: &[vir_high::Expression], - ) -> SpannedEncodingResult { if !self .mir_type_encoder_state .encoded_types @@ -239,7 +231,7 @@ impl<'v, 'tcx: 'v> MirTypeEncoderInterface<'tcx> for super::super::super::Encode .contains_key(ty.kind()) { let type_encoder = TypeEncoder::new(self, ty); - let encoded_type = type_encoder.encode_type(const_arguments)?; + let encoded_type = type_encoder.encode_type()?; assert!(self .mir_type_encoder_state .encoded_types @@ -342,6 +334,7 @@ impl<'v, 'tcx: 'v> MirTypeEncoderInterface<'tcx> for super::super::super::Encode fn encode_type_def_high( &self, ty: &vir_high::Type, + with_invariant: bool, ) -> SpannedEncodingResult { if !self .mir_type_encoder_state @@ -357,12 +350,15 @@ impl<'v, 'tcx: 'v> MirTypeEncoderInterface<'tcx> for super::super::super::Encode lifetimes, }) => { let encoded_enum = self - .encode_type_def_high(&vir_high::Type::enum_( - name.clone(), - arguments.clone(), - None, - lifetimes.clone(), - ))? + .encode_type_def_high( + &vir_high::Type::enum_( + name.clone(), + arguments.clone(), + None, + lifetimes.clone(), + ), + with_invariant, + )? .unwrap_enum(); vir_high::TypeDecl::Struct(encoded_enum.into_variant(&variant.index).unwrap()) } @@ -373,37 +369,45 @@ impl<'v, 'tcx: 'v> MirTypeEncoderInterface<'tcx> for super::super::super::Encode lifetimes, }) => { let encoded_union = self - .encode_type_def_high(&vir_high::Type::union_( - name.clone(), - arguments.clone(), - None, - lifetimes.clone(), - ))? + .encode_type_def_high( + &vir_high::Type::union_( + name.clone(), + arguments.clone(), + None, + lifetimes.clone(), + ), + with_invariant, + )? .unwrap_union(); vir_high::TypeDecl::Struct(encoded_union.into_variant(&variant.index).unwrap()) } _ => { let original_ty = self.decode_type_high(ty); let type_encoder = TypeEncoder::new(self, original_ty); - type_encoder.encode_type_def_high()? + type_encoder.encode_type_def_high(ty, with_invariant)? } }; - self.mir_type_encoder_state - .encoded_type_decls - .borrow_mut() - .insert(ty.clone(), encoded_type); + if with_invariant { + // Cache only the fully encoded version. + self.mir_type_encoder_state + .encoded_type_decls + .borrow_mut() + .insert(ty.clone(), encoded_type); + } else { + return Ok(encoded_type); + } } let encoded_type = self.mir_type_encoder_state.encoded_type_decls.borrow()[ty].clone(); Ok(encoded_type) } - fn encode_adt_def( - &self, - adt_def: ty::AdtDef<'tcx>, - substs: ty::subst::SubstsRef<'tcx>, - variant_index: Option, - ) -> SpannedEncodingResult { - super::encoder::encode_adt_def(self, adt_def, substs, variant_index) - } + // fn encode_adt_def( + // &self, + // adt_def: ty::AdtDef<'tcx>, + // substs: ty::subst::SubstsRef<'tcx>, + // variant_index: Option, + // ) -> SpannedEncodingResult { + // super::encoder::encode_adt_def(self, ty, adt_def, substs, variant_index) + // } fn encode_type_bounds_high( &self, var: &vir_high::Expression, diff --git a/prusti-viper/src/encoder/procedure_encoder.rs b/prusti-viper/src/encoder/procedure_encoder.rs index cbbd0767652..3c8930397c2 100644 --- a/prusti-viper/src/encoder/procedure_encoder.rs +++ b/prusti-viper/src/encoder/procedure_encoder.rs @@ -149,7 +149,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { let init_info = InitInfo::new(mir, tcx, proc_def_id, &mir_encoder) .with_default_span(procedure.get_span())?; - let specification_blocks = SpecificationBlocks::build(encoder.env().query, mir, procedure, false); + let specification_blocks = SpecificationBlocks::build(encoder.env().query, mir, None, false); let cfg_method = vir::CfgMethod::new( // method name diff --git a/prusti-viper/src/encoder/typed/to_middle/expression.rs b/prusti-viper/src/encoder/typed/to_middle/expression.rs index d75408523f3..fa9f7051fc0 100644 --- a/prusti-viper/src/encoder/typed/to_middle/expression.rs +++ b/prusti-viper/src/encoder/typed/to_middle/expression.rs @@ -1,7 +1,9 @@ use crate::encoder::errors::SpannedEncodingError; use vir_crate::{ middle as vir_mid, - middle::operations::{TypedToMiddleExpressionLowerer, TypedToMiddleType}, + middle::operations::{ + TypedToMiddleExpressionLowerer, TypedToMiddlePredicate, TypedToMiddleType, + }, typed as vir_typed, }; @@ -65,4 +67,11 @@ impl<'v, 'tcx> TypedToMiddleExpressionLowerer for crate::encoder::Encoder<'v, 't index: variant_index.index, }) } + + fn typed_to_middle_expression_predicate( + &self, + predicate: vir_typed::Predicate, + ) -> Result { + predicate.typed_to_middle_predicate(self) + } } diff --git a/prusti-viper/src/encoder/typed/to_middle/statement.rs b/prusti-viper/src/encoder/typed/to_middle/statement.rs index 052d9ca9b73..4ab06f1441f 100644 --- a/prusti-viper/src/encoder/typed/to_middle/statement.rs +++ b/prusti-viper/src/encoder/typed/to_middle/statement.rs @@ -110,4 +110,46 @@ impl<'v, 'tcx> TypedToMiddleStatementLowerer for crate::encoder::Encoder<'v, 'tc ) -> Result { unreachable!("ObtainMutRef statement cannot be lowered"); } + + fn typed_to_middle_statement_statement_unpack( + &self, + _statement: vir_typed::Unpack, + ) -> Result { + unreachable!("Unpack statement cannot be lowered"); + } + + fn typed_to_middle_statement_statement_pack( + &self, + _statement: vir_typed::Pack, + ) -> Result { + unreachable!("Pack statement cannot be lowered"); + } + + fn typed_to_middle_statement_statement_forget_initialization( + &self, + _statement: vir_typed::ForgetInitialization, + ) -> Result { + unreachable!("ForgetInitialization statement cannot be lowered"); + } + + fn typed_to_middle_statement_statement_split( + &self, + _statement: vir_typed::Split, + ) -> Result { + unreachable!("Split statement cannot be lowered"); + } + + fn typed_to_middle_statement_statement_join( + &self, + _statement: vir_typed::Join, + ) -> Result { + unreachable!("Join statement cannot be lowered"); + } + + // fn typed_to_middle_statement_statement_restore( + // &self, + // _statement: vir_typed::Restore, + // ) -> Result { + // unreachable!("Restore statement cannot be lowered"); + // } } diff --git a/prusti-viper/src/encoder/typed/to_middle/type_decl.rs b/prusti-viper/src/encoder/typed/to_middle/type_decl.rs index b60bb88f41c..2895a20c736 100644 --- a/prusti-viper/src/encoder/typed/to_middle/type_decl.rs +++ b/prusti-viper/src/encoder/typed/to_middle/type_decl.rs @@ -75,4 +75,11 @@ impl<'v, 'tcx> TypedToMiddleTypeDeclLowerer for crate::encoder::Encoder<'v, 'tcx vir_typed::ty::EnumSafety::Union => vir_mid::ty::EnumSafety::Union, }) } + + fn typed_to_middle_type_decl_position( + &self, + position: vir_typed::Position, + ) -> Result { + Ok(position) + } } diff --git a/prusti-viper/src/verifier.rs b/prusti-viper/src/verifier.rs index 73b5f371d62..5a3d57a1123 100644 --- a/prusti-viper/src/verifier.rs +++ b/prusti-viper/src/verifier.rs @@ -221,8 +221,8 @@ fn verify_programs(env: &Environment, programs: Vec) let program_name = program.get_name().to_string(); let check_mode = program.get_check_mode(); // Prepend the Rust file name to the program. - program.set_name(format!("{rust_program_name}_{program_name}")); - let backend = if check_mode == CheckMode::Specifications { + program.set_name(format!("{}_{}", rust_program_name, program_name)); + let backend = if check_mode == CheckMode::PurificationFunctional { config::verify_specifications_backend() } else { config::viper_backend() diff --git a/smt-log-analyzer/src/lib.rs b/smt-log-analyzer/src/lib.rs index 518c3cf3dde..518f7c25c71 100644 --- a/smt-log-analyzer/src/lib.rs +++ b/smt-log-analyzer/src/lib.rs @@ -33,7 +33,9 @@ pub struct Settings { fn process_line(settings: &Settings, state: &mut State, line: &str) -> Result<(), Error> { let mut parser = Parser::from_line(line); - match parser.parse_event_kind()? { + let event_kind = parser.parse_event_kind()?; + state.register_event_kind(event_kind); + match event_kind { EventKind::Pop => { let scopes_to_pop = parser.parse_number()?; let active_scopes_count = parser.parse_number()?; @@ -128,7 +130,13 @@ fn process_line(settings: &Settings, state: &mut State, line: &str) -> Result<() EventKind::Instance => { state.register_instance()?; } - EventKind::Unrecognized => {} + EventKind::DecideAndOr => { + let term_id = parser.parse_id()?; + let undef_child_id = parser.parse_id()?; + // FIXME: This information seems to be useless. + state.register_decide_and_or_term(term_id, undef_child_id); + } + _ => {} } Ok(()) } diff --git a/smt-log-analyzer/src/parser.rs b/smt-log-analyzer/src/parser.rs index 0e25e207ee5..cd8e15b2b1a 100644 --- a/smt-log-analyzer/src/parser.rs +++ b/smt-log-analyzer/src/parser.rs @@ -4,6 +4,7 @@ use crate::{ }; use std::str::CharIndices; +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub(crate) enum EventKind { Pop, Push, @@ -15,6 +16,20 @@ pub(crate) enum EventKind { Unrecognized, AttachMeaning, MkVar, + ToolVersion, + AttachVarNames, + MkProof, + AttachEnode, + EndOfInstance, + MkLambda, + BeginCheck, + Assign, + EqExpl, + DecideAndOr, + ResolveLit, + ResolveProcess, + Conflict, + Eof, } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] @@ -137,11 +152,20 @@ impl<'a> Parser<'a> { "inst-discovered" => EventKind::InstDiscovered, "instance" => EventKind::Instance, "attach-meaning" => EventKind::AttachMeaning, - "tool-version" | "attach-var-names" | "mk-proof" | "attach-enode" - | "end-of-instance" | "mk-lambda" | "begin-check" | "assign" | "eq-expl" - | "decide-and-or" | "resolve-lit" | "resolve-process" | "conflict" | "eof" => { - EventKind::Unrecognized - } + "tool-version" => EventKind::ToolVersion, + "attach-var-names" => EventKind::AttachVarNames, + "mk-proof" => EventKind::MkProof, + "attach-enode" => EventKind::AttachEnode, + "end-of-instance" => EventKind::EndOfInstance, + "mk-lambda" => EventKind::MkLambda, + "begin-check" => EventKind::BeginCheck, + "assign" => EventKind::Assign, + "eq-expl" => EventKind::EqExpl, + "decide-and-or" => EventKind::DecideAndOr, + "resolve-lit" => EventKind::ResolveLit, + "resolve-process" => EventKind::ResolveProcess, + "conflict" => EventKind::Conflict, + "eof" => EventKind::Eof, x => unimplemented!("got: {:?}", x), }; self.consume(']')?; diff --git a/smt-log-analyzer/src/state.rs b/smt-log-analyzer/src/state.rs index 1fdacc1fb26..47828bc8111 100644 --- a/smt-log-analyzer/src/state.rs +++ b/smt-log-analyzer/src/state.rs @@ -2,7 +2,7 @@ use csv::Writer; use crate::{ error::Error, - parser::TheoryKind, + parser::{EventKind, TheoryKind}, types::{Level, QuantifierId, TermId, BUILTIN_QUANTIFIER_ID}, }; use std::{ @@ -86,6 +86,8 @@ struct LargestPop { pub(crate) struct State { quantifiers: HashMap, terms: HashMap, + /// Frequencies of each event kind. + event_kind_counters: HashMap, /// The currently matched quantifiers (via [new-match]) at a given level. quantifiers_matched_events: HashMap>, /// The currently discovered quantifiers (via [inst-discovered]) at a given level. @@ -118,9 +120,15 @@ pub(crate) struct State { current_active_scopes_count: Level, traced_quantifier: Option, traced_quantifier_triggers: Option, + decide_and_or_terms: Vec<(TermId, String, TermId, String)>, } impl State { + pub(crate) fn register_event_kind(&mut self, event_kind: EventKind) { + let entry = self.event_kind_counters.entry(event_kind).or_insert(0); + *entry += 1; + } + pub(crate) fn register_label(&mut self, label: String) { self.trace.push(BasicBlockVisitedEvent { level: self.current_active_scopes_count, @@ -300,6 +308,20 @@ impl State { .insert(term_id, Term::AttachMeaning { ident, value }); } + pub(crate) fn register_decide_and_or_term(&mut self, term_id: TermId, undef_child_id: TermId) { + let mut rendered_term = String::new(); + self.render_term(term_id, &mut rendered_term, 30).unwrap(); + let mut rendered_undef_child = String::new(); + self.render_term(undef_child_id, &mut rendered_undef_child, 30) + .unwrap(); + self.decide_and_or_terms.push(( + term_id, + rendered_term, + undef_child_id, + rendered_undef_child, + )); + } + pub(crate) fn active_scopes_count(&self) -> Level { self.current_active_scopes_count } @@ -452,6 +474,16 @@ impl State { } pub(crate) fn write_statistics(&self, input_file: &str) { + { + let mut writer = Writer::from_path(format!("{}.event-kinds.csv", input_file)).unwrap(); + writer.write_record(["Event Kind", "Count"]).unwrap(); + for (event_kind, counter) in &self.event_kind_counters { + writer + .write_record([format!("{:?}", event_kind), counter.to_string()]) + .unwrap(); + } + } + { // [instance] – the number of quantifier instantiations. let mut writer = Writer::from_path(format!("{input_file}.instances.csv")).unwrap(); @@ -614,6 +646,30 @@ impl State { } } + { + // [decide-and-or] – Case splits. + let mut writer = + Writer::from_path(format!("{}.decide-and-or.csv", input_file)).unwrap(); + writer + .write_record([ + "TermId", + "Rendered Term", + "Undef child ID", + "Rendered undef child", + ]) + .unwrap(); + for (term_id, rendered_term, child_id, rendered_child) in &self.decide_and_or_terms { + writer + .write_record(&[ + &term_id.to_string(), + rendered_term, + &child_id.to_string(), + rendered_child, + ]) + .unwrap(); + } + } + { println!( "The largest number of quantifier matches removed in a single \ diff --git a/vir-gen/src/deriver/lower.rs b/vir-gen/src/deriver/lower.rs index 77f18cb8aa0..b06166b3444 100644 --- a/vir-gen/src/deriver/lower.rs +++ b/vir-gen/src/deriver/lower.rs @@ -554,6 +554,20 @@ impl<'a> Deriver<'a> { ty.map(|element| self.#inner_method_name(*element).map(Box::new)).transpose() } } + } else if container_ident_first == "Option" && container_ident_second == "Vec" { + let inner_method_name = self.encode_name(inner_ident); + parse_quote! { + fn #method_name( + #self_parameter, + ty: #container_ident_first < #container_ident_second < #parameter_type > > + ) -> Result< #container_ident_first < #container_ident_second < #return_type > >, Self::Error> { + ty.map(|elements| + elements.into_iter().map(|element| { + self.#inner_method_name(element) + }).collect() + ).transpose() + } + } } else { unimplemented!( "first: {} second: {}", diff --git a/vir/defs/high/ast/expression.rs b/vir/defs/high/ast/expression.rs index fb685c77844..3448445679a 100644 --- a/vir/defs/high/ast/expression.rs +++ b/vir/defs/high/ast/expression.rs @@ -1,6 +1,7 @@ pub(crate) use super::{ field::FieldDecl, position::Position, + predicate::Predicate, ty::{Type, VariantIndex}, variable::VariableDecl, }; @@ -45,6 +46,10 @@ pub enum Expression { /// * field that encodes the variant // FIXME: Is downcast really needed? Isn't variant enough? Downcast(Downcast), + /// An accessibility predicate such as `own`. + AccPredicate(AccPredicate), + /// An unpacking of an accessibility predicate. + Unfolding(Unfolding), } #[display(fmt = "{}", "variable.name")] @@ -233,6 +238,7 @@ pub enum BuiltinFunc { SnapshotEquality, Size, PaddingSize, + Align, Discriminant, LifetimeIncluded, LifetimeIntersect, @@ -249,6 +255,12 @@ pub enum BuiltinFunc { NewInt, Index, Len, + IsNull, + IsValid, // TODO: Delete. + EnsureOwnedPredicate, + // GetSnapshot, + /// Take the inner-most lifetime of a place. + TakeLifetime, } #[display(fmt = "__builtin__{}({})", function, "display::cjoin(arguments)")] @@ -268,3 +280,16 @@ pub struct Downcast { pub field: FieldDecl, pub position: Position, } + +#[display(fmt = "acc({})", predicate)] +pub struct AccPredicate { + pub predicate: Box, + pub position: Position, +} + +#[display(fmt = "unfolding({}, {})", predicate, body)] +pub struct Unfolding { + pub predicate: Box, + pub body: Box, + pub position: Position, +} diff --git a/vir/defs/high/ast/function.rs b/vir/defs/high/ast/function.rs index bd52dccf969..f2a4bbe9d5c 100644 --- a/vir/defs/high/ast/function.rs +++ b/vir/defs/high/ast/function.rs @@ -8,7 +8,7 @@ use crate::common::display; "display::cjoin(parameters)", return_type, "display::foreach!(\" requires {}\n\", pres)", - "display::foreach!(\" ensures {}\n\", pres)", + "display::foreach!(\" ensures {}\n\", posts)", "display::option!(body, \"{{ {} }}\n\", \"\")" )] pub struct FunctionDecl { diff --git a/vir/defs/high/ast/predicate.rs b/vir/defs/high/ast/predicate.rs index a1c1f68fbe7..81b4f2db6eb 100644 --- a/vir/defs/high/ast/predicate.rs +++ b/vir/defs/high/ast/predicate.rs @@ -11,8 +11,11 @@ pub enum Predicate { MemoryBlockStack(MemoryBlockStack), MemoryBlockStackDrop(MemoryBlockStackDrop), MemoryBlockHeap(MemoryBlockHeap), + MemoryBlockHeapRange(MemoryBlockHeapRange), MemoryBlockHeapDrop(MemoryBlockHeapDrop), OwnedNonAliased(OwnedNonAliased), + OwnedRange(OwnedRange), + OwnedSet(OwnedSet), } #[display(fmt = "acc(LifetimeToken({}), {})", lifetime, permission)] @@ -60,6 +63,21 @@ pub struct MemoryBlockHeap { pub position: Position, } +#[display( + fmt = "MemoryBlockHeapRange({}, {}, {}, {})", + address, + size, + start_index, + end_index +)] +pub struct MemoryBlockHeapRange { + pub address: Expression, + pub size: Expression, + pub start_index: Expression, + pub end_index: Expression, + pub position: Position, +} + /// A permission to deallocate a (precisely) matching `MemoryBlockHeap`. #[display(fmt = "MemoryBlockHeapDrop({}, {})", address, size)] pub struct MemoryBlockHeapDrop { @@ -74,3 +92,20 @@ pub struct OwnedNonAliased { pub place: Expression, pub position: Position, } + +/// A range of owned predicates of a specific type. `start_index` is inclusive +/// and `end_index` is exclusive. +#[display(fmt = "OwnedRange({}, {}, {})", address, start_index, end_index)] +pub struct OwnedRange { + pub address: Expression, + pub start_index: Expression, + pub end_index: Expression, + pub position: Position, +} + +/// A set of owned predicates of a specific type. +#[display(fmt = "OwnedSet({})", set)] +pub struct OwnedSet { + pub set: Expression, + pub position: Position, +} diff --git a/vir/defs/high/ast/rvalue.rs b/vir/defs/high/ast/rvalue.rs index 290182bbd9b..0a32fe5b72f 100644 --- a/vir/defs/high/ast/rvalue.rs +++ b/vir/defs/high/ast/rvalue.rs @@ -16,7 +16,7 @@ pub enum Rvalue { // ThreadLocalRef(ThreadLocalRef), AddressOf(AddressOf), Len(Len), - // Cast(Cast), + Cast(Cast), BinaryOp(BinaryOp), CheckedBinaryOp(CheckedBinaryOp), // NullaryOp(NullaryOp), @@ -66,6 +66,13 @@ pub struct Len { pub place: Expression, } +#[display(fmt = "cast({} -> {})", operand, ty)] +pub struct Cast { + // TODO: kind: CastKind, + pub operand: Operand, + pub ty: Type, +} + #[display(fmt = "{}({}, {})", kind, left, right)] pub struct BinaryOp { pub kind: BinaryOpKind, diff --git a/vir/defs/high/ast/statement.rs b/vir/defs/high/ast/statement.rs index e3fc011c1e3..1ea9b059c15 100644 --- a/vir/defs/high/ast/statement.rs +++ b/vir/defs/high/ast/statement.rs @@ -17,13 +17,16 @@ use std::collections::BTreeSet; pub enum Statement { Comment(Comment), OldLabel(OldLabel), - Inhale(Inhale), - Exhale(Exhale), + InhalePredicate(InhalePredicate), + ExhalePredicate(ExhalePredicate), + InhaleExpression(InhaleExpression), + ExhaleExpression(ExhaleExpression), + Assume(Assume), + Assert(Assert), Consume(Consume), Havoc(Havoc), GhostHavoc(GhostHavoc), - Assume(Assume), - Assert(Assert), + HeapHavoc(HeapHavoc), LoopInvariant(LoopInvariant), MovePlace(MovePlace), CopyPlace(CopyPlace), @@ -33,6 +36,16 @@ pub enum Statement { GhostAssign(GhostAssign), LeakAll(LeakAll), SetUnionVariant(SetUnionVariant), + Pack(Pack), + Unpack(Unpack), + Join(Join), + JoinRange(JoinRange), + Split(Split), + SplitRange(SplitRange), + StashRange(StashRange), + StashRangeRestore(StashRangeRestore), + ForgetInitialization(ForgetInitialization), + RestoreRawBorrowed(RestoreRawBorrowed), NewLft(NewLft), EndLft(EndLft), DeadLifetime(DeadLifetime), @@ -59,16 +72,18 @@ pub struct OldLabel { pub position: Position, } -/// Inhale the permission denoted by the place. -#[display(fmt = "inhale {}", predicate)] -pub struct Inhale { +/// Inhale the permission denoted by the place. This operation is automatically +/// managed by fold-unfold. +#[display(fmt = "inhale-pred {}", predicate)] +pub struct InhalePredicate { pub predicate: Predicate, pub position: Position, } -#[display(fmt = "exhale {}", predicate)] -/// Exhale the permission denoted by the place. -pub struct Exhale { +#[display(fmt = "exhale-pred {}", predicate)] +/// Exhale the permission denoted by the place. This operation is automatically +/// managed by fold-unfold. +pub struct ExhalePredicate { pub predicate: Predicate, pub position: Position, } @@ -88,20 +103,41 @@ pub struct Havoc { } #[display(fmt = "ghost-havoc {}", variable)] +/// Havoc the local variable. pub struct GhostHavoc { pub variable: VariableDecl, pub position: Position, } +#[display(fmt = "heap-havoc")] +/// Havoc the heap. +pub struct HeapHavoc { + pub position: Position, +} + +#[display(fmt = "inhale-expr {}", expression)] +/// Inhale the boolean expression. This operation is ignored by fold-unfold. +pub struct InhaleExpression { + pub expression: Expression, + pub position: Position, +} + +#[display(fmt = "exhale-expr {}", expression)] +/// Exhale the boolean expression. This operation is ignored by fold-unfold. +pub struct ExhaleExpression { + pub expression: Expression, + pub position: Position, +} + #[display(fmt = "assume {}", expression)] -/// Assume the boolean expression. +/// Assume the pure boolean expression. pub struct Assume { pub expression: Expression, pub position: Position, } #[display(fmt = "assert {}", expression)] -/// Assert the boolean expression. +/// Assert the pure boolean expression. pub struct Assert { pub expression: Expression, pub position: Position, @@ -252,6 +288,111 @@ pub struct SetUnionVariant { pub position: Position, } +#[derive_helpers] +#[derive(derive_more::From, derive_more::IsVariant, derive_more::Unwrap)] +pub enum PredicateKind { + Owned, + UniqueRef(UniqueRef), + FracRef(FracRef), +} + +pub struct UniqueRef { + pub lifetime: LifetimeConst, +} + +pub struct FracRef { + pub lifetime: LifetimeConst, +} + +#[display(fmt = "pack-{} {}", predicate_kind, place)] +pub struct Pack { + pub place: Expression, + pub predicate_kind: PredicateKind, + pub position: Position, +} + +#[display(fmt = "unpack-{} {}", predicate_kind, place)] +pub struct Unpack { + pub place: Expression, + pub predicate_kind: PredicateKind, + pub position: Position, +} + +#[display(fmt = "join {}", place)] +pub struct Join { + pub place: Expression, + pub position: Position, +} + +#[display(fmt = "join-range {} {} {}", address, start_index, end_index)] +pub struct JoinRange { + pub address: Expression, + pub start_index: Expression, + pub end_index: Expression, + pub position: Position, +} + +#[display(fmt = "split {}", place)] +pub struct Split { + pub place: Expression, + pub position: Position, +} + +#[display(fmt = "split-range {} {} {}", address, start_index, end_index)] +pub struct SplitRange { + pub address: Expression, + pub start_index: Expression, + pub end_index: Expression, + pub position: Position, +} + +#[display( + fmt = "stash-range {} {} {} {}", + address, + start_index, + end_index, + label +)] +pub struct StashRange { + pub address: Expression, + pub start_index: Expression, + pub end_index: Expression, + pub label: String, + pub position: Position, +} + +#[display( + fmt = "stash-range-restore {} {} {} {} → {} {}", + old_address, + old_start_index, + old_end_index, + old_label, + new_address, + new_start_index +)] +pub struct StashRangeRestore { + pub old_address: Expression, + pub old_start_index: Expression, + pub old_end_index: Expression, + pub old_label: String, + pub new_address: Expression, + pub new_start_index: Expression, + pub position: Position, +} + +#[display(fmt = "forget-initialization {}", place)] +pub struct ForgetInitialization { + pub place: Expression, + pub position: Position, +} + +#[display(fmt = "restore {} --* {}", borrowing_place, restored_place)] +pub struct RestoreRawBorrowed { + pub borrowing_place: Expression, + pub restored_place: Expression, + pub position: Position, +} + #[display(fmt = "{} = newlft()", target)] pub struct NewLft { pub target: VariableDecl, diff --git a/vir/defs/high/ast/type_decl.rs b/vir/defs/high/ast/type_decl.rs index f4768418715..49b801d6939 100644 --- a/vir/defs/high/ast/type_decl.rs +++ b/vir/defs/high/ast/type_decl.rs @@ -1,6 +1,7 @@ pub(crate) use super::{ expression::Expression, field::FieldDecl, + position::Position, ty::{GenericType, LifetimeConst, Type, Uniqueness}, variable::VariableDecl, }; @@ -70,7 +71,9 @@ pub struct Struct { pub name: String, pub lifetimes: Vec, pub const_parameters: Vec, + pub structural_invariant: Option>, pub fields: Vec, + pub position: Position, } pub type DiscriminantValue = i128; diff --git a/vir/defs/high/cfg/procedure.rs b/vir/defs/high/cfg/procedure.rs index 757598f4e19..ead1a4f6f84 100644 --- a/vir/defs/high/cfg/procedure.rs +++ b/vir/defs/high/cfg/procedure.rs @@ -1,5 +1,6 @@ use super::super::ast::{ expression::{Expression, Local}, + position::Position, statement::Statement, }; use crate::common::{check_mode::CheckMode, display}; @@ -17,6 +18,7 @@ pub struct ProcedureDecl { pub entry: BasicBlockId, pub exit: BasicBlockId, pub basic_blocks: BTreeMap, + pub position: Position, } #[derive(PartialOrd, Ord, derive_more::Constructor, derive_more::AsRef)] diff --git a/vir/defs/high/mod.rs b/vir/defs/high/mod.rs index 8460b5e512f..c73335c51bd 100644 --- a/vir/defs/high/mod.rs +++ b/vir/defs/high/mod.rs @@ -5,25 +5,28 @@ pub(crate) mod operations_internal; pub use self::{ ast::{ expression::{ - self, visitors, AddrOf, BinaryOp, BinaryOpKind, BuiltinFunc, BuiltinFuncApp, - Conditional, Constant, Constructor, ContainerOp, Deref, Downcast, Expression, Field, - FuncApp, LabelledOld, LetExpr, Local, Quantifier, Seq, Trigger, UnaryOp, UnaryOpKind, - Variant, + self, visitors, AccPredicate, AddrOf, BinaryOp, BinaryOpKind, BuiltinFunc, + BuiltinFuncApp, Conditional, Constant, Constructor, ContainerOp, Deref, Downcast, + Expression, Field, FuncApp, LabelledOld, LetExpr, Local, Quantifier, Seq, Trigger, + UnaryOp, UnaryOpKind, Unfolding, Variant, }, field::FieldDecl, function::FunctionDecl, position::Position, predicate::{ LifetimeToken, MemoryBlockHeap, MemoryBlockHeapDrop, MemoryBlockStack, - MemoryBlockStackDrop, Predicate, + MemoryBlockStackDrop, OwnedNonAliased, Predicate, }, rvalue::{Operand, OperandKind, Rvalue}, statement::{ Assert, Assign, Assume, BorShorten, CloseFracRef, CloseMutRef, Comment, Consume, - CopyPlace, DeadInclusion, DeadLifetime, EndLft, Exhale, GhostAssign, GhostHavoc, Havoc, - Inhale, LeakAll, LifetimeReturn, LifetimeTake, LoopInvariant, MovePlace, NewLft, - ObtainMutRef, OldLabel, OpenFracRef, OpenMutRef, SetUnionVariant, Statement, - WriteAddress, WritePlace, + CopyPlace, DeadInclusion, DeadLifetime, EndLft, ExhaleExpression, ExhalePredicate, + ForgetInitialization, FracRef, GhostAssign, GhostHavoc, Havoc, HeapHavoc, + InhaleExpression, InhalePredicate, Join, JoinRange, LeakAll, LifetimeReturn, + LifetimeTake, LoopInvariant, MovePlace, NewLft, ObtainMutRef, OldLabel, OpenFracRef, + OpenMutRef, Pack, PredicateKind, RestoreRawBorrowed, SetUnionVariant, Split, + SplitRange, StashRange, StashRangeRestore, Statement, UniqueRef, Unpack, WriteAddress, + WritePlace, }, ty::{self, Type}, type_decl::{self, DiscriminantRange, DiscriminantValue, TypeDecl}, diff --git a/vir/defs/high/operations_internal/const_generics/common.rs b/vir/defs/high/operations_internal/const_generics/common.rs index 05d243d6a62..2441234105b 100644 --- a/vir/defs/high/operations_internal/const_generics/common.rs +++ b/vir/defs/high/operations_internal/const_generics/common.rs @@ -41,6 +41,7 @@ impl WithConstArguments for Rvalue { Self::Repeat(value) => value.get_const_arguments(), Self::AddressOf(value) => value.get_const_arguments(), Self::Len(value) => value.get_const_arguments(), + Self::Cast(value) => value.get_const_arguments(), Self::BinaryOp(value) => value.get_const_arguments(), Self::CheckedBinaryOp(value) => value.get_const_arguments(), Self::UnaryOp(value) => value.get_const_arguments(), @@ -82,6 +83,14 @@ impl WithConstArguments for Len { } } +impl WithConstArguments for Cast { + fn get_const_arguments(&self) -> Vec { + let mut arguments = self.operand.get_const_arguments(); + arguments.extend(self.ty.get_const_arguments()); + arguments + } +} + impl WithConstArguments for BinaryOp { fn get_const_arguments(&self) -> Vec { let mut arguments = self.left.get_const_arguments(); diff --git a/vir/defs/high/operations_internal/expression.rs b/vir/defs/high/operations_internal/expression.rs index 72cdd733cdc..8b8bad746de 100644 --- a/vir/defs/high/operations_internal/expression.rs +++ b/vir/defs/high/operations_internal/expression.rs @@ -2,17 +2,18 @@ use super::{ super::ast::{ expression::{ visitors::{ - default_fold_expression, default_fold_quantifier, default_walk_expression, - ExpressionFolder, ExpressionWalker, + default_fold_expression, default_fold_quantifier, default_walk_binary_op, + default_walk_expression, ExpressionFolder, ExpressionWalker, }, *, }, position::Position, + predicate::visitors::{PredicateFolder, PredicateWalker}, ty::{self, visitors::TypeFolder, LifetimeConst, Type}, }, ty::Typed, }; -use crate::common::expression::SyntacticEvaluation; +use crate::common::expression::{ExpressionIterator, SyntacticEvaluation, UnaryOperationHelpers}; use std::collections::BTreeMap; impl From for Expression { @@ -71,6 +72,28 @@ impl Expression { expr => unreachable!("{}", expr), } } + /// Create a new place with the provided parent. + pub fn with_new_parent(&self, new_parent: Self) -> Self { + match self { + Expression::Variant(expression) => Expression::Variant(Variant { + base: box new_parent, + ..expression.clone() + }), + Expression::Field(expression) => Expression::Field(Field { + base: box new_parent, + ..expression.clone() + }), + Expression::Deref(expression) => Expression::Deref(Deref { + base: box new_parent, + ..expression.clone() + }), + Expression::AddrOf(expression) => Expression::AddrOf(AddrOf { + base: box new_parent, + ..expression.clone() + }), + _ => unreachable!("Cannot change parent for {}", self), + } + } /// Only defined for places. pub fn try_into_parent(self) -> Option { debug_assert!(self.is_place()); @@ -180,8 +203,8 @@ impl Expression { } } - /// Check whether the place is a dereference of a reference and if that is - /// the case, return its base. + /// Check whether the place is a dereference if that is the case, return its + /// base. pub fn get_dereference_base(&self) -> Option<&Expression> { assert!(self.is_place()); if let Expression::Deref(Deref { box base, .. }) = self { @@ -193,6 +216,80 @@ impl Expression { } } + /// Check whether the place is a dereference of a reference and if that is + /// the case, return its base. + pub fn get_last_dereferenced_reference(&self) -> Option<&Expression> { + assert!(self.is_place()); + if let Expression::Deref(Deref { box base, .. }) = self { + if let Type::Reference(_) = base.get_type() { + Some(base) + } else { + base.get_last_dereferenced_reference() + } + } else if let Some(parent) = self.get_parent_ref() { + parent.get_last_dereferenced_reference() + } else { + None + } + } + + /// Same as `get_last_dereferenced_reference`, just returns the first + /// reference. + pub fn get_first_dereferenced_reference(&self) -> Option<&Expression> { + assert!(self.is_place()); + if let Expression::Deref(Deref { box base, .. }) = self { + let parent_ref = base.get_first_dereferenced_reference(); + if parent_ref.is_some() { + parent_ref + } else if let Type::Reference(_) = base.get_type() { + Some(base) + } else { + None + } + } else if let Some(parent) = self.get_parent_ref() { + parent.get_first_dereferenced_reference() + } else { + None + } + } + + pub fn is_behind_pointer_dereference(&self) -> bool { + assert!(self.is_place()); + if let Some(parent) = self.get_parent_ref() { + if self.is_deref() && parent.get_type().is_pointer() { + return true; + } + parent.is_behind_pointer_dereference() + } else { + false + } + } + + pub fn get_last_dereferenced_pointer(&self) -> Option<&Expression> { + assert!(self.is_place()); + if let Some(parent) = self.get_parent_ref() { + if self.is_deref() && parent.get_type().is_pointer() { + return Some(parent); + } + parent.get_last_dereferenced_pointer() + } else { + None + } + } + + pub fn get_first_dereferenced_pointer(&self) -> Option<&Expression> { + assert!(self.is_place()); + if let Some(last_pointer) = self.get_last_dereferenced_pointer() { + if let Some(parent) = last_pointer.get_first_dereferenced_pointer() { + Some(parent) + } else { + Some(last_pointer) + } + } else { + None + } + } + #[must_use] pub fn erase_lifetime(self) -> Expression { struct DefaultLifetimeEraser {} @@ -322,12 +419,20 @@ impl Expression { default_fold_expression(self, expression) } } + fn fold_predicate(&mut self, predicate: Predicate) -> Predicate { + PredicateFolder::fold_predicate(self, predicate) + } + } + impl<'a> PredicateFolder for PlaceReplacer<'a> { + fn fold_expression(&mut self, expression: Expression) -> Expression { + ExpressionFolder::fold_expression(self, expression) + } } let mut replacer = PlaceReplacer { target, replacement, }; - replacer.fold_expression(self) + ExpressionFolder::fold_expression(&mut replacer, self) } #[must_use] pub fn replace_multiple_places(self, replacements: &[(Expression, Expression)]) -> Self { @@ -364,8 +469,56 @@ impl Expression { } Expression::Quantifier(default_fold_quantifier(self, quantifier)) } + + fn fold_predicate(&mut self, predicate: Predicate) -> Predicate { + PredicateFolder::fold_predicate(self, predicate) + } + } + impl<'a> PredicateFolder for PlaceReplacer<'a> { + fn fold_expression(&mut self, expression: Expression) -> Expression { + ExpressionFolder::fold_expression(self, expression) + } + } + let mut replacer = PlaceReplacer { replacements }; + ExpressionFolder::fold_expression(&mut replacer, self) + } + #[must_use] + pub fn replace_self(self, replacement: &Expression) -> Self { + struct PlaceReplacer<'a> { + replacement: &'a Expression, + } + impl<'a> ExpressionFolder for PlaceReplacer<'a> { + fn fold_local_enum(&mut self, local: Local) -> Expression { + if local.variable.is_self_variable() { + assert_eq!( + &local.variable.ty, + self.replacement.get_type(), + "{} → {}", + local.variable.ty, + self.replacement + ); + self.replacement.clone() + } else { + Expression::Local(local) + } + } + fn fold_predicate(&mut self, predicate: Predicate) -> Predicate { + PredicateFolder::fold_predicate(self, predicate) + } + } + impl<'a> PredicateFolder for PlaceReplacer<'a> { + fn fold_expression(&mut self, expression: Expression) -> Expression { + ExpressionFolder::fold_expression(self, expression) + } + } + let mut replacer = PlaceReplacer { replacement }; + ExpressionFolder::fold_expression(&mut replacer, self) + } + pub fn peel_unfoldings(&self) -> &Self { + match self { + Expression::Unfolding(unfolding) => unfolding.body.peel_unfoldings(), + _ => self, } - PlaceReplacer { replacements }.fold_expression(self) } #[must_use] pub fn map_old_expression_label(self, substitutor: F) -> Self @@ -545,13 +698,21 @@ impl Expression { default_walk_expression(self, expr) } } + fn walk_predicate(&mut self, predicate: &Predicate) { + PredicateWalker::walk_predicate(self, predicate) + } + } + impl<'a> PredicateWalker for ExprFinder<'a> { + fn walk_expression(&mut self, expr: &Expression) { + ExpressionWalker::walk_expression(self, expr) + } } let mut finder = ExprFinder { sub_target, found: false, }; - finder.walk_expression(self); + ExpressionWalker::walk_expression(&mut finder, self); finder.found } pub fn function_call>( @@ -686,4 +847,187 @@ impl Expression { pub fn full_permission() -> Self { Self::constant_no_pos(ConstantValue::Int(1), Type::MPerm) } + + pub fn is_pure(&self) -> bool { + struct Checker { + is_pure: bool, + } + impl ExpressionWalker for Checker { + fn walk_acc_predicate(&mut self, _: &AccPredicate) { + self.is_pure = false; + } + } + let mut checker = Checker { is_pure: true }; + checker.walk_expression(self); + checker.is_pure + } +} + +/// Methods for collecting places. +impl Expression { + /// Returns place used in `own`. + pub fn collect_owned_places(&self) -> Vec { + struct Collector { + owned_places: Vec, + } + impl<'a> ExpressionWalker for Collector { + fn walk_acc_predicate(&mut self, acc_predicate: &AccPredicate) { + match &*acc_predicate.predicate { + Predicate::LifetimeToken(_) + | Predicate::MemoryBlockStack(_) + | Predicate::MemoryBlockStackDrop(_) + | Predicate::MemoryBlockHeap(_) + | Predicate::MemoryBlockHeapRange(_) + | Predicate::MemoryBlockHeapDrop(_) => {} + Predicate::OwnedNonAliased(predicate) => { + self.owned_places.push(predicate.place.clone()); + } + Predicate::OwnedRange(predicate) => { + unimplemented!("predicate: {}", predicate); + } + Predicate::OwnedSet(predicate) => { + unimplemented!("predicate: {}", predicate); + } + } + } + } + let mut collector = Collector { + owned_places: Vec::new(), + }; + collector.walk_expression(self); + collector.owned_places + } + + /// Returns places used in `own` with path conditions that guard them. + pub fn collect_guarded_owned_places(&self) -> Vec<(Expression, Expression)> { + struct Collector { + path_condition: Vec, + owned_places: Vec<(Expression, Expression)>, + } + impl<'a> ExpressionWalker for Collector { + fn walk_acc_predicate(&mut self, acc_predicate: &AccPredicate) { + match &*acc_predicate.predicate { + Predicate::LifetimeToken(_) + | Predicate::MemoryBlockStack(_) + | Predicate::MemoryBlockStackDrop(_) + | Predicate::MemoryBlockHeap(_) + | Predicate::MemoryBlockHeapRange(_) + | Predicate::MemoryBlockHeapDrop(_) => {} + Predicate::OwnedNonAliased(predicate) => { + self.owned_places.push(( + self.path_condition.iter().cloned().conjoin(), + predicate.place.clone(), + )); + } + Predicate::OwnedRange(predicate) => { + unimplemented!("predicate: {}", predicate); + } + Predicate::OwnedSet(predicate) => { + unimplemented!("predicate: {}", predicate); + } + } + } + fn walk_binary_op(&mut self, binary_op: &BinaryOp) { + if binary_op.op_kind == BinaryOpKind::Implies { + self.path_condition.push((*binary_op.left).clone()); + self.walk_expression(&binary_op.right); + self.path_condition.pop(); + } else { + default_walk_binary_op(self, binary_op); + } + } + fn walk_conditional(&mut self, conditional: &Conditional) { + self.path_condition.push((*conditional.guard).clone()); + self.walk_expression(&conditional.then_expr); + let guard = self.path_condition.pop().unwrap(); + self.path_condition.push((Expression::not(guard))); + self.walk_expression(&conditional.else_expr); + self.path_condition.pop(); + } + } + let mut collector = Collector { + path_condition: Vec::new(), + owned_places: Vec::new(), + }; + collector.walk_expression(self); + collector.owned_places + } + + /// Returns the expression with all pure parts removed and implications + /// converted into conditionals. + /// + /// This method is different from `collect_guarded_owned_places` in that it + /// still returns a single expression preserving most of the original + /// structure. + pub fn convert_into_permission_expression(self) -> Expression { + struct Remover {} + impl<'a> ExpressionFolder for Remover { + fn fold_expression(&mut self, expression: Expression) -> Expression { + if expression.is_pure() { + true.into() + } else { + default_fold_expression(self, expression) + } + } + fn fold_binary_op_enum(&mut self, binary_op: BinaryOp) -> Expression { + if binary_op.op_kind == BinaryOpKind::Implies { + let guard = *binary_op.left; + let then_expr = self.fold_expression(*binary_op.right); + let else_expr = false.into(); + Expression::conditional(guard, then_expr, else_expr, binary_op.position) + } else { + Expression::BinaryOp(self.fold_binary_op(binary_op)) + } + } + } + let mut remover = Remover {}; + remover.fold_expression(self) + } + + /// Returns places that contain dereferences with their path conditions. + pub fn collect_guarded_dereferenced_places(&self) -> Vec<(Expression, Expression)> { + struct Collector { + path_condition: Vec, + deref_places: Vec<(Expression, Expression)>, + } + impl<'a> ExpressionWalker for Collector { + fn walk_expression(&mut self, expression: &Expression) { + if expression.is_place() { + if expression.get_last_dereferenced_pointer().is_some() { + self.deref_places.push(( + self.path_condition.iter().cloned().conjoin(), + expression.clone(), + )); + } + } else { + default_walk_expression(self, expression) + } + } + fn walk_binary_op(&mut self, binary_op: &BinaryOp) { + if binary_op.op_kind == BinaryOpKind::Implies { + self.walk_expression(&binary_op.left); + self.path_condition.push((*binary_op.left).clone()); + self.walk_expression(&binary_op.right); + self.path_condition.pop(); + } else { + default_walk_binary_op(self, binary_op); + } + } + fn walk_conditional(&mut self, conditional: &Conditional) { + self.walk_expression(&conditional.guard); + self.path_condition.push((*conditional.guard).clone()); + self.walk_expression(&conditional.then_expr); + let guard = self.path_condition.pop().unwrap(); + self.path_condition.push((Expression::not(guard))); + self.walk_expression(&conditional.else_expr); + self.path_condition.pop(); + } + } + let mut collector = Collector { + path_condition: Vec::new(), + deref_places: Vec::new(), + }; + collector.walk_expression(self); + collector.deref_places + } } diff --git a/vir/defs/high/operations_internal/helpers.rs b/vir/defs/high/operations_internal/helpers.rs index aeca7d5ae0f..03536c7bcf5 100644 --- a/vir/defs/high/operations_internal/helpers.rs +++ b/vir/defs/high/operations_internal/helpers.rs @@ -117,22 +117,44 @@ impl ConstantHelpers for Expression { impl SyntacticEvaluation for Expression { fn is_true(&self) -> bool { - matches!( - self, + match self { Self::Constant(Constant { value: ConstantValue::Bool(true), .. - }) - ) + }) => true, + Self::UnaryOp(UnaryOp { + op_kind: UnaryOpKind::Not, + argument, + .. + }) => argument.is_false(), + Self::BinaryOp(BinaryOp { + op_kind: BinaryOpKind::Or, + left, + right, + .. + }) => left.is_true() || right.is_true(), + _ => false, + } } fn is_false(&self) -> bool { - matches!( - self, + match self { Self::Constant(Constant { value: ConstantValue::Bool(false), .. - }) - ) + }) => true, + Self::UnaryOp(UnaryOp { + op_kind: UnaryOpKind::Not, + argument, + .. + }) => argument.is_true(), + Self::BinaryOp(BinaryOp { + op_kind: BinaryOpKind::And, + left, + right, + .. + }) => left.is_false() || right.is_false(), + _ => false, + } } fn is_zero(&self) -> bool { matches!( diff --git a/vir/defs/high/operations_internal/identifier/predicate.rs b/vir/defs/high/operations_internal/identifier/predicate.rs index 5f1a5b0427a..99a6be27c01 100644 --- a/vir/defs/high/operations_internal/identifier/predicate.rs +++ b/vir/defs/high/operations_internal/identifier/predicate.rs @@ -14,8 +14,11 @@ impl WithIdentifier for Predicate { Self::MemoryBlockStack(predicate) => predicate.get_identifier(), Self::MemoryBlockStackDrop(predicate) => predicate.get_identifier(), Self::MemoryBlockHeap(predicate) => predicate.get_identifier(), + Self::MemoryBlockHeapRange(predicate) => predicate.get_identifier(), Self::MemoryBlockHeapDrop(predicate) => predicate.get_identifier(), Self::OwnedNonAliased(predicate) => predicate.get_identifier(), + Self::OwnedRange(predicate) => predicate.get_identifier(), + Self::OwnedSet(predicate) => predicate.get_identifier(), } } } @@ -44,6 +47,12 @@ impl WithIdentifier for predicate::MemoryBlockHeap { } } +impl WithIdentifier for predicate::MemoryBlockHeapRange { + fn get_identifier(&self) -> String { + "MemoryBlockHeapRange".to_string() + } +} + impl WithIdentifier for predicate::MemoryBlockHeapDrop { fn get_identifier(&self) -> String { "MemoryBlockHeapDrop".to_string() @@ -55,3 +64,15 @@ impl WithIdentifier for predicate::OwnedNonAliased { format!("OwnedNonAliased${}", self.place.get_type().get_identifier()) } } + +impl WithIdentifier for predicate::OwnedRange { + fn get_identifier(&self) -> String { + format!("OwnedRange${}", self.address.get_type().get_identifier()) + } +} + +impl WithIdentifier for predicate::OwnedSet { + fn get_identifier(&self) -> String { + format!("OwnedSet${}", self.set.get_type().get_identifier()) + } +} diff --git a/vir/defs/high/operations_internal/identifier/rvalue.rs b/vir/defs/high/operations_internal/identifier/rvalue.rs index c9ecdc33215..542a8ad6421 100644 --- a/vir/defs/high/operations_internal/identifier/rvalue.rs +++ b/vir/defs/high/operations_internal/identifier/rvalue.rs @@ -7,6 +7,7 @@ impl WithIdentifier for Rvalue { Self::Repeat(value) => value.get_identifier(), Self::AddressOf(value) => value.get_identifier(), Self::Len(value) => value.get_identifier(), + Self::Cast(value) => value.get_identifier(), Self::BinaryOp(value) => value.get_identifier(), Self::CheckedBinaryOp(value) => value.get_identifier(), Self::UnaryOp(value) => value.get_identifier(), @@ -48,6 +49,16 @@ impl WithIdentifier for Len { } } +impl WithIdentifier for Cast { + fn get_identifier(&self) -> String { + format!( + "Cast${}${}", + self.operand.get_identifier(), + self.ty.get_identifier() + ) + } +} + impl WithIdentifier for UnaryOp { fn get_identifier(&self) -> String { format!("UnaryOp${}${}", self.kind, self.argument.get_identifier()) diff --git a/vir/defs/high/operations_internal/lifetimes/common.rs b/vir/defs/high/operations_internal/lifetimes/common.rs index 7e644d94d9e..b4835a0e102 100644 --- a/vir/defs/high/operations_internal/lifetimes/common.rs +++ b/vir/defs/high/operations_internal/lifetimes/common.rs @@ -51,6 +51,7 @@ impl WithLifetimes for Rvalue { Self::Repeat(value) => value.get_lifetimes(), Self::AddressOf(value) => value.get_lifetimes(), Self::Len(value) => value.get_lifetimes(), + Self::Cast(value) => value.get_lifetimes(), Self::BinaryOp(value) => value.get_lifetimes(), Self::CheckedBinaryOp(value) => value.get_lifetimes(), Self::UnaryOp(value) => value.get_lifetimes(), @@ -96,6 +97,14 @@ impl WithLifetimes for Len { } } +impl WithLifetimes for Cast { + fn get_lifetimes(&self) -> Vec { + let mut lifetimes = self.operand.get_lifetimes(); + lifetimes.extend(self.ty.get_lifetimes()); + lifetimes + } +} + impl WithLifetimes for BinaryOp { fn get_lifetimes(&self) -> Vec { let mut lifetimes = self.left.get_lifetimes(); diff --git a/vir/defs/high/operations_internal/position/expressions.rs b/vir/defs/high/operations_internal/position/expressions.rs index e9e5bd86136..49cb00ca783 100644 --- a/vir/defs/high/operations_internal/position/expressions.rs +++ b/vir/defs/high/operations_internal/position/expressions.rs @@ -22,6 +22,8 @@ impl Positioned for Expression { Self::FuncApp(expression) => expression.position(), Self::Downcast(expression) => expression.position(), Self::BuiltinFuncApp(expression) => expression.position(), + Self::AccPredicate(expression) => expression.position(), + Self::Unfolding(expression) => expression.position(), } } } @@ -133,3 +135,15 @@ impl Positioned for Downcast { self.position } } + +impl Positioned for AccPredicate { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for Unfolding { + fn position(&self) -> Position { + self.position + } +} diff --git a/vir/defs/high/operations_internal/position/mod.rs b/vir/defs/high/operations_internal/position/mod.rs index 3705384ac71..5ae756e1a4c 100644 --- a/vir/defs/high/operations_internal/position/mod.rs +++ b/vir/defs/high/operations_internal/position/mod.rs @@ -1,2 +1,3 @@ mod expressions; mod statement; +mod type_decl; diff --git a/vir/defs/high/operations_internal/position/statement.rs b/vir/defs/high/operations_internal/position/statement.rs index 0b9ae927f4c..a8c51b1dbc7 100644 --- a/vir/defs/high/operations_internal/position/statement.rs +++ b/vir/defs/high/operations_internal/position/statement.rs @@ -6,10 +6,13 @@ impl Positioned for Statement { match self { Self::Comment(statement) => statement.position(), Self::OldLabel(statement) => statement.position(), - Self::Inhale(statement) => statement.position(), - Self::Exhale(statement) => statement.position(), + Self::InhalePredicate(statement) => statement.position(), + Self::ExhalePredicate(statement) => statement.position(), + Self::InhaleExpression(statement) => statement.position(), + Self::ExhaleExpression(statement) => statement.position(), Self::Havoc(statement) => statement.position(), Self::GhostHavoc(statement) => statement.position(), + Self::HeapHavoc(statement) => statement.position(), Self::Assume(statement) => statement.position(), Self::Assert(statement) => statement.position(), Self::LoopInvariant(statement) => statement.position(), @@ -22,6 +25,16 @@ impl Positioned for Statement { Self::Consume(statement) => statement.position(), Self::LeakAll(statement) => statement.position(), Self::SetUnionVariant(statement) => statement.position(), + Self::Pack(statement) => statement.position(), + Self::Unpack(statement) => statement.position(), + Self::Join(statement) => statement.position(), + Self::JoinRange(statement) => statement.position(), + Self::Split(statement) => statement.position(), + Self::SplitRange(statement) => statement.position(), + Self::StashRange(statement) => statement.position(), + Self::StashRangeRestore(statement) => statement.position(), + Self::ForgetInitialization(statement) => statement.position(), + Self::RestoreRawBorrowed(statement) => statement.position(), Self::NewLft(statement) => statement.position(), Self::EndLft(statement) => statement.position(), Self::DeadLifetime(statement) => statement.position(), @@ -50,13 +63,25 @@ impl Positioned for OldLabel { } } -impl Positioned for Inhale { +impl Positioned for InhalePredicate { fn position(&self) -> Position { self.position } } -impl Positioned for Exhale { +impl Positioned for ExhalePredicate { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for InhaleExpression { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for ExhaleExpression { fn position(&self) -> Position { self.position } @@ -74,6 +99,12 @@ impl Positioned for GhostHavoc { } } +impl Positioned for HeapHavoc { + fn position(&self) -> Position { + self.position + } +} + impl Positioned for GhostAssign { fn position(&self) -> Position { self.position @@ -146,6 +177,66 @@ impl Positioned for SetUnionVariant { } } +impl Positioned for Pack { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for Unpack { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for Join { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for JoinRange { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for Split { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for SplitRange { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for StashRange { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for StashRangeRestore { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for ForgetInitialization { + fn position(&self) -> Position { + self.position + } +} + +impl Positioned for RestoreRawBorrowed { + fn position(&self) -> Position { + self.position + } +} + impl Positioned for NewLft { fn position(&self) -> Position { self.position diff --git a/vir/defs/high/operations_internal/position/type_decl.rs b/vir/defs/high/operations_internal/position/type_decl.rs new file mode 100644 index 00000000000..8482ae7b43e --- /dev/null +++ b/vir/defs/high/operations_internal/position/type_decl.rs @@ -0,0 +1,33 @@ +use super::super::super::ast::type_decl::*; +use crate::common::position::Positioned; + +impl Positioned for TypeDecl { + fn position(&self) -> Position { + match self { + Self::Bool => Default::default(), + Self::Int(_) => Default::default(), + Self::Float(_) => Default::default(), + Self::TypeVar(_) => Default::default(), + Self::Tuple(_) => Default::default(), + Self::Struct(decl) => decl.position(), + Self::Sequence(_) => Default::default(), + Self::Map(_) => Default::default(), + Self::Enum(_) => Default::default(), + Self::Union(_) => Default::default(), + Self::Array(_) => Default::default(), + Self::Slice(_) => Default::default(), + Self::Reference(_) => Default::default(), + Self::Pointer(_) => Default::default(), + Self::Never => Default::default(), + Self::Closure(_) => Default::default(), + Self::Unsupported(_) => Default::default(), + Self::Trusted(_) => Default::default(), + } + } +} + +impl Positioned for Struct { + fn position(&self) -> Position { + self.position + } +} diff --git a/vir/defs/high/operations_internal/predicate.rs b/vir/defs/high/operations_internal/predicate.rs index 88f3f66c3ef..3707679039d 100644 --- a/vir/defs/high/operations_internal/predicate.rs +++ b/vir/defs/high/operations_internal/predicate.rs @@ -32,6 +32,14 @@ impl Predicate { predicate.size.get_type().clone(), ] } + Self::MemoryBlockHeapRange(predicate) => { + // FIXME: This is probably wrong: we need to use the type of the + // target. + vec![ + predicate.address.get_type().clone(), + predicate.size.get_type().clone(), + ] + } Self::MemoryBlockHeapDrop(predicate) => { vec![ predicate.address.get_type().clone(), @@ -41,6 +49,16 @@ impl Predicate { Self::OwnedNonAliased(predicate) => { vec![predicate.place.get_type().clone()] } + Self::OwnedRange(predicate) => { + // FIXME: This is probably wrong: we need to use the type of the + // target. + vec![predicate.address.get_type().clone()] + } + Self::OwnedSet(predicate) => { + // FIXME: This is probably wrong: we need to use the type of the + // target of the pointer stored in the set. + vec![predicate.set.get_type().clone()] + } } } pub fn check_no_default_position(&self) { diff --git a/vir/defs/high/operations_internal/special_variables.rs b/vir/defs/high/operations_internal/special_variables.rs index 2a0ceb78d23..7f28606b066 100644 --- a/vir/defs/high/operations_internal/special_variables.rs +++ b/vir/defs/high/operations_internal/special_variables.rs @@ -6,7 +6,27 @@ use super::super::ast::{ variable::VariableDecl, }; +impl VariableDecl { + pub fn self_variable(ty: Type) -> Self { + VariableDecl::new("self$", ty) + } + pub fn is_self_variable(&self) -> bool { + self.name == "self$" + } +} + impl Expression { + pub fn self_variable(ty: Type) -> Self { + let variable = VariableDecl::self_variable(ty); + Expression::local_no_pos(variable) + } + pub fn is_self_variable(&self) -> bool { + if let Expression::Local(Local { variable, .. }) = self { + variable.is_self_variable() + } else { + false + } + } pub fn discriminant() -> Self { let variable = VariableDecl::new("discriminant$", Type::MInt); Expression::local_no_pos(variable) diff --git a/vir/defs/high/operations_internal/ty.rs b/vir/defs/high/operations_internal/ty.rs index 6ca408ac373..36d46391cc0 100644 --- a/vir/defs/high/operations_internal/ty.rs +++ b/vir/defs/high/operations_internal/ty.rs @@ -200,6 +200,13 @@ impl Type { _ => false, } } + pub fn is_unique_reference(&self) -> bool { + if let Type::Reference(Reference { uniqueness, .. }) = self { + uniqueness.is_unique() + } else { + false + } + } } impl AsRef for VariantIndex { @@ -370,6 +377,8 @@ impl Typed for Expression { Expression::FuncApp(expression) => expression.get_type(), Expression::BuiltinFuncApp(expression) => expression.get_type(), Expression::Downcast(expression) => expression.get_type(), + Expression::AccPredicate(expression) => expression.get_type(), + Expression::Unfolding(expression) => expression.get_type(), } } fn set_type(&mut self, new_type: Type) { @@ -392,6 +401,8 @@ impl Typed for Expression { Expression::FuncApp(expression) => expression.set_type(new_type), Expression::BuiltinFuncApp(expression) => expression.set_type(new_type), Expression::Downcast(expression) => expression.set_type(new_type), + Expression::AccPredicate(expression) => expression.set_type(new_type), + Expression::Unfolding(expression) => expression.set_type(new_type), } } } @@ -502,6 +513,22 @@ impl Typed for BinaryOp { } } fn set_type(&mut self, new_type: Type) { + assert!( + !matches!( + self.op_kind, + BinaryOpKind::EqCmp + | BinaryOpKind::NeCmp + | BinaryOpKind::GtCmp + | BinaryOpKind::GeCmp + | BinaryOpKind::LtCmp + | BinaryOpKind::LeCmp + | BinaryOpKind::And + | BinaryOpKind::Or + | BinaryOpKind::Implies + ), + "cannot change the type of {:?}", + self.op_kind + ); self.left.set_type(new_type.clone()); self.right.set_type(new_type); } @@ -582,3 +609,21 @@ impl Typed for Downcast { self.base.set_type(new_type); } } + +impl Typed for AccPredicate { + fn get_type(&self) -> &Type { + &Type::Bool + } + fn set_type(&mut self, _new_type: Type) { + unreachable!(); + } +} + +impl Typed for Unfolding { + fn get_type(&self) -> &Type { + self.body.get_type() + } + fn set_type(&mut self, new_type: Type) { + self.body.set_type(new_type) + } +} diff --git a/vir/defs/high/operations_internal/type_decl.rs b/vir/defs/high/operations_internal/type_decl.rs index 687e2fc60b9..973ac34862a 100644 --- a/vir/defs/high/operations_internal/type_decl.rs +++ b/vir/defs/high/operations_internal/type_decl.rs @@ -4,6 +4,12 @@ use super::super::ast::{ type_decl::{Enum, Struct, Trusted, Tuple, TypeDecl, Union}, }; +impl Struct { + pub fn is_manually_managed_type(&self) -> bool { + self.structural_invariant.is_some() + } +} + impl Enum { pub fn variant(&self, variant_name: &str) -> Option<&Struct> { self.variants diff --git a/vir/defs/low/ast/expression.rs b/vir/defs/low/ast/expression.rs index 90cefcd028c..2a47d833167 100644 --- a/vir/defs/low/ast/expression.rs +++ b/vir/defs/low/ast/expression.rs @@ -3,7 +3,7 @@ use crate::common::display; #[derive_helpers] #[derive_visitors] -#[derive(derive_more::From, derive_more::IsVariant)] +#[derive(derive_more::From, derive_more::IsVariant, derive_more::Unwrap)] pub enum Expression { /// A Viper variable. /// @@ -90,17 +90,9 @@ pub struct FieldAccessPredicate { pub position: Position, } -#[display( - fmt = "(unfolding acc({}({}), {}) in {})", - predicate, - "display::cjoin(arguments)", - permission, - base -)] +#[display(fmt = "(unfolding {} in {})", predicate, base)] pub struct Unfolding { - pub predicate: String, - pub arguments: Vec, - pub permission: Box, + pub predicate: PredicateAccessPredicate, pub base: Box, pub position: Position, } diff --git a/vir/defs/low/ast/function.rs b/vir/defs/low/ast/function.rs index c4784dee55c..c48e33c576f 100644 --- a/vir/defs/low/ast/function.rs +++ b/vir/defs/low/ast/function.rs @@ -4,6 +4,7 @@ use crate::common::display; pub enum FunctionKind { MemoryBlockBytes, CallerFor, + Snap, } #[display( @@ -13,7 +14,7 @@ pub enum FunctionKind { "display::cjoin(parameters)", return_type, "display::foreach!(\" requires {}\n\", pres)", - "display::foreach!(\" ensures {}\n\", pres)", + "display::foreach!(\" ensures {}\n\", posts)", "display::option!(body, \"{{ {} }}\n\", \"\")" )] pub struct FunctionDecl { diff --git a/vir/defs/low/ast/predicate.rs b/vir/defs/low/ast/predicate.rs index dbecbc3ad5d..f9653444478 100644 --- a/vir/defs/low/ast/predicate.rs +++ b/vir/defs/low/ast/predicate.rs @@ -1,14 +1,23 @@ use super::{expression::Expression, variable::VariableDecl}; use crate::common::display; +pub enum PredicateKind { + MemoryBlock, + Owned, + WithoutSnapshotFrac, + WithoutSnapshotWhole, +} + #[display( - fmt = "predicate {}({}){}\n", + fmt = "predicate<{}> {}({}){}\n", + kind, "name", "display::cjoin(parameters)", "display::option!(body, \" {{\n {}\n}}\", \";\")" )] pub struct PredicateDecl { pub name: String, + pub kind: PredicateKind, pub parameters: Vec, pub body: Option, } diff --git a/vir/defs/low/ast/statement.rs b/vir/defs/low/ast/statement.rs index 1c2f9756af7..3973ac1305a 100644 --- a/vir/defs/low/ast/statement.rs +++ b/vir/defs/low/ast/statement.rs @@ -6,6 +6,7 @@ use crate::common::display; #[derive(derive_more::From, derive_more::IsVariant)] pub enum Statement { Comment(Comment), + Label(Label), LogEvent(LogEvent), Assume(Assume), Assert(Assert), @@ -24,10 +25,17 @@ pub struct Comment { pub comment: String, } +#[display(fmt = "label {}", label)] +pub struct Label { + pub label: String, + pub position: Position, +} + #[display(fmt = "log-event {}", expression)] /// Log an event by assuming a (fresh) domain function. pub struct LogEvent { pub expression: Expression, + pub position: Position, } #[display(fmt = "assume {}", expression)] diff --git a/vir/defs/low/cfg/procedure.rs b/vir/defs/low/cfg/procedure.rs index c1c48658d07..dcc5923ecb8 100644 --- a/vir/defs/low/cfg/procedure.rs +++ b/vir/defs/low/cfg/procedure.rs @@ -1,33 +1,39 @@ use crate::{ common::display, - low::ast::{expression::Expression, statement::Statement, variable::VariableDecl}, + low::ast::{ + expression::Expression, position::Position, statement::Statement, variable::VariableDecl, + }, }; +use std::collections::BTreeMap; #[display( fmt = "procedure {} {{\n{}\n{}}}\n", name, "display::foreach!(\" var {};\n\", locals)", - "display::foreach!(\"{}\n\", basic_blocks)" + "display::foreach2!(\" label {}\n{}\", basic_blocks.keys(), basic_blocks.values())" )] pub struct ProcedureDecl { pub name: String, pub locals: Vec, - pub basic_blocks: Vec, + pub custom_labels: Vec