From 7e06a2c870682c9a0e24220e66f7fe6f043829e7 Mon Sep 17 00:00:00 2001 From: Vytautas Astrauskas Date: Mon, 13 Mar 2023 14:32:06 +0100 Subject: [PATCH] Ugly commit 3ei5. --- .../prusti-contracts-proc-macros/src/lib.rs | 12 ++++++++++++ prusti-contracts/prusti-contracts/src/lib.rs | 9 +++++++++ prusti-contracts/prusti-specs/src/lib.rs | 4 ++++ .../high/procedures/inference/semantics.rs | 16 ++++++++++++++++ .../high/procedures/inference/visitor/mod.rs | 3 +++ .../src/encoder/mir/procedures/encoder/mod.rs | 15 +++++++++++++++ .../encoder/mir/procedures/passes/assertions.rs | 1 + .../src/encoder/typed/to_middle/statement.rs | 7 +++++++ vir/defs/high/ast/statement.rs | 8 ++++++++ vir/defs/high/mod.rs | 8 ++++---- .../operations_internal/position/statement.rs | 7 +++++++ vir/defs/typed/mod.rs | 6 +++--- 12 files changed, 89 insertions(+), 7 deletions(-) diff --git a/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs b/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs index 4e3cee0f408..9bb6715cc7d 100644 --- a/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs +++ b/prusti-contracts/prusti-contracts-proc-macros/src/lib.rs @@ -160,6 +160,12 @@ pub fn unpack(_tokens: TokenStream) -> TokenStream { TokenStream::new() } +#[cfg(not(feature = "prusti"))] +#[proc_macro] +pub fn obtain(_tokens: TokenStream) -> TokenStream { + TokenStream::new() +} + #[cfg(not(feature = "prusti"))] #[proc_macro] pub fn pack_ref(_tokens: TokenStream) -> TokenStream { @@ -453,6 +459,12 @@ pub fn unpack(tokens: TokenStream) -> TokenStream { prusti_specs::unpack(tokens.into()).into() } +#[cfg(feature = "prusti")] +#[proc_macro] +pub fn obtain(tokens: TokenStream) -> TokenStream { + prusti_specs::obtain(tokens.into()).into() +} + #[cfg(feature = "prusti")] #[proc_macro] pub fn pack_ref(tokens: TokenStream) -> TokenStream { diff --git a/prusti-contracts/prusti-contracts/src/lib.rs b/prusti-contracts/prusti-contracts/src/lib.rs index 19d9608aa2c..26b6a54b76e 100644 --- a/prusti-contracts/prusti-contracts/src/lib.rs +++ b/prusti-contracts/prusti-contracts/src/lib.rs @@ -83,6 +83,9 @@ pub use prusti_contracts_proc_macros::pack; /// A macro to manually unpack a place capability. pub use prusti_contracts_proc_macros::unpack; +/// Tell Prusti to obtain the specified capability. +pub use prusti_contracts_proc_macros::obtain; + /// A macro to manually pack a place capability. pub use prusti_contracts_proc_macros::pack_ref; @@ -490,6 +493,12 @@ pub fn prusti_unpack_place(_arg: T) { unreachable!(); } +#[doc(hidden)] +#[trusted] +pub fn prusti_obtain_place(_arg: T) { + unreachable!(); +} + #[doc(hidden)] #[trusted] pub fn prusti_pack_ref_place(_lifetime_name: &'static str, _arg: T) { diff --git a/prusti-contracts/prusti-specs/src/lib.rs b/prusti-contracts/prusti-specs/src/lib.rs index 264ded12601..e4281199d2d 100644 --- a/prusti-contracts/prusti-specs/src/lib.rs +++ b/prusti-contracts/prusti-specs/src/lib.rs @@ -1198,6 +1198,10 @@ pub fn unpack(tokens: TokenStream) -> TokenStream { generate_place_function(tokens, quote! {prusti_unpack_place}) } +pub fn obtain(tokens: TokenStream) -> TokenStream { + generate_place_function(tokens, quote! {prusti_obtain_place}) +} + pub fn pack_ref(tokens: TokenStream) -> TokenStream { // generate_place_function(tokens, quote! {prusti_pack_ref_place}) pack_unpack_ref(tokens, quote! {prusti_pack_ref_place}) diff --git a/prusti-viper/src/encoder/high/procedures/inference/semantics.rs b/prusti-viper/src/encoder/high/procedures/inference/semantics.rs index d3ea5795dc0..146dc56fb8d 100644 --- a/prusti-viper/src/encoder/high/procedures/inference/semantics.rs +++ b/prusti-viper/src/encoder/high/procedures/inference/semantics.rs @@ -129,6 +129,9 @@ impl CollectPermissionChanges for vir_typed::Statement { vir_typed::Statement::Unpack(statement) => { statement.collect(encoder, consumed_permissions, produced_permissions) } + vir_typed::Statement::Obtain(statement) => { + statement.collect(encoder, consumed_permissions, produced_permissions) + } vir_typed::Statement::Join(statement) => { statement.collect(encoder, consumed_permissions, produced_permissions) } @@ -845,6 +848,19 @@ impl CollectPermissionChanges for vir_typed::Unpack { } } +impl CollectPermissionChanges for vir_typed::Obtain { + fn collect<'v, 'tcx>( + &self, + encoder: &mut Encoder<'v, 'tcx>, + consumed_permissions: &mut Vec, + produced_permissions: &mut Vec, + ) -> SpannedEncodingResult<()> { + consumed_permissions.push(Permission::Owned(self.place.clone())); + produced_permissions.push(Permission::Owned(self.place.clone())); + Ok(()) + } +} + impl CollectPermissionChanges for vir_typed::Join { fn collect<'v, 'tcx>( &self, diff --git a/prusti-viper/src/encoder/high/procedures/inference/visitor/mod.rs b/prusti-viper/src/encoder/high/procedures/inference/visitor/mod.rs index 26cfb990a11..36b49fa2337 100644 --- a/prusti-viper/src/encoder/high/procedures/inference/visitor/mod.rs +++ b/prusti-viper/src/encoder/high/procedures/inference/visitor/mod.rs @@ -378,6 +378,9 @@ impl<'p, 'v, 'tcx> Visitor<'p, 'v, 'tcx> { }; self.current_statements.push(encoded_statement); } + vir_typed::Statement::Obtain(_) => { + // Nothing to do because the fold-unfold already handled it. + } vir_typed::Statement::Join(join_statement) => { let position = join_statement.position(); let place = join_statement diff --git a/prusti-viper/src/encoder/mir/procedures/encoder/mod.rs b/prusti-viper/src/encoder/mir/procedures/encoder/mod.rs index 496e57d432c..06f8606bb63 100644 --- a/prusti-viper/src/encoder/mir/procedures/encoder/mod.rs +++ b/prusti-viper/src/encoder/mir/procedures/encoder/mod.rs @@ -3257,6 +3257,21 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { encoded_statements.push(statement); Ok(true) } + "prusti_contracts::prusti_obtain_place" => { + let encoded_place = extract_place(self.mir, args, block, self)?; + let statement = self.encoder.set_statement_error_ctxt( + vir_high::Statement::obtain_no_pos( + encoded_place, + vir_high::PredicateKind::Owned, + ), + span, + ErrorCtxt::Unpack, + self.def_id, + )?; + statement.check_no_default_position(); + encoded_statements.push(statement); + Ok(true) + } "prusti_contracts::prusti_pack_ref_place" => { assert_eq!(args.len(), 2); let mut encoded_args = extract_args(self.mir, args, block, self)?; diff --git a/prusti-viper/src/encoder/mir/procedures/passes/assertions.rs b/prusti-viper/src/encoder/mir/procedures/passes/assertions.rs index 02460605fa0..e608932e955 100644 --- a/prusti-viper/src/encoder/mir/procedures/passes/assertions.rs +++ b/prusti-viper/src/encoder/mir/procedures/passes/assertions.rs @@ -64,6 +64,7 @@ pub(in super::super) fn propagate_assertions_back<'v, 'tcx: 'v>( | vir_high::Statement::BorShorten(_) => true, vir_high::Statement::Pack(_) | vir_high::Statement::Unpack(_) + | vir_high::Statement::Obtain(_) | vir_high::Statement::Join(_) | vir_high::Statement::JoinRange(_) | vir_high::Statement::Split(_) diff --git a/prusti-viper/src/encoder/typed/to_middle/statement.rs b/prusti-viper/src/encoder/typed/to_middle/statement.rs index d47474150df..416458ab20b 100644 --- a/prusti-viper/src/encoder/typed/to_middle/statement.rs +++ b/prusti-viper/src/encoder/typed/to_middle/statement.rs @@ -125,6 +125,13 @@ impl<'v, 'tcx> TypedToMiddleStatementLowerer for crate::encoder::Encoder<'v, 'tc unreachable!("Pack statement cannot be lowered"); } + fn typed_to_middle_statement_statement_obtain( + &self, + _: vir_typed::Obtain, + ) -> Result { + unreachable!("Obtain statement cannot be lowered"); + } + fn typed_to_middle_statement_statement_forget_initialization( &self, _statement: vir_typed::ForgetInitialization, diff --git a/vir/defs/high/ast/statement.rs b/vir/defs/high/ast/statement.rs index 25a469b4f2d..209c43e9dd9 100644 --- a/vir/defs/high/ast/statement.rs +++ b/vir/defs/high/ast/statement.rs @@ -38,6 +38,7 @@ pub enum Statement { SetUnionVariant(SetUnionVariant), Pack(Pack), Unpack(Unpack), + Obtain(Obtain), Join(Join), JoinRange(JoinRange), Split(Split), @@ -316,6 +317,13 @@ pub struct Unpack { pub position: Position, } +#[display(fmt = "obtain-{} {}", predicate_kind, place)] +pub struct Obtain { + pub place: Expression, + pub predicate_kind: PredicateKind, + pub position: Position, +} + #[display(fmt = "join {}", place)] pub struct Join { pub place: Expression, diff --git a/vir/defs/high/mod.rs b/vir/defs/high/mod.rs index 996d5562f3f..1257851a709 100644 --- a/vir/defs/high/mod.rs +++ b/vir/defs/high/mod.rs @@ -23,10 +23,10 @@ pub use self::{ CopyPlace, DeadInclusion, DeadLifetime, DeadReference, EndLft, ExhaleExpression, ExhalePredicate, ForgetInitialization, FracRef, GhostAssign, GhostHavoc, Havoc, HeapHavoc, InhaleExpression, InhalePredicate, Join, JoinRange, LeakAll, LifetimeReturn, - LifetimeTake, LoopInvariant, MovePlace, NewLft, ObtainMutRef, OldLabel, OpenFracRef, - OpenMutRef, Pack, PredicateKind, RestoreRawBorrowed, SetUnionVariant, Split, - SplitRange, StashRange, StashRangeRestore, Statement, UniqueRef, Unpack, WriteAddress, - WritePlace, + LifetimeTake, LoopInvariant, MovePlace, NewLft, Obtain, ObtainMutRef, OldLabel, + OpenFracRef, OpenMutRef, Pack, PredicateKind, RestoreRawBorrowed, SetUnionVariant, + Split, SplitRange, StashRange, StashRangeRestore, Statement, UniqueRef, Unpack, + WriteAddress, WritePlace, }, ty::{self, Type}, type_decl::{self, DiscriminantRange, DiscriminantValue, TypeDecl}, diff --git a/vir/defs/high/operations_internal/position/statement.rs b/vir/defs/high/operations_internal/position/statement.rs index 35557c593ed..8749160e600 100644 --- a/vir/defs/high/operations_internal/position/statement.rs +++ b/vir/defs/high/operations_internal/position/statement.rs @@ -27,6 +27,7 @@ impl Positioned for Statement { Self::SetUnionVariant(statement) => statement.position(), Self::Pack(statement) => statement.position(), Self::Unpack(statement) => statement.position(), + Self::Obtain(statement) => statement.position(), Self::Join(statement) => statement.position(), Self::JoinRange(statement) => statement.position(), Self::Split(statement) => statement.position(), @@ -190,6 +191,12 @@ impl Positioned for Unpack { } } +impl Positioned for Obtain { + fn position(&self) -> Position { + self.position + } +} + impl Positioned for Join { fn position(&self) -> Position { self.position diff --git a/vir/defs/typed/mod.rs b/vir/defs/typed/mod.rs index cb3f5a19c6a..7302c8d7cf6 100644 --- a/vir/defs/typed/mod.rs +++ b/vir/defs/typed/mod.rs @@ -24,9 +24,9 @@ pub use self::{ CopyPlace, DeadInclusion, DeadLifetime, DeadReference, EndLft, ExhaleExpression, ExhalePredicate, ForgetInitialization, GhostAssign, GhostHavoc, Havoc, HeapHavoc, InhaleExpression, InhalePredicate, Join, JoinRange, LeakAll, LifetimeReturn, - LifetimeTake, LoopInvariant, MovePlace, NewLft, ObtainMutRef, OldLabel, OpenFracRef, - OpenMutRef, Pack, RestoreRawBorrowed, SetUnionVariant, Split, SplitRange, StashRange, - StashRangeRestore, Statement, Unpack, WriteAddress, WritePlace, + LifetimeTake, LoopInvariant, MovePlace, NewLft, Obtain, ObtainMutRef, OldLabel, + OpenFracRef, OpenMutRef, Pack, RestoreRawBorrowed, SetUnionVariant, Split, SplitRange, + StashRange, StashRangeRestore, Statement, Unpack, WriteAddress, WritePlace, }, ty::{self, Type}, type_decl::{self, DiscriminantRange, DiscriminantValue, TypeDecl},