diff --git a/Cargo.lock b/Cargo.lock index c05b071396d..100df83d4ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2558,6 +2558,7 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", diff --git a/mir-ssa-analysis/src/lib.rs b/mir-ssa-analysis/src/lib.rs index 0671abbe7e3..0214618ba9b 100644 --- a/mir-ssa-analysis/src/lib.rs +++ b/mir-ssa-analysis/src/lib.rs @@ -1,9 +1,6 @@ #![feature(rustc_private)] -use prusti_rustc_interface::{ - index::IndexVec, - middle::mir, -}; +use prusti_rustc_interface::{index::IndexVec, middle::mir}; use std::collections::HashMap; pub type SsaVersion = usize; @@ -34,9 +31,15 @@ impl SsaAnalysis { let local_count = body.local_decls.len(); let block_count = body.basic_blocks.reverse_postorder().len(); - let block_started = std::iter::repeat(false).take(block_count).collect::>(); - let initial_version_in_block = std::iter::repeat(None).take(block_count).collect::>(); - let final_version_in_block = std::iter::repeat(None).take(block_count).collect::>(); + let block_started = std::iter::repeat(false) + .take(block_count) + .collect::>(); + let initial_version_in_block = std::iter::repeat(None) + .take(block_count) + .collect::>(); + let final_version_in_block = std::iter::repeat(None) + .take(block_count) + .collect::>(); let mut ssa_visitor = SsaVisitor { version_counter: IndexVec::from_raw(vec![0; local_count]), @@ -69,11 +72,7 @@ struct SsaVisitor { } impl<'tcx> SsaVisitor { - fn walk_block<'a>( - &mut self, - body: &'a mir::Body<'tcx>, - block: mir::BasicBlock, - ) { + fn walk_block<'a>(&mut self, body: &'a mir::Body<'tcx>, block: mir::BasicBlock) { if self.final_version_in_block[block.as_usize()].is_some() { return; } @@ -95,7 +94,9 @@ impl<'tcx> SsaVisitor { // TODO: cfg cycles prev_versions.push(( *pred, - self.final_version_in_block[pred.as_usize()].as_ref().unwrap()[local.into()], + self.final_version_in_block[pred.as_usize()] + .as_ref() + .unwrap()[local.into()], )); } if prev_versions.is_empty() { @@ -117,7 +118,9 @@ impl<'tcx> SsaVisitor { assert!(self.analysis.phi.insert(block, phis).is_none()); } - assert!(self.initial_version_in_block[block.as_usize()].replace(initial_versions.clone()).is_none()); + assert!(self.initial_version_in_block[block.as_usize()] + .replace(initial_versions.clone()) + .is_none()); use mir::visit::Visitor; self.last_version = initial_versions; @@ -127,12 +130,14 @@ impl<'tcx> SsaVisitor { .map(|local| self.last_version[local.into()]) .collect::>(); for local in 0..self.local_count { - self.analysis.version.insert(( - body.terminator_loc(block), - local.into(), - ), final_versions[local.into()]); + self.analysis.version.insert( + (body.terminator_loc(block), local.into()), + final_versions[local.into()], + ); } - assert!(self.final_version_in_block[block.as_usize()].replace(final_versions).is_none()); + assert!(self.final_version_in_block[block.as_usize()] + .replace(final_versions) + .is_none()); use prusti_rustc_interface::data_structures::graph::WithSuccessors; for succ in body.basic_blocks.successors(block) { @@ -152,7 +157,9 @@ impl<'tcx> mir::visit::Visitor<'tcx> for SsaVisitor { ) { let local = place.local; - assert!(self.analysis.version + assert!(self + .analysis + .version .insert((location, local), self.last_version[local]) .is_none()); @@ -161,12 +168,17 @@ impl<'tcx> mir::visit::Visitor<'tcx> for SsaVisitor { let new_version = self.version_counter[local] + 1; self.version_counter[local] = new_version; self.last_version[local] = new_version; - assert!(self.analysis.updates - .insert(location, SsaUpdate { - local, - old_version, - new_version, - }) + assert!(self + .analysis + .updates + .insert( + location, + SsaUpdate { + local, + old_version, + new_version, + } + ) .is_none()); } } diff --git a/mir-state-analysis/src/free_pcs/impl/local.rs b/mir-state-analysis/src/free_pcs/impl/local.rs index 9067adb669c..e0a09437d86 100644 --- a/mir-state-analysis/src/free_pcs/impl/local.rs +++ b/mir-state-analysis/src/free_pcs/impl/local.rs @@ -42,7 +42,9 @@ impl Default for CapabilityLocal<'_> { impl<'tcx> CapabilityLocal<'tcx> { pub fn get_allocated_mut(&mut self) -> &mut CapabilityProjections<'tcx> { - let Self::Allocated(cps) = self else { panic!("Expected allocated local") }; + let Self::Allocated(cps) = self else { + panic!("Expected allocated local") + }; cps } pub fn new(local: Local, perm: CapabilityKind) -> Self { @@ -198,8 +200,7 @@ impl<'tcx> CapabilityProjections<'tcx> { } let mut ops = Vec::new(); for (to, from, _) in collapsed { - let removed_perms: Vec<_> = - old_caps.extract_if(|old, _| to.is_prefix(*old)).collect(); + let removed_perms: Vec<_> = old_caps.extract_if(|old, _| to.is_prefix(*old)).collect(); let perm = removed_perms .iter() .fold(CapabilityKind::Exclusive, |acc, (_, p)| { diff --git a/mir-state-analysis/src/free_pcs/impl/triple.rs b/mir-state-analysis/src/free_pcs/impl/triple.rs index 91370d382be..da6ba700a6c 100644 --- a/mir-state-analysis/src/free_pcs/impl/triple.rs +++ b/mir-state-analysis/src/free_pcs/impl/triple.rs @@ -58,7 +58,8 @@ impl<'tcx> Visitor<'tcx> for Fpcs<'_, 'tcx> { self.ensures_unalloc(local); } &Retag(_, box place) => self.requires_exclusive(place), - AscribeUserType(..) | PlaceMention(..) | Coverage(..) | Intrinsic(..) | ConstEvalCounter | Nop => (), + AscribeUserType(..) | PlaceMention(..) | Coverage(..) | Intrinsic(..) + | ConstEvalCounter | Nop => (), }; } @@ -88,11 +89,19 @@ impl<'tcx> Visitor<'tcx> for Fpcs<'_, 'tcx> { } } } - &Drop { place, replace: false, .. } => { + &Drop { + place, + replace: false, + .. + } => { self.requires_write(place); self.ensures_write(place); } - &Drop { place, replace: true, .. } => { + &Drop { + place, + replace: true, + .. + } => { self.requires_write(place); self.ensures_exclusive(place); } diff --git a/mir-state-analysis/src/utils/repacker.rs b/mir-state-analysis/src/utils/repacker.rs index 21a88b281d8..0a2cd4c317d 100644 --- a/mir-state-analysis/src/utils/repacker.rs +++ b/mir-state-analysis/src/utils/repacker.rs @@ -11,8 +11,7 @@ use prusti_rustc_interface::{ index::bit_set::BitSet, middle::{ mir::{ - tcx::PlaceTy, Body, HasLocalDecls, Local, Mutability, Place as MirPlace, - ProjectionElem, + tcx::PlaceTy, Body, HasLocalDecls, Local, Mutability, Place as MirPlace, ProjectionElem, }, ty::{TyCtxt, TyKind}, }, diff --git a/prusti-contracts/prusti-contracts/src/lib.rs b/prusti-contracts/prusti-contracts/src/lib.rs index e36319181df..eced6374b07 100644 --- a/prusti-contracts/prusti-contracts/src/lib.rs +++ b/prusti-contracts/prusti-contracts/src/lib.rs @@ -342,7 +342,7 @@ pub fn old(arg: T) -> T { /// Universal quantifier. /// /// This is a Prusti-internal representation of the `forall` syntax. -#[prusti::builtin="forall"] +#[prusti::builtin = "forall"] pub fn forall(_trigger_set: T, _closure: F) -> bool { true } diff --git a/prusti-contracts/prusti-specs/src/lib.rs b/prusti-contracts/prusti-specs/src/lib.rs index 3403260c345..a16b21cd4dc 100644 --- a/prusti-contracts/prusti-specs/src/lib.rs +++ b/prusti-contracts/prusti-specs/src/lib.rs @@ -77,7 +77,9 @@ fn extract_prusti_attributes( // tokens identical to the ones passed by the native procedural // macro call. let mut iter = attr.tokens.into_iter(); - let TokenTree::Group(group) = iter.next().unwrap() else { unreachable!() }; + let TokenTree::Group(group) = iter.next().unwrap() else { + unreachable!() + }; assert!(iter.next().is_none(), "Unexpected shape of an attribute."); group.stream() } @@ -596,10 +598,14 @@ pub fn refine_trait_spec(_attr: TokenStream, tokens: TokenStream) -> TokenStream let parsed_predicate = handle_result!(predicate::parse_predicate_in_impl(makro.mac.tokens.clone())); - let ParsedPredicate::Impl(predicate) = parsed_predicate else { unreachable!() }; + let ParsedPredicate::Impl(predicate) = parsed_predicate else { + unreachable!() + }; // Patch spec function: Rewrite self with _self: - let syn::Item::Fn(spec_function) = predicate.spec_function else { unreachable!() }; + let syn::Item::Fn(spec_function) = predicate.spec_function else { + unreachable!() + }; generated_spec_items.push(spec_function); // Add patched predicate function to new items @@ -872,7 +878,9 @@ fn extract_prusti_attributes_for_types( // tokens identical to the ones passed by the native procedural // macro call. let mut iter = attr.tokens.into_iter(); - let TokenTree::Group(group) = iter.next().unwrap() else { unreachable!() }; + let TokenTree::Group(group) = iter.next().unwrap() else { + unreachable!() + }; group.stream() } }; diff --git a/prusti-encoder/src/encoders/generic.rs b/prusti-encoder/src/encoders/generic.rs index dff23c44cb2..bfa8aac6502 100644 --- a/prusti-encoder/src/encoders/generic.rs +++ b/prusti-encoder/src/encoders/generic.rs @@ -1,7 +1,4 @@ -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct GenericEncoder; @@ -39,7 +36,8 @@ impl TaskEncoder for GenericEncoder { type EncodingError = GenericEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, GenericEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, GenericEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -57,38 +55,49 @@ impl TaskEncoder for GenericEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { - deps.emit_output_ref::(*task_key, GenericEncoderOutputRef { - snapshot_param_name: "s_Param", - predicate_param_name: "p_Param", - domain_type_name: "s_Type", - }); - vir::with_vcx(|vcx| Ok((GenericEncoderOutput { - snapshot_param: vir::vir_domain! { vcx; domain s_Param {} }, - predicate_param: vir::vir_predicate! { vcx; predicate p_Param(self_p: Ref/*, self_s: s_Param*/) }, - domain_type: vir::vir_domain! { vcx; domain s_Type { - // TODO: only emit these when the types are actually used? - // emit instead from type encoder, to be consistent with the ADT case? - unique function s_Type_Bool(): s_Type; - unique function s_Type_Int_isize(): s_Type; - unique function s_Type_Int_i8(): s_Type; - unique function s_Type_Int_i16(): s_Type; - unique function s_Type_Int_i32(): s_Type; - unique function s_Type_Int_i64(): s_Type; - unique function s_Type_Int_i128(): s_Type; - unique function s_Type_Uint_usize(): s_Type; - unique function s_Type_Uint_u8(): s_Type; - unique function s_Type_Uint_u16(): s_Type; - unique function s_Type_Uint_u32(): s_Type; - unique function s_Type_Uint_u64(): s_Type; - unique function s_Type_Uint_u128(): s_Type; - } }, - }, ()))) + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { + deps.emit_output_ref::( + *task_key, + GenericEncoderOutputRef { + snapshot_param_name: "s_Param", + predicate_param_name: "p_Param", + domain_type_name: "s_Type", + }, + ); + vir::with_vcx(|vcx| { + Ok(( + GenericEncoderOutput { + snapshot_param: vir::vir_domain! { vcx; domain s_Param {} }, + predicate_param: vir::vir_predicate! { vcx; predicate p_Param(self_p: Ref/*, self_s: s_Param*/) }, + domain_type: vir::vir_domain! { vcx; domain s_Type { + // TODO: only emit these when the types are actually used? + // emit instead from type encoder, to be consistent with the ADT case? + unique function s_Type_Bool(): s_Type; + unique function s_Type_Int_isize(): s_Type; + unique function s_Type_Int_i8(): s_Type; + unique function s_Type_Int_i16(): s_Type; + unique function s_Type_Int_i32(): s_Type; + unique function s_Type_Int_i64(): s_Type; + unique function s_Type_Int_i128(): s_Type; + unique function s_Type_Uint_usize(): s_Type; + unique function s_Type_Uint_u8(): s_Type; + unique function s_Type_Uint_u16(): s_Type; + unique function s_Type_Uint_u32(): s_Type; + unique function s_Type_Uint_u64(): s_Type; + unique function s_Type_Uint_u128(): s_Type; + } }, + }, + (), + )) + }) } } diff --git a/prusti-encoder/src/encoders/local_def.rs b/prusti-encoder/src/encoders/local_def.rs index d6720a3d0c9..3b385fb8278 100644 --- a/prusti-encoder/src/encoders/local_def.rs +++ b/prusti-encoder/src/encoders/local_def.rs @@ -1,11 +1,7 @@ -use prusti_rustc_interface::{ - index::IndexVec, - middle::mir, - span::def_id::DefId -}; +use prusti_rustc_interface::{index::IndexVec, middle::mir, span::def_id::DefId}; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use std::cell::RefCell; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct MirLocalDefEncoder; #[derive(Clone, Copy)] @@ -70,32 +66,37 @@ impl TaskEncoder for MirLocalDefEncoder { deps.emit_output_ref::(def_id, ()); vir::with_vcx(|vcx| { - let local_def_id = def_id.expect_local(); - let body = vcx.body.borrow_mut().get_impure_fn_body_identity(local_def_id); - let locals = IndexVec::from_fn_n(|arg: mir::Local| { - let local = vcx.mk_local(vir::vir_format!(vcx, "_{}p", arg.index())); - let ty = deps.require_ref::( - body.local_decls[arg].ty, - ).unwrap(); - let snapshot = ty.snapshot; - let local_ex = vcx.mk_local_ex_local(local); - let impure_snap = vcx.mk_func_app( - ty.function_snap, - &[local_ex], - ); - let impure_pred = vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc(vir::PredicateAppData { - target: ty.predicate_name, - args: vcx.alloc_slice(&[local_ex]), - }))); - LocalDef { - local, - snapshot, - local_ex, - impure_snap, - impure_pred, - ty: vcx.alloc(ty), - } - }, body.local_decls.len()); + let local_def_id = def_id.expect_local(); + let body = vcx + .body + .borrow_mut() + .get_impure_fn_body_identity(local_def_id); + let locals = IndexVec::from_fn_n( + |arg: mir::Local| { + let local = vcx.mk_local(vir::vir_format!(vcx, "_{}p", arg.index())); + let ty = deps + .require_ref::(body.local_decls[arg].ty) + .unwrap(); + let snapshot = ty.snapshot; + let local_ex = vcx.mk_local_ex_local(local); + let impure_snap = vcx.mk_func_app(ty.function_snap, &[local_ex]); + let impure_pred = vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc( + vir::PredicateAppData { + target: ty.predicate_name, + args: vcx.alloc_slice(&[local_ex]), + }, + ))); + LocalDef { + local, + snapshot, + local_ex, + impure_snap, + impure_pred, + ty: vcx.alloc(ty), + } + }, + body.local_decls.len(), + ); let data = MirLocalDefEncoderOutput { locals: vcx.alloc(locals), arg_count: body.arg_count, diff --git a/prusti-encoder/src/encoders/mir_builtin.rs b/prusti-encoder/src/encoders/mir_builtin.rs index 032d69fdd45..a3654207246 100644 --- a/prusti-encoder/src/encoders/mir_builtin.rs +++ b/prusti-encoder/src/encoders/mir_builtin.rs @@ -1,11 +1,5 @@ -use prusti_rustc_interface::{ - middle::ty, - middle::mir, -}; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use prusti_rustc_interface::middle::{mir, ty}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct MirBuiltinEncoder; @@ -47,7 +41,8 @@ impl TaskEncoder for MirBuiltinEncoder { type EncodingError = MirBuiltinEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirBuiltinEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirBuiltinEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -65,13 +60,16 @@ impl TaskEncoder for MirBuiltinEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { // TODO: this function is also useful for the type encoder, extract? fn int_name<'tcx>(ty: ty::Ty<'tcx>) -> &'static str { match ty.kind() { @@ -87,164 +85,202 @@ impl TaskEncoder for MirBuiltinEncoder { assert_eq!(*ty.kind(), ty::TyKind::Bool); let name = "mir_unop_not"; - deps.emit_output_ref::(task_key.clone(), MirBuiltinEncoderOutputRef { - name, - }); + deps.emit_output_ref::( + task_key.clone(), + MirBuiltinEncoderOutputRef { name }, + ); /* function mir_unop_not(arg: s_Bool): s_Bool { s_Bool$cons(!s_Bool$val(val)) } */ - let ty_s = deps.require_ref::( - *ty, - ).unwrap().snapshot; - Ok((MirBuiltinEncoderOutput { - function: vcx.alloc(vir::FunctionData { - name, - args: vcx.alloc_slice(&[vcx.mk_local_decl("arg", ty_s)]), - ret: ty_s, - pres: &[], - posts: &[], - expr: Some(vcx.mk_func_app( - "s_Bool_cons", - &[vcx.alloc(vir::ExprData::UnOp(vcx.alloc(vir::UnOpData { - kind: vir::UnOpKind::Not, - expr: vcx.mk_func_app( - "s_Bool_val", - &[vcx.mk_local_ex("arg")], - ), - })))], - )), - }), - }, ())) + let ty_s = deps + .require_ref::(*ty) + .unwrap() + .snapshot; + Ok(( + MirBuiltinEncoderOutput { + function: vcx.alloc(vir::FunctionData { + name, + args: vcx.alloc_slice(&[vcx.mk_local_decl("arg", ty_s)]), + ret: ty_s, + pres: &[], + posts: &[], + expr: Some(vcx.mk_func_app( + "s_Bool_cons", + &[vcx.alloc(vir::ExprData::UnOp( + vcx.alloc(vir::UnOpData { + kind: vir::UnOpKind::Not, + expr: vcx.mk_func_app( + "s_Bool_val", + &[vcx.mk_local_ex("arg")], + ), + }), + ))], + )), + }), + }, + (), + )) } MirBuiltinEncoderTask::UnOp(mir::UnOp::Neg, ty) => { let name = vir::vir_format!(vcx, "mir_unop_neg_{}", int_name(*ty)); - deps.emit_output_ref::(task_key.clone(), MirBuiltinEncoderOutputRef { - name, - }); + deps.emit_output_ref::( + task_key.clone(), + MirBuiltinEncoderOutputRef { name }, + ); /* function mir_unop_neg(arg: s_I32): s_I32 { cons(-val(arg)) } */ - let ty_out = deps.require_ref::( - *ty, - ).unwrap(); - Ok((MirBuiltinEncoderOutput { - function: vcx.alloc(vir::FunctionData { - name, - args: vcx.alloc_slice(&[vcx.mk_local_decl("arg", ty_out.snapshot)]), - ret: ty_out.snapshot, - pres: &[], - posts: &[], - expr: Some(vcx.mk_func_app( - ty_out.from_primitive.unwrap(), - &[vcx.alloc(vir::ExprData::UnOp(vcx.alloc(vir::UnOpData { - kind: vir::UnOpKind::Neg, - expr: vcx.mk_func_app( - ty_out.to_primitive.unwrap(), - &[vcx.mk_local_ex("arg")], - ), - })))], - )), - }), - }, ())) + let ty_out = deps + .require_ref::(*ty) + .unwrap(); + Ok(( + MirBuiltinEncoderOutput { + function: vcx.alloc(vir::FunctionData { + name, + args: vcx.alloc_slice(&[vcx.mk_local_decl("arg", ty_out.snapshot)]), + ret: ty_out.snapshot, + pres: &[], + posts: &[], + expr: Some(vcx.mk_func_app( + ty_out.from_primitive.unwrap(), + &[vcx.alloc(vir::ExprData::UnOp(vcx.alloc(vir::UnOpData { + kind: vir::UnOpKind::Neg, + expr: vcx.mk_func_app( + ty_out.to_primitive.unwrap(), + &[vcx.mk_local_ex("arg")], + ), + })))], + )), + }), + }, + (), + )) } // TODO: should these be two different functions? precondition? MirBuiltinEncoderTask::BinOp(mir::BinOp::Add | mir::BinOp::AddUnchecked, ty) => { let name = vir::vir_format!(vcx, "mir_binop_add_{}", int_name(*ty)); - deps.emit_output_ref::(task_key.clone(), MirBuiltinEncoderOutputRef { - name, - }); - - let ty_out = deps.require_ref::( - *ty, - ).unwrap(); - Ok((MirBuiltinEncoderOutput { - function: vcx.alloc(vir::FunctionData { - name, - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("arg1", ty_out.snapshot), - vcx.mk_local_decl("arg2", ty_out.snapshot), - ]), - ret: ty_out.snapshot, - pres: &[], - posts: &[], - expr: Some(vcx.mk_func_app( - ty_out.from_primitive.unwrap(), - &[vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::Add, - lhs: vcx.mk_func_app( - ty_out.to_primitive.unwrap(), - &[vcx.mk_local_ex("arg1")], - ), - rhs: vcx.mk_func_app( - ty_out.to_primitive.unwrap(), - &[vcx.mk_local_ex("arg2")], - ), - })))], - )), - }), - }, ())) + deps.emit_output_ref::( + task_key.clone(), + MirBuiltinEncoderOutputRef { name }, + ); + + let ty_out = deps + .require_ref::(*ty) + .unwrap(); + Ok(( + MirBuiltinEncoderOutput { + function: vcx.alloc(vir::FunctionData { + name, + args: vcx.alloc_slice(&[ + vcx.mk_local_decl("arg1", ty_out.snapshot), + vcx.mk_local_decl("arg2", ty_out.snapshot), + ]), + ret: ty_out.snapshot, + pres: &[], + posts: &[], + expr: Some(vcx.mk_func_app( + ty_out.from_primitive.unwrap(), + &[vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::Add, + lhs: vcx.mk_func_app( + ty_out.to_primitive.unwrap(), + &[vcx.mk_local_ex("arg1")], + ), + rhs: vcx.mk_func_app( + ty_out.to_primitive.unwrap(), + &[vcx.mk_local_ex("arg2")], + ), + })))], + )), + }), + }, + (), + )) } - MirBuiltinEncoderTask::CheckedBinOp(mir::BinOp::Add | mir::BinOp::AddUnchecked, ty) => { - let name = vir::vir_format!(vcx, "mir_checkedbinop_add_{}", int_name(*ty)); - deps.emit_output_ref::(task_key.clone(), MirBuiltinEncoderOutputRef { - name, - }); - - let ty_in = deps.require_ref::( - *ty, - ).unwrap(); - let ty_out = deps.require_ref::( - vcx.tcx.mk_ty_from_kind(ty::TyKind::Tuple(vcx.tcx.mk_type_list(&[ - *ty, - vcx.tcx.mk_ty_from_kind(ty::TyKind::Bool), - ]))), - ).unwrap(); - Ok((MirBuiltinEncoderOutput { - function: vcx.alloc(vir::FunctionData { - name, - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("arg1", ty_in.snapshot), - vcx.mk_local_decl("arg2", ty_in.snapshot), - ]), - ret: ty_out.snapshot, - pres: &[], - posts: &[], - expr: Some(vcx.mk_func_app( - vir::vir_format!(vcx, "{}_cons", ty_out.snapshot_name), - &[ - vcx.mk_func_app( - ty_in.from_primitive.unwrap(), - &[vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::Add, - lhs: vcx.mk_func_app( - ty_in.to_primitive.unwrap(), - &[vcx.mk_local_ex("arg1")], - ), - rhs: vcx.mk_func_app( - ty_in.to_primitive.unwrap(), - &[vcx.mk_local_ex("arg2")], - ), - })))], - ), - // TODO: overflow condition! - vcx.mk_func_app( - "s_Bool_cons", - &[&vir::ExprData::Const(&vir::ConstData::Bool(false))], - ), - ], - )), - }), - }, ())) + MirBuiltinEncoderTask::CheckedBinOp( + op @ (mir::BinOp::Add + | mir::BinOp::AddUnchecked + | mir::BinOp::Sub + | mir::BinOp::SubUnchecked), + ty, + ) => { + let (op_name, vir_op) = match op { + mir::BinOp::Add | mir::BinOp::AddUnchecked => ("add", vir::BinOpKind::Add), + mir::BinOp::Sub | mir::BinOp::SubUnchecked => ("sub", vir::BinOpKind::Sub), + + _ => unreachable!(), + }; + let name = + vir::vir_format!(vcx, "mir_checkedbinop_{}_{}", op_name, int_name(*ty)); + + deps.emit_output_ref::( + task_key.clone(), + MirBuiltinEncoderOutputRef { name }, + ); + + let ty_in = deps + .require_ref::(*ty) + .unwrap(); + let ty_out = deps + .require_ref::(vcx.tcx.mk_ty_from_kind( + ty::TyKind::Tuple( + vcx.tcx.mk_type_list(&[ + *ty, + vcx.tcx.mk_ty_from_kind(ty::TyKind::Bool), + ]), + ), + )) + .unwrap(); + Ok(( + MirBuiltinEncoderOutput { + function: vcx.alloc(vir::FunctionData { + name, + args: vcx.alloc_slice(&[ + vcx.mk_local_decl("arg1", ty_in.snapshot), + vcx.mk_local_decl("arg2", ty_in.snapshot), + ]), + ret: ty_out.snapshot, + pres: &[], + posts: &[], + expr: Some(vcx.mk_func_app( + vir::vir_format!(vcx, "{}_cons", ty_out.snapshot_name), + &[ + vcx.mk_func_app( + ty_in.from_primitive.unwrap(), + &[vcx.alloc(vir::ExprData::BinOp(vcx.alloc( + vir::BinOpData { + kind: vir_op, + lhs: vcx.mk_func_app( + ty_in.to_primitive.unwrap(), + &[vcx.mk_local_ex("arg1")], + ), + rhs: vcx.mk_func_app( + ty_in.to_primitive.unwrap(), + &[vcx.mk_local_ex("arg2")], + ), + }, + )))], + ), + // TODO: overflow condition! + vcx.mk_func_app( + "s_Bool_cons", + &[&vir::ExprData::Const(&vir::ConstData::Bool(false))], + ), + ], + )), + }), + }, + (), + )) } - _ => todo!(), + other => todo!("{other:?}"), } - }) } } diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index de64582a988..91bffdf6765 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -1,17 +1,15 @@ +use mir_state_analysis::{ + free_pcs::{CapabilityKind, FreePcsAnalysis, FreePcsBasicBlock, RepackOp}, + utils::Place, +}; use prusti_rustc_interface::{ middle::{mir, ty}, span::def_id::DefId, }; -use mir_state_analysis::{ - free_pcs::{FreePcsAnalysis, FreePcsBasicBlock, RepackOp, CapabilityKind}, utils::Place, -}; //use mir_ssa_analysis::{ // SsaAnalysis, //}; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct MirImpureEncoder; @@ -48,7 +46,8 @@ impl TaskEncoder for MirImpureEncoder { type EncodingError = MirImpureEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirImpureEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirImpureEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -66,31 +65,35 @@ impl TaskEncoder for MirImpureEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { use mir::visit::Visitor; vir::with_vcx(|vcx| { let def_id = *task_key; let method_name = vir::vir_format!(vcx, "m_{}", vcx.tcx.item_name(def_id)); - deps.emit_output_ref::(def_id, MirImpureEncoderOutputRef { - method_name, - }); - - let local_defs = deps.require_local::( - def_id, - ).unwrap(); - let spec = deps.require_local::( - (def_id, false) - ).unwrap(); + deps.emit_output_ref::(def_id, MirImpureEncoderOutputRef { method_name }); + + let local_defs = deps + .require_local::(def_id) + .unwrap(); + let spec = deps + .require_local::((def_id, false)) + .unwrap(); let (spec_pres, spec_posts) = (spec.pres, spec.posts); - let local_def_id = def_id.expect_local(); - let body = vcx.body.borrow_mut().get_impure_fn_body_identity(local_def_id); + let local_def_id = def_id.expect_local(); + let body = vcx + .body + .borrow_mut() + .get_impure_fn_body_identity(local_def_id); // let body = vcx.tcx.mir_promoted(local_def_id).0.borrow(); //let ssa_analysis = SsaAnalysis::analyse(&body); @@ -129,14 +132,13 @@ impl TaskEncoder for MirImpureEncoder { ))) } if ENCODE_REACH_BB { - start_stmts.extend((0..block_count) - .map(|block| { - let name = vir::vir_format!(vcx, "_reach_bb{block}"); - vcx.alloc(vir::StmtData::LocalDecl( - vir::vir_local_decl! { vcx; [name] : Bool }, - Some(vcx.alloc(vir::ExprData::Todo("false"))), - )) - })); + start_stmts.extend((0..block_count).map(|block| { + let name = vir::vir_format!(vcx, "_reach_bb{block}"); + vcx.alloc(vir::StmtData::LocalDecl( + vir::vir_local_decl! { vcx; [name] : Bool }, + Some(vcx.alloc(vir::ExprData::Todo("false"))), + )) + })); } encoded_blocks.push(vcx.alloc(vir::CfgBlockData { label: vcx.alloc(vir::CfgBlockLabelData::Start), @@ -161,46 +163,68 @@ impl TaskEncoder for MirImpureEncoder { posts.push(local_defs.locals[mir::RETURN_PLACE].impure_pred); posts.extend(spec_posts); - let mut visitor = EncoderVisitor { - vcx, - deps, - local_decls: &body.local_decls, - //ssa_analysis, - fpcs_analysis, - local_defs, - - tmp_ctr: 0, + // TODO: dedup with mir_pure + let attrs = vcx.tcx.get_attrs_unchecked(def_id); + let is_trusted = attrs + .iter() + .filter(|attr| !attr.is_doc_comment()) + .map(|attr| attr.get_normal_item()) + .any(|item| { + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "trusted" + }); + + let blocks = if is_trusted { + None + } else { + let mut visitor = EncoderVisitor { + vcx, + deps, + local_decls: &body.local_decls, + //ssa_analysis, + fpcs_analysis, + local_defs, + + tmp_ctr: 0, + + current_fpcs: None, + + current_stmts: None, + current_terminator: None, + encoded_blocks, + }; + visitor.visit_body(&body); - current_fpcs: None, + visitor.encoded_blocks.push(vcx.alloc(vir::CfgBlockData { + label: vcx.alloc(vir::CfgBlockLabelData::End), + stmts: &[], + terminator: vcx.alloc(vir::TerminatorStmtData::Exit), + })); - current_stmts: None, - current_terminator: None, - encoded_blocks, + Some(vcx.alloc_slice(&visitor.encoded_blocks)) }; - visitor.visit_body(&body); - - visitor.encoded_blocks.push(vcx.alloc(vir::CfgBlockData { - label: vcx.alloc(vir::CfgBlockLabelData::End), - stmts: &[], - terminator: vcx.alloc(vir::TerminatorStmtData::Exit), - })); - Ok((MirImpureEncoderOutput { - method: vcx.alloc(vir::MethodData { - name: method_name, - args: vcx.alloc_slice(&args), - rets: &[], - pres: vcx.alloc_slice(&pres), - posts: vcx.alloc_slice(&posts), - blocks: Some(vcx.alloc_slice(&visitor.encoded_blocks)), - }), - }, ())) + Ok(( + MirImpureEncoderOutput { + method: vcx.alloc(vir::MethodData { + name: method_name, + args: vcx.alloc_slice(&args), + rets: &[], + pres: vcx.alloc_slice(&pres), + posts: vcx.alloc_slice(&posts), + blocks, + }), + }, + (), + )) }) } } struct EncoderVisitor<'vir, 'enc> - where 'vir: 'enc +where + 'vir: 'enc, { vcx: &'vir vir::VirCtxt<'vir>, deps: &'enc mut TaskEncoderDependencies<'vir>, @@ -237,13 +261,15 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { match projection { mir::ProjectionElem::Field(f, ty) => { let ty_out_struct = ty_out.expect_structlike(); - let field_ty_out = self.deps.require_ref::( - ty, - ).unwrap(); - (self.vcx.mk_func_app( - ty_out_struct.field_projection_p[f.as_usize()], - &[base], - ), field_ty_out) + let field_ty_out = self + .deps + .require_ref::(ty) + .unwrap(); + ( + self.vcx + .mk_func_app(ty_out_struct.field_projection_p[f.as_usize()], &[base]), + field_ty_out, + ) } _ => panic!("unsupported projection"), } @@ -255,8 +281,11 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { ty_out: crate::encoders::TypeEncoderOutputRef<'vir>, projection: &'vir [mir::PlaceElem<'vir>], ) -> (vir::Expr<'vir>, crate::encoders::TypeEncoderOutputRef<'vir>) { - projection.iter() - .fold((base, ty_out), |(base, ty_out), proj| self.project_one(base, ty_out, *proj)) + projection + .iter() + .fold((base, ty_out), |(base, ty_out), proj| { + self.project_one(base, ty_out, *proj) + }) } /* @@ -301,11 +330,10 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { } */ - fn fpcs_location( - &mut self, - location: mir::Location, - ) { - let repacks = self.current_fpcs.as_ref().unwrap().statements[location.statement_index].repacks.clone(); + fn fpcs_location(&mut self, location: mir::Location) { + let repacks = self.current_fpcs.as_ref().unwrap().statements[location.statement_index] + .repacks + .clone(); for repack_op in repacks { match repack_op { RepackOp::Expand(place, _target, capability_kind) @@ -319,24 +347,26 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { let place_ty = place.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); - let place_ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); + let place_ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); let ref_p = self.encode_place(place); - if matches!(repack_op, mir_state_analysis::free_pcs::RepackOp::Expand(..)) { - self.stmt(vir::StmtData::Unfold(self.vcx.alloc(vir::PredicateAppData { - target: place_ty_out.predicate_name, - args: self.vcx.alloc_slice(&[ - ref_p, - ]), - }))); + if matches!( + repack_op, + mir_state_analysis::free_pcs::RepackOp::Expand(..) + ) { + self.stmt(vir::StmtData::Unfold(self.vcx.alloc( + vir::PredicateAppData { + target: place_ty_out.predicate_name, + args: self.vcx.alloc_slice(&[ref_p]), + }, + ))); } else { self.stmt(vir::StmtData::Fold(self.vcx.alloc(vir::PredicateAppData { target: place_ty_out.predicate_name, - args: self.vcx.alloc_slice(&[ - ref_p, - ]), + args: self.vcx.alloc_slice(&[ref_p]), }))); } } @@ -344,133 +374,148 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { let place_ty = place.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); - let place_ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); + let place_ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); let ref_p = self.encode_place(place); - self.stmt(vir::StmtData::Exhale(self.vcx.alloc(vir::ExprData::PredicateApp(self.vcx.alloc(vir::PredicateAppData { - target: place_ty_out.predicate_name, - args: self.vcx.alloc_slice(&[ - ref_p, - ]), - }))))); + self.stmt(vir::StmtData::Exhale(self.vcx.alloc( + vir::ExprData::PredicateApp(self.vcx.alloc(vir::PredicateAppData { + target: place_ty_out.predicate_name, + args: self.vcx.alloc_slice(&[ref_p]), + })), + ))); } unsupported_op => panic!("unsupported repack op: {unsupported_op:?}"), } } } - fn encode_operand_snap( - &mut self, - operand: &mir::Operand<'vir>, - ) -> vir::Expr<'vir> { + fn encode_operand_snap(&mut self, operand: &mir::Operand<'vir>) -> vir::Expr<'vir> { match operand { &mir::Operand::Move(source) => { let place_ty = source.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); // TODO - let ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); + let ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); let place_exp = self.encode_place(Place::from(source)); let snap_val = self.vcx.mk_func_app(ty_out.function_snap, &[place_exp]); let tmp_exp = self.new_tmp(ty_out.snapshot).1; - self.stmt(vir::StmtData::PureAssign(self.vcx.alloc(vir::PureAssignData { - lhs: tmp_exp, - rhs: snap_val, - }))); - self.stmt(vir::StmtData::Exhale(self.vcx.alloc(vir::ExprData::PredicateApp(self.vcx.alloc(vir::PredicateAppData { - target: ty_out.predicate_name, - args: self.vcx.alloc_slice(&[place_exp]), - }))))); + self.stmt(vir::StmtData::PureAssign(self.vcx.alloc( + vir::PureAssignData { + lhs: tmp_exp, + rhs: snap_val, + }, + ))); + self.stmt(vir::StmtData::Exhale(self.vcx.alloc( + vir::ExprData::PredicateApp(self.vcx.alloc(vir::PredicateAppData { + target: ty_out.predicate_name, + args: self.vcx.alloc_slice(&[place_exp]), + })), + ))); tmp_exp } &mir::Operand::Copy(source) => { let place_ty = source.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); // TODO - let ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); - self.vcx.mk_func_app(ty_out.function_snap, &[self.encode_place(Place::from(source))]) + let ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); + self.vcx.mk_func_app( + ty_out.function_snap, + &[self.encode_place(Place::from(source))], + ) } mir::Operand::Constant(box constant) => self.encode_constant(constant), } } - fn encode_operand( - &mut self, - operand: &mir::Operand<'vir>, - ) -> vir::Expr<'vir> { + fn encode_operand(&mut self, operand: &mir::Operand<'vir>) -> vir::Expr<'vir> { let (snap_val, ty_out) = match operand { &mir::Operand::Move(source) => return self.encode_place(Place::from(source)), &mir::Operand::Copy(source) => { let place_ty = source.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); // TODO - let ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); - (self.vcx.mk_func_app(ty_out.function_snap, &[self.encode_place(Place::from(source))]), ty_out) + let ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); + ( + self.vcx.mk_func_app( + ty_out.function_snap, + &[self.encode_place(Place::from(source))], + ), + ty_out, + ) } mir::Operand::Constant(box constant) => { - let ty_out = self.deps.require_ref::( - constant.ty(), - ).unwrap(); + let ty_out = self + .deps + .require_ref::(constant.ty()) + .unwrap(); (self.encode_constant(constant), ty_out) } }; let tmp_exp = self.new_tmp(&vir::TypeData::Ref).1; - self.stmt(vir::StmtData::MethodCall(self.vcx.alloc(vir::MethodCallData { - targets: &[], - method: ty_out.method_assign, - args: self.vcx.alloc_slice(&[tmp_exp, snap_val]), - }))); + self.stmt(vir::StmtData::MethodCall(self.vcx.alloc( + vir::MethodCallData { + targets: &[], + method: ty_out.method_assign, + args: self.vcx.alloc_slice(&[tmp_exp, snap_val]), + }, + ))); tmp_exp } - fn encode_place( - &mut self, - place: Place<'vir>, - ) -> vir::Expr<'vir> { + fn encode_place(&mut self, place: Place<'vir>) -> vir::Expr<'vir> { //assert!(place.projection.is_empty()); //self.vcx.mk_local_ex(vir::vir_format!(self.vcx, "_{}p", place.local.index())) self.project( self.local_defs.locals[place.local].local_ex, self.local_defs.locals[place.local].ty.clone(), place.projection, - ).0 + ) + .0 } // TODO: this will not work for unevaluated constants (which needs const // MIR evaluation, more like pure fn body encoding) - fn encode_constant( - &self, - constant: &mir::Constant<'vir>, - ) -> vir::Expr<'vir> { + fn encode_constant(&self, constant: &mir::Constant<'vir>) -> vir::Expr<'vir> { match constant.literal { - mir::ConstantKind::Val(const_val, const_ty) => { - match const_ty.kind() { - ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Tuple0_cons()"), - )), - ty::TyKind::Int(int_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Int_{}_cons({})", int_ty.name_str(), scalar_val.try_to_int(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Uint(uint_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Uint_{}_cons({})", uint_ty.name_str(), scalar_val.try_to_uint(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Bool => self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Bool_cons({})", const_val.try_to_bool().unwrap()), - )), - unsupported_ty => todo!("unsupported constant literal type: {unsupported_ty:?}"), + mir::ConstantKind::Val(const_val, const_ty) => match const_ty.kind() { + ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc(vir::ExprData::Todo( + vir::vir_format!(self.vcx, "s_Tuple0_cons()"), + )), + ty::TyKind::Int(int_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!( + self.vcx, + "s_Int_{}_cons({})", + int_ty.name_str(), + scalar_val.try_to_int(scalar_val.size()).unwrap() + ))) } - } + ty::TyKind::Uint(uint_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!( + self.vcx, + "s_Uint_{}_cons({})", + uint_ty.name_str(), + scalar_val.try_to_uint(scalar_val.size()).unwrap() + ))) + } + ty::TyKind::Bool => self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!( + self.vcx, + "s_Bool_cons({})", + const_val.try_to_bool().unwrap() + ))), + unsupported_ty => todo!("unsupported constant literal type: {unsupported_ty:?}"), + }, unsupported_literal => todo!("unsupported constant literal: {unsupported_literal:?}"), } } @@ -491,21 +536,23 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { // fn visit_body(&mut self, body: &mir::Body<'tcx>) { // println!("visiting body!"); // } - fn visit_basic_block_data( - &mut self, - block: mir::BasicBlock, - data: &mir::BasicBlockData<'vir>, - ) { + fn visit_basic_block_data(&mut self, block: mir::BasicBlock, data: &mir::BasicBlockData<'vir>) { self.current_fpcs = Some(self.fpcs_analysis.get_all_for_bb(block)); self.current_stmts = Some(Vec::with_capacity( data.statements.len(), // TODO: not exact? )); if ENCODE_REACH_BB { - self.stmt(vir::StmtData::PureAssign(self.vcx.alloc(vir::PureAssignData { - lhs: self.vcx.mk_local_ex(vir::vir_format!(self.vcx, "_reach_bb{}", block.as_usize())), - rhs: self.vcx.mk_true(), - }))); + self.stmt(vir::StmtData::PureAssign(self.vcx.alloc( + vir::PureAssignData { + lhs: self.vcx.mk_local_ex(vir::vir_format!( + self.vcx, + "_reach_bb{}", + block.as_usize() + )), + rhs: self.vcx.mk_true(), + }, + ))); } /* @@ -539,26 +586,25 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { self.super_basic_block_data(block, data); let stmts = self.current_stmts.take().unwrap(); let terminator = self.current_terminator.take().unwrap(); - self.encoded_blocks.push(self.vcx.alloc(vir::CfgBlockData { - label: self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(block.as_usize())), - stmts: self.vcx.alloc_slice(&stmts), - terminator, - })); + self.encoded_blocks.push( + self.vcx.alloc(vir::CfgBlockData { + label: self + .vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(block.as_usize())), + stmts: self.vcx.alloc_slice(&stmts), + terminator, + }), + ); } - fn visit_statement( - &mut self, - statement: &mir::Statement<'vir>, - location: mir::Location, - ) { + fn visit_statement(&mut self, statement: &mir::Statement<'vir>, location: mir::Location) { // TODO: proper flag // This clears up the noise a bit, making sure StorageLive and other // kinds do not show up in the comments. let IGNORE_NOP_STMTS = true; if IGNORE_NOP_STMTS { match &statement.kind { - mir::StatementKind::StorageLive(..) - | mir::StatementKind::StorageDead(..) => { + mir::StatementKind::StorageLive(..) | mir::StatementKind::StorageDead(..) => { return; } _ => {} @@ -613,7 +659,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { rhs: self.encode_operand_snap(r), })))], )), - mir::Rvalue::BinaryOp(mir::BinOp::Lt, box (l, r)) => { + mir::Rvalue::BinaryOp(op, box (l, r)) => { let ty_l = self.deps.require_ref::( l.ty(self.local_decls, self.vcx.tcx), ).unwrap().to_primitive.unwrap(); @@ -624,7 +670,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { Some(self.vcx.mk_func_app( "s_Bool_cons", // TODO: go through type encoder &[self.vcx.alloc(vir::ExprData::BinOp(self.vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpLt, + kind: vir::BinOpKind::from(op), lhs: self.vcx.mk_func_app( ty_l, &[self.encode_operand_snap(l)], @@ -686,7 +732,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { } mir::Rvalue::Aggregate( - box mir::AggregateKind::Adt(..), + box mir::AggregateKind::Adt(..) | box mir::AggregateKind::Tuple, fields, ) => { let dest_ty_struct = dest_ty_out.expect_structlike(); @@ -700,10 +746,10 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { method: dest_ty_out.method_assign, args: self.vcx.alloc_slice(&[proj_ref, cons]), }))); - + for field in fields { if let mir::Operand::Move(source) = field { - + } } None @@ -712,7 +758,8 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { //mir::Rvalue::Discriminant(Place<'tcx>) => {} //mir::Rvalue::ShallowInitBox(Operand<'tcx>, Ty<'tcx>) => {} //mir::Rvalue::CopyForDeref(Place<'tcx>) => {} - _ => { + other => { + tracing::error!("unsupported rvalue {other:?}"); Some(self.vcx.alloc(vir::ExprData::Todo( vir::vir_format!(self.vcx, "rvalue {rvalue:?}"), ))) @@ -735,7 +782,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { // no-ops mir::StatementKind::FakeRead(_) | mir::StatementKind::Retag(..) - //| mir::StatementKind::PlaceMention(_) + | mir::StatementKind::PlaceMention(_) | mir::StatementKind::AscribeUserType(..) | mir::StatementKind::Coverage(_) //| mir::StatementKind::ConstEvalCounter @@ -745,42 +792,53 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { } } - fn visit_terminator( - &mut self, - terminator: &mir::Terminator<'vir>, - location: mir::Location, - ) { + fn visit_terminator(&mut self, terminator: &mir::Terminator<'vir>, location: mir::Location) { self.fpcs_location(location); let terminator = match &terminator.kind { mir::TerminatorKind::Goto { target } - | mir::TerminatorKind::FalseUnwind { real_target: target, .. } => - self.vcx.alloc(vir::TerminatorStmtData::Goto( - self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), - )), + | mir::TerminatorKind::FalseUnwind { + real_target: target, + .. + } => self.vcx.alloc(vir::TerminatorStmtData::Goto( + self.vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), + )), mir::TerminatorKind::SwitchInt { discr, targets } => { //let discr_version = self.ssa_analysis.version.get(&(location, discr.local)).unwrap(); //let discr_name = vir::vir_format!(self.vcx, "_{}s_{}", discr.local.index(), discr_version); - let ty_out = self.deps.require_ref::( - discr.ty(self.local_decls, self.vcx.tcx), - ).unwrap(); - - let goto_targets = self.vcx.alloc_slice(&targets.iter() - .map(|(value, target)| ( - ty_out.expr_from_u128(value), - //self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!(self.vcx, "constant({value})"))), - self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), - )) - .collect::>()); + let ty_out = self + .deps + .require_ref::( + discr.ty(self.local_decls, self.vcx.tcx), + ) + .unwrap(); + + let goto_targets = self.vcx.alloc_slice( + &targets + .iter() + .map(|(value, target)| { + ( + ty_out.expr_from_u128(value), + //self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!(self.vcx, "constant({value})"))), + self.vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), + ) + }) + .collect::>(), + ); let goto_otherwise = self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock( targets.otherwise().as_usize(), )); let discr_ex = self.encode_operand_snap(discr); - self.vcx.alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc(vir::GotoIfData { - value: discr_ex, // self.vcx.mk_local_ex(discr_name), - targets: goto_targets, - otherwise: goto_otherwise, - }))) + self.vcx + .alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc( + vir::GotoIfData { + value: discr_ex, // self.vcx.mk_local_ex(discr_name), + targets: goto_targets, + otherwise: goto_otherwise, + }, + ))) } mir::TerminatorKind::Return => self.vcx.alloc(vir::TerminatorStmtData::Goto( self.vcx.alloc(vir::CfgBlockLabelData::End), @@ -802,25 +860,118 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { _ => todo!(), }; - let func_out = self.deps.require_ref::( - *func_def_id, - ).unwrap(); + // TODO: dedup with mir_pure + let attrs = self.vcx.tcx.get_attrs_unchecked(*func_def_id); + let is_pure = attrs + .iter() + .filter(|attr| !attr.is_doc_comment()) + .map(|attr| attr.get_normal_item()) + .any(|item| { + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "pure" + }); + + let dest = self.encode_place(Place::from(*destination)); + let call_args = args.iter().map(|op| { + if is_pure { + self.encode_operand_snap(op) + } else { + self.encode_operand(op) + } + }); + + if is_pure { + let func_args = call_args.collect::>(); + let pure_func_name = self + .deps + .require_ref::(*func_def_id) + .unwrap() + .function_name; + + let pure_func_app = self.vcx.mk_func_app(pure_func_name, &func_args); + + let method_assign = { + //TODO: Can we get `method_assign` in a better way? Maybe from the MirFunctionEncoder? + let body = self + .vcx + .body + .borrow_mut() + .get_impure_fn_body_identity(func_def_id.expect_local()); + let return_type = self + .deps + .require_ref::(body.return_ty()) + .unwrap(); + return_type.method_assign + }; + + self.stmt(vir::StmtData::MethodCall(self.vcx.alloc( + vir::MethodCallData { + targets: &[], + method: method_assign, + args: self.vcx.alloc_slice(&[dest, pure_func_app]), + }, + ))); + } else { + let meth_args = std::iter::once(dest).chain(call_args).collect::>(); + let func_out = self + .deps + .require_ref::(*func_def_id) + .unwrap(); + + self.stmt(vir::StmtData::MethodCall(self.vcx.alloc( + vir::MethodCallData { + targets: &[], + method: func_out.method_name, + args: self.vcx.alloc_slice(&meth_args), + }, + ))); + } - let destination = self.encode_place(Place::from(*destination)); - let args = args.iter().map(|op| self.encode_operand(op)); - let args: Vec<_> = std::iter::once(destination).chain(args).collect(); - self.stmt(vir::StmtData::MethodCall(self.vcx.alloc(vir::MethodCallData { - targets: &[], - method: func_out.method_name, - args: self.vcx.alloc_slice(&args), - }))); - self.vcx.alloc(vir::TerminatorStmtData::Goto( - self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.unwrap().as_usize())), + self.vcx.alloc(vir::TerminatorStmtData::Goto(self.vcx.alloc( + vir::CfgBlockLabelData::BasicBlock(target.unwrap().as_usize()), + ))) + } + mir::TerminatorKind::Assert { + cond, + expected, + msg, + target, + unwind, + } => { + let otherwise = match unwind { + mir::UnwindAction::Cleanup(bb) => bb, + _ => todo!(), + }; + + let enc = self.encode_operand_snap(cond); + let enc = self.vcx.mk_func_app("s_Bool_val", &[enc]); + + let target_bb = self + .vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())); + + self.vcx.alloc(vir::TerminatorStmtData::GotoIf( + self.vcx.alloc(vir::GotoIfData { + value: enc, + targets: self.vcx.alloc_slice(&[( + self.vcx.alloc(vir::ExprData::Const( + self.vcx.alloc(vir::ConstData::Bool(*expected)), + )), + &target_bb, + )]), + otherwise: self + .vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(otherwise.as_usize())), + }), )) } - unsupported_kind => self.vcx.alloc(vir::TerminatorStmtData::Dummy( - vir::vir_format!(self.vcx, "terminator {unsupported_kind:?}"), - )), + unsupported_kind => self + .vcx + .alloc(vir::TerminatorStmtData::Dummy(vir::vir_format!( + self.vcx, + "terminator {unsupported_kind:?}" + ))), }; assert!(self.current_terminator.replace(terminator).is_none()); } diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index 0d312e49adc..f26926d107d 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -1,14 +1,12 @@ +use crate::encoders::{TypeEncoder, ViperTupleEncoder}; use prusti_rustc_interface::{ data_structures::graph::dominators::Dominators, middle::{mir, ty}, span::def_id::DefId, type_ir::sty::TyKind, }; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; use std::collections::HashMap; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct MirPureEncoder; @@ -29,6 +27,142 @@ pub struct MirPureEncoderOutput<'vir> { pub expr: ExprRet<'vir>, } +/// Optimize a vir expresison +/// +/// This is a temporary fix for the issue where variables that are quantified over and are then stored in a let binding +/// cause issues with triggering somehow. +/// +/// This was also intended to make debugging easier by making the resulting viper code a bit more readable +/// +/// This should be replaced with a proper solution +fn opt<'vir, Cur, Next>( + expr: vir::ExprGen<'vir, Cur, Next>, + rename: &mut HashMap, +) -> vir::ExprGen<'vir, Cur, Next> { + match expr { + vir::ExprGenData::Local(d) => { + let nam = rename + .get(d.name) + .map(|e| e.as_str()) + .unwrap_or(d.name) + .to_owned(); + vir::with_vcx(move |vcx| vcx.mk_local_ex(vcx.alloc_str(&nam))) + } + + vir::ExprGenData::Let(vir::LetGenData { name, val, expr }) => { + let val = opt(val, rename); + + match val { + // let name = loc.name + vir::ExprGenData::Local(loc) => { + let t = rename + .get(loc.name) + .map(|e| e.to_owned()) + .unwrap_or(loc.name.to_string()); + assert!(rename.insert(name.to_string(), t).is_none()); + return opt(expr, rename); + } + _ => {} + } + + let expr = opt(expr, rename); + + match expr { + vir::ExprGenData::Local(inner_local) => { + if &inner_local.name == name { + return val; + } + } + _ => {} + } + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::Let(vcx.alloc(vir::LetGenData { + name, + val, + expr, + }))) + }) + } + vir::ExprGenData::FuncApp(vir::FuncAppGenData { target, args }) => { + let n_args = args.iter().map(|arg| opt(arg, rename)).collect::>(); + vir::with_vcx(move |vcx| vcx.mk_func_app(target, &n_args)) + } + + vir::ExprGenData::PredicateApp(vir::PredicateAppGenData { target, args }) => { + let n_args = args.iter().map(|arg| opt(arg, rename)).collect::>(); + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::PredicateApp(vcx.alloc( + vir::PredicateAppGenData { + target, + args: vcx.alloc_slice(&n_args), + }, + ))) + }) + } + + vir::ExprGenData::Forall(vir::ForallGenData { + qvars, + triggers, + body, + }) => { + let body = opt(body, rename); + + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::Forall(vcx.alloc(vir::ForallGenData { + qvars, + triggers, + body, + }))) + }) + } + + vir::ExprGenData::Ternary(vir::TernaryGenData { cond, then, else_ }) => { + let cond = opt(cond, rename); + let then = opt(then, rename); + let else_ = opt(else_, rename); + + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::Ternary(vcx.alloc(vir::TernaryGenData { + cond, + then, + else_, + }))) + }) + } + + vir::ExprGenData::BinOp(vir::BinOpGenData { kind, lhs, rhs }) => { + let lhs = opt(lhs, rename); + let rhs = opt(rhs, rename); + + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::BinOp(vcx.alloc(vir::BinOpGenData { + kind: kind.clone(), + lhs, + rhs, + }))) + }) + } + + vir::ExprGenData::UnOp(vir::UnOpGenData { kind, expr }) => { + let expr = opt(expr, rename); + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::UnOp(vcx.alloc(vir::UnOpGenData { + kind: kind.clone(), + expr, + }))) + }) + } + + todo @ (vir::ExprGenData::Unfolding(_) + | vir::ExprGenData::Field(_, _) + | vir::ExprGenData::Old(_) + | vir::ExprGenData::AccField(_) + | vir::ExprGenData::Lazy(_, _)) => todo, + + other @ (vir::ExprGenData::Const(_) | vir::ExprGenData::Todo(_)) => other, + } +} + use std::cell::RefCell; thread_local! { static CACHE: task_encoder::CacheStaticRef = RefCell::new(Default::default()); @@ -39,9 +173,9 @@ pub struct MirPureEncoderTask<'tcx> { // TODO: depth of encoding should be in the lazy context rather than here; // can we integrate the lazy context into the identifier system? pub encoding_depth: usize, - pub parent_def_id: DefId, // ID of the function - pub promoted: Option, // ID of a constant within the function - pub param_env: ty::ParamEnv<'tcx>, // param environment at the usage site + pub parent_def_id: DefId, // ID of the function + pub promoted: Option, // ID of a constant within the function + pub param_env: ty::ParamEnv<'tcx>, // param environment at the usage site pub substs: ty::GenericArgsRef<'tcx>, // type substitutions at the usage site } @@ -49,8 +183,8 @@ impl TaskEncoder for MirPureEncoder { type TaskDescription<'vir> = MirPureEncoderTask<'vir>; type TaskKey<'vir> = ( - usize, // encoding depth - DefId, // ID of the function + usize, // encoding depth + DefId, // ID of the function Option, // ID of a constant within the function, or `None` if encoding the function itself ty::GenericArgsRef<'vir>, // ? this should be the "signature", after applying the env/substs ); @@ -60,7 +194,8 @@ impl TaskEncoder for MirPureEncoder { type EncodingError = MirPureEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirPureEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirPureEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -84,13 +219,16 @@ impl TaskEncoder for MirPureEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { deps.emit_output_ref::(*task_key, ()); let def_id = task_key.1; //.parent_def_id; @@ -99,7 +237,10 @@ impl TaskEncoder for MirPureEncoder { tracing::debug!("encoding {def_id:?}"); let expr = vir::with_vcx(move |vcx| { //let body = vcx.tcx.mir_promoted(local_def_id).0.borrow(); - let body = vcx.body.borrow_mut().get_impure_fn_body_identity(local_def_id); + let body = vcx + .body + .borrow_mut() + .get_impure_fn_body_identity(local_def_id); let expr_inner = Encoder::new(vcx, task_key.0, &body, deps).encode_body(); @@ -117,7 +258,19 @@ impl TaskEncoder for MirPureEncoder { assert_eq!(lctx.1.len(), body.arg_count); use vir::Reify; - expr_inner.reify(vcx, lctx) + let expr_inner = expr_inner.reify(vcx, lctx); + + let expr_inner = if true { + tracing::warn!("before opt {expr_inner:?}"); + let mut rename = HashMap::new(); + let opted = opt(expr_inner, &mut rename); + tracing::warn!("after opt {opted:?}"); + opted + } else { + expr_inner + }; + + expr_inner }), )) }); @@ -146,8 +299,16 @@ impl<'vir> Update<'vir> { fn merge(self, newer: Self) -> Self { Self { - binds: self.binds.into_iter().chain(newer.binds.into_iter()).collect(), - versions: self.versions.into_iter().chain(newer.versions.into_iter()).collect(), + binds: self + .binds + .into_iter() + .chain(newer.binds.into_iter()) + .collect(), + versions: self + .versions + .into_iter() + .chain(newer.versions.into_iter()) + .collect(), } } @@ -159,7 +320,8 @@ impl<'vir> Update<'vir> { } struct Encoder<'vir, 'enc> - where 'vir: 'enc +where + 'vir: 'enc, { vcx: &'vir vir::VirCtxt<'vir>, encoding_depth: usize, @@ -171,7 +333,8 @@ struct Encoder<'vir, 'enc> } impl<'vir, 'enc> Encoder<'vir, 'enc> - where 'vir: 'enc +where + 'vir: 'enc, { fn new( vcx: &'vir vir::VirCtxt<'vir>, @@ -179,7 +342,10 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> body: &'enc mir::Body<'vir>, deps: &'enc mut TaskEncoderDependencies<'vir>, ) -> Self { - assert!(!body.basic_blocks.is_cfg_cyclic(), "MIR pure encoding does not support loops"); + assert!( + !body.basic_blocks.is_cfg_cyclic(), + "MIR pure encoding does not support loops" + ); Self { vcx, @@ -187,31 +353,28 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> body, deps, visited: Default::default(), - version_ctr: (0..body.local_decls.len()).map(|local| (local.into(), 0)).collect(), + version_ctr: (0..body.local_decls.len()) + .map(|local| (local.into(), 0)) + .collect(), phi_ctr: 0, } } - fn mk_local( - &self, - local: mir::Local, - version: usize, - ) -> &'vir str { - vir::vir_format!(self.vcx, "_{}_{}s_{}", self.encoding_depth, local.as_usize(), version) + fn mk_local(&self, local: mir::Local, version: usize) -> &'vir str { + vir::vir_format!( + self.vcx, + "_{}_{}s_{}", + self.encoding_depth, + local.as_usize(), + version + ) } - fn mk_local_ex( - &self, - local: mir::Local, - version: usize, - ) -> ExprRet<'vir> { + fn mk_local_ex(&self, local: mir::Local, version: usize) -> ExprRet<'vir> { self.vcx.mk_local_ex(self.mk_local(local, version)) } - fn mk_phi( - &self, - idx: usize, - ) -> &'vir str { + fn mk_phi(&self, idx: usize) -> &'vir str { vir::vir_format!(self.vcx, "_{}_phi_{}", self.encoding_depth, idx) } @@ -224,36 +387,34 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> tuple_ref.mk_elem(self.vcx, self.vcx.mk_local_ex(self.mk_phi(idx)), elem_idx) } - fn bump_version( - &mut self, - update: &mut Update<'vir>, - local: mir::Local, - expr: ExprRet<'vir>, - ) { + fn bump_version(&mut self, update: &mut Update<'vir>, local: mir::Local, expr: ExprRet<'vir>) { let new_version = self.version_ctr.get(&local).copied().unwrap_or(0usize); self.version_ctr.insert(local, new_version + 1); - update.binds.push(UpdateBind::Local(local, new_version, expr)); + update + .binds + .push(UpdateBind::Local(local, new_version, expr)); update.versions.insert(local, new_version); } - fn reify_binds( - &self, - update: Update<'vir>, - expr: ExprRet<'vir>, - ) -> ExprRet<'vir> { - update.binds.iter() - .rfold(expr, |expr, bind| match bind { - UpdateBind::Local(local, ver, val) => self.vcx.alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { - name: self.mk_local(*local, *ver), - val, - expr, - }))), - UpdateBind::Phi(idx, val) => self.vcx.alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { - name: self.mk_phi(*idx), - val, - expr, - }))), - }) + fn reify_binds(&self, update: Update<'vir>, expr: ExprRet<'vir>) -> ExprRet<'vir> { + update.binds.iter().rfold(expr, |expr, bind| match bind { + UpdateBind::Local(local, ver, val) => { + self.vcx + .alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { + name: self.mk_local(*local, *ver), + val, + expr, + }))) + } + UpdateBind::Phi(idx, val) => { + self.vcx + .alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { + name: self.mk_phi(*idx), + val, + expr, + }))) + } + }) } fn reify_branch( @@ -263,30 +424,40 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> curr_ver: &HashMap, update: Update<'vir>, ) -> ExprRet<'vir> { - let tuple_args = mod_locals.iter().map(|local| self.mk_local_ex( - *local, - update.versions.get(local).copied().unwrap_or_else(|| { - // TODO: remove (debug) - if !curr_ver.contains_key(&local) { - println!("unknown version of local! {}", local.as_usize()); - return 0xff - } - curr_ver[local] - }), - )).collect::>(); - self.reify_binds( - update, - tuple_ref.mk_cons(self.vcx, &tuple_args), - ) + let tuple_args = mod_locals + .iter() + .map(|local| { + self.mk_local_ex( + *local, + update.versions.get(local).copied().unwrap_or_else(|| { + // TODO: remove (debug) + if !curr_ver.contains_key(&local) { + tracing::error!("unknown version of local! {}", local.as_usize()); + return 0xff; + } + curr_ver[local] + }), + ) + }) + .collect::>(); + self.reify_binds(update, tuple_ref.mk_cons(self.vcx, &tuple_args)) } fn encode_body(&mut self) -> ExprRet<'vir> { - let end_blocks = self.body.basic_blocks.reverse_postorder() + let end_blocks = self + .body + .basic_blocks + .reverse_postorder() .iter() - .filter(|bb| matches!( - self.body[**bb].terminator, - Some(mir::Terminator { kind: mir::TerminatorKind::Return, .. }), - )) + .filter(|bb| { + matches!( + self.body[**bb].terminator, + Some(mir::Terminator { + kind: mir::TerminatorKind::Return, + .. + }), + ) + }) .collect::>(); assert!(end_blocks.len() > 0, "no Return block found"); assert!(end_blocks.len() < 2, "multiple Return blocks found"); @@ -299,19 +470,16 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> vir::vir_format!(self.vcx, "pure in _{local}"), Box::new(move |_vcx, lctx: ExprInput<'vir>| lctx.1[local - 1]), )); - init.binds.push(UpdateBind::Local(local.into(), 0, local_ex)); + init.binds + .push(UpdateBind::Local(local.into(), 0, local_ex)); init.versions.insert(local.into(), 0); } - let update = self.encode_cfg( - &init.versions, - mir::START_BLOCK, - *end_block, - ); + let update = self.encode_cfg(&init.versions, mir::START_BLOCK, *end_block); let res = init.merge(update); let ret_version = res.versions.get(&mir::RETURN_PLACE).copied().unwrap_or(0); - + self.reify_binds(res, self.mk_local_ex(mir::RETURN_PLACE, ret_version)) } @@ -321,7 +489,8 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> start: mir::BasicBlock, end: mir::BasicBlock, ) -> mir::BasicBlock { - dominators.dominators(end) + dominators + .dominators(end) .take_while(|bb| *bb != start) .last() .unwrap() @@ -337,7 +506,9 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> // walk block statements first let mut new_curr_ver = curr_ver.clone(); - let stmt_update = self.body[start].statements.iter() + let stmt_update = self.body[start] + .statements + .iter() .fold(Update::new(), |update, stmt| { let newer = self.encode_stmt(&new_curr_ver, stmt); newer.add_to_map(&mut new_curr_ver); @@ -360,22 +531,26 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> mir::TerminatorKind::SwitchInt { discr, targets } => { // encode the discriminant operand let discr_expr = self.encode_operand(&new_curr_ver, discr); - let discr_ty_out = self.deps.require_ref::( - discr.ty(self.body, self.vcx.tcx), - ).unwrap(); + let discr_ty_out = self + .deps + .require_ref::(discr.ty(self.body, self.vcx.tcx)) + .unwrap(); // find earliest join point `join` let join = self.find_join_point(dominators, start, end); // walk `start` -> `targets[i]` -> `join` for each target // TODO: indexvec? - let mut updates = targets.all_targets().iter() + let mut updates = targets + .all_targets() + .iter() .map(|target| self.encode_cfg(&new_curr_ver, *target, join)) .collect::>(); // find locals updated in any of the results, which were also // defined before the branch - let mut mod_locals = updates.iter() + let mut mod_locals = updates + .iter() .map(|update| update.versions.keys()) .flatten() .filter(|local| new_curr_ver.contains_key(&local)) @@ -385,24 +560,33 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> mod_locals.dedup(); // for each branch, create a Viper tuple of the updated locals - let tuple_ref = self.deps.require_ref::( - mod_locals.len(), - ).unwrap(); + let tuple_ref = self + .deps + .require_ref::(mod_locals.len()) + .unwrap(); let otherwise_update = updates.pop().unwrap(); - let phi_expr = targets.iter() - .zip(updates.into_iter()) - .fold( - self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, otherwise_update), - |expr, ((cond_val, target), branch_update)| self.vcx.alloc(ExprRetData::Ternary(self.vcx.alloc(vir::TernaryGenData { - cond: self.vcx.alloc(ExprRetData::BinOp(self.vcx.alloc(vir::BinOpGenData { - kind: vir::BinOpKind::CmpEq, - lhs: discr_expr, - rhs: discr_ty_out.expr_from_u128(cond_val).lift(), - }))), - then: self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, branch_update), - else_: expr, - }))), - ); + let phi_expr = targets.iter().zip(updates.into_iter()).fold( + self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, otherwise_update), + |expr, ((cond_val, target), branch_update)| { + self.vcx + .alloc(ExprRetData::Ternary(self.vcx.alloc(vir::TernaryGenData { + cond: self.vcx.alloc(ExprRetData::BinOp(self.vcx.alloc( + vir::BinOpGenData { + kind: vir::BinOpKind::CmpEq, + lhs: discr_expr, + rhs: discr_ty_out.expr_from_u128(cond_val).lift(), + }, + ))), + then: self.reify_branch( + &tuple_ref, + &mod_locals, + &new_curr_ver, + branch_update, + ), + else_: expr, + }))) + }, + ); // assign tuple into a `phi` variable let phi_idx = self.phi_ctr; @@ -416,8 +600,7 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> for (elem_idx, local) in mod_locals.iter().enumerate() { let expr = self.mk_phi_acc(tuple_ref.clone(), phi_idx, elem_idx); self.bump_version(&mut phi_update, *local, expr); - // TODO: add to curr_ver here ? - //new_curr_ver.insert(*local, ); + new_curr_ver.insert(*local, phi_update.versions[local]); } // walk `join` -> `end` @@ -449,25 +632,103 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> TyKind::FnDef(def_id, arg_tys) => { // TODO: this attribute extraction should be done elsewhere? let attrs = self.vcx.tcx.get_attrs_unchecked(*def_id); - attrs.iter() + let normal_attrs = attrs + .iter() .filter(|attr| !attr.is_doc_comment()) .map(|attr| attr.get_normal_item()) - .filter(|item| item.path.segments.len() == 2 - && item.path.segments[0].ident.as_str() == "prusti" - && item.path.segments[1].ident.as_str() == "builtin") + .collect::>(); + normal_attrs + .iter() + .filter(|item| { + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "builtin" + }) .for_each(|attr| match &attr.args { prusti_rustc_interface::ast::AttrArgs::Eq( _, prusti_rustc_interface::ast::AttrArgsEq::Hir(lit), ) => { assert!(builtin.is_none(), "multiple prusti::builtin"); - builtin = Some((match lit.symbol.as_str() { - "forall" => PrustiBuiltin::Forall, - _ => panic!("illegal prusti::builtin"), - }, arg_tys)); + builtin = Some(( + match lit.symbol.as_str() { + "forall" => PrustiBuiltin::Forall, + _ => panic!("illegal prusti::builtin"), + }, + arg_tys, + )); } _ => panic!("illegal prusti::builtin"), }); + + let is_pure = normal_attrs.iter().any(|item| { + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "pure" + }); + + // TODO: detect snapshot_equality properly + let is_snapshot_eq = self + .vcx + .tcx + .opt_item_name(*def_id) + .map(|e| e.as_str() == "snapshot_equality") + == Some(true) + && self.vcx.tcx.crate_name(def_id.krate).as_str() == "prusti_contracts"; + + let func_call = if is_pure { + assert!(builtin.is_none(), "Function is pure and builtin?"); + let pure_func = self + .deps + .require_ref::(*def_id) + .unwrap() + .function_name; + + let encoded_args = args + .iter() + .map(|oper| self.encode_operand(&new_curr_ver, oper)) + .collect::>(); + + let func_call = self.vcx.mk_func_app(pure_func, &encoded_args); + + Some(func_call) + } else if is_snapshot_eq { + assert!( + builtin.is_none(), + "Function is snapshot_equality and builtin?" + ); + let encoded_args = args + .iter() + .map(|oper| self.encode_operand(&new_curr_ver, oper)) + .collect::>(); + + assert_eq!(encoded_args.len(), 2); + + let eq_expr = self.vcx.alloc(vir::ExprGenData::BinOp(self.vcx.alloc( + vir::BinOpGenData { + kind: vir::BinOpKind::CmpEq, + lhs: encoded_args[0], + rhs: encoded_args[1], + }, + ))); + + // TODO: type encoder + Some(self.vcx.mk_func_app("s_Bool_cons", &[eq_expr])) + } else { + None + }; + + if let Some(func_call) = func_call { + let mut term_update = Update::new(); + assert!(destination.projection.is_empty()); + self.bump_version(&mut term_update, destination.local, func_call); + term_update.add_to_map(&mut new_curr_ver); + + // walk rest of CFG + let end_update = self.encode_cfg(&new_curr_ver, target.unwrap(), end); + + return stmt_update.merge(term_update).merge(end_update); + } } _ => todo!(), } @@ -487,18 +748,26 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> _ => panic!("illegal prusti::forall"), }; - let qvars = self.vcx.alloc_slice(&qvar_tys.iter() - .enumerate() - .map(|(idx, qvar_ty)| { - let ty_out = self.deps.require_ref::( - qvar_ty, - ).unwrap(); - self.vcx.mk_local_decl( - vir::vir_format!(self.vcx, "qvar_{}_{idx}", self.encoding_depth), - ty_out.snapshot, - ) - }) - .collect::>()); + let qvars = self.vcx.alloc_slice( + &qvar_tys + .iter() + .enumerate() + .map(|(idx, qvar_ty)| { + let ty_out = self + .deps + .require_ref::(qvar_ty) + .unwrap(); + self.vcx.mk_local_decl( + vir::vir_format!( + self.vcx, + "qvar_{}_{idx}", + self.encoding_depth + ), + ty_out.snapshot, + ) + }) + .collect::>(), + ); //let qvar_tuple_ref = self.deps.require_ref::( // qvars.len(), //).unwrap(); @@ -518,42 +787,49 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> reify_args.push(unsafe { std::mem::transmute(self.encode_operand(&new_curr_ver, &args[1])) }); - reify_args.extend((0..qvars.len()) - .map(|idx| self.vcx.mk_local_ex( - vir::vir_format!(self.vcx, "qvar_{}_{idx}", self.encoding_depth), - ))); + reify_args.extend((0..qvars.len()).map(|idx| { + self.vcx.mk_local_ex(vir::vir_format!( + self.vcx, + "qvar_{}_{idx}", + self.encoding_depth + )) + })); // TODO: recursively invoke MirPure encoder to encode // the body of the closure; pass the closure as the // variable to use, then closure access = tuple access // (then hope to optimise this away later ...?) use vir::Reify; - let body = self.deps.require_local::( - MirPureEncoderTask { + let body = self + .deps + .require_local::(MirPureEncoderTask { encoding_depth: self.encoding_depth + 1, parent_def_id: cl_def_id, promoted: None, param_env: self.vcx.tcx.param_env(cl_def_id), substs: ty::List::identity_for_item(self.vcx.tcx, cl_def_id), - } - ).unwrap().expr - // arguments to the closure are - // - the closure itself - // - the qvars - .reify(self.vcx, ( - cl_def_id, - self.vcx.alloc_slice(&reify_args), - )) + }) + .unwrap() + .expr + // arguments to the closure are + // - the closure itself + // - the qvars + .reify(self.vcx, (cl_def_id, self.vcx.alloc_slice(&reify_args))) .lift(); + // TODO: use type encoder + let body = self.vcx.mk_func_app("s_Bool_val", &[body]); + // TODO: use type encoder let forall = self.vcx.mk_func_app( "s_Bool_cons", - &[self.vcx.alloc(ExprRetData::Forall(self.vcx.alloc(vir::ForallGenData { - qvars, - triggers: &[], // TODO - body, - })))], + &[self.vcx.alloc(ExprRetData::Forall(self.vcx.alloc( + vir::ForallGenData { + qvars, + triggers: &[], // TODO + body, + }, + )))], ); let mut term_update = Update::new(); @@ -566,7 +842,9 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> stmt_update.merge(term_update).merge(end_update) } - None => todo!(), + None => { + todo!("call not supported {func:?}"); + } } } @@ -581,10 +859,15 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> ) -> Update<'vir> { let mut update = Update::new(); match &stmt.kind { - mir::StatementKind::StorageLive(..) - | mir::StatementKind::StorageDead(..) + mir::StatementKind::StorageLive(local) => { + let new_version = self.version_ctr.get(local).copied().unwrap_or(0usize); + self.version_ctr.insert(*local, new_version + 1); + update.versions.insert(*local, new_version); + } + mir::StatementKind::StorageDead(..) | mir::StatementKind::FakeRead(..) - | mir::StatementKind::AscribeUserType(..) => {}, // nop + | mir::StatementKind::AscribeUserType(..) + | mir::StatementKind::PlaceMention(..) => {} // nop mir::StatementKind::Assign(box (dest, rvalue)) => { assert!(dest.projection.is_empty()); let expr = self.encode_rvalue(curr_ver, rvalue); @@ -614,69 +897,112 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> // Len // Cast mir::Rvalue::BinaryOp(op, box (l, r)) => { - let ty_l = self.deps.require_ref::( - l.ty(self.body, self.vcx.tcx), - ).unwrap().to_primitive.unwrap(); - let ty_r = self.deps.require_ref::( - r.ty(self.body, self.vcx.tcx), - ).unwrap().to_primitive.unwrap(); - let ty_rvalue = self.deps.require_ref::( - rvalue.ty(self.body, self.vcx.tcx), - ).unwrap().from_primitive.unwrap(); + let ty_l = self + .deps + .require_ref::(l.ty(self.body, self.vcx.tcx)) + .unwrap() + .to_primitive + .unwrap(); + let ty_r = self + .deps + .require_ref::(r.ty(self.body, self.vcx.tcx)) + .unwrap() + .to_primitive + .unwrap(); + let ty_rvalue = self + .deps + .require_ref::(rvalue.ty(self.body, self.vcx.tcx)) + .unwrap() + .from_primitive + .unwrap(); self.vcx.mk_func_app( ty_rvalue, - &[self.vcx.alloc(ExprRetData::BinOp(self.vcx.alloc(vir::BinOpGenData { - kind: op.into(), - lhs: self.vcx.mk_func_app( - ty_l, - &[self.encode_operand(curr_ver, l)], - ), - rhs: self.vcx.mk_func_app( - ty_r, - &[self.encode_operand(curr_ver, r)], - ), - })))], + &[self.vcx.alloc(ExprRetData::BinOp( + self.vcx.alloc(vir::BinOpGenData { + kind: op.into(), + lhs: self + .vcx + .mk_func_app(ty_l, &[self.encode_operand(curr_ver, l)]), + rhs: self + .vcx + .mk_func_app(ty_r, &[self.encode_operand(curr_ver, r)]), + }), + ))], ) } // CheckedBinaryOp // NullaryOp mir::Rvalue::UnaryOp(op, expr) => { - let ty_expr = self.deps.require_ref::( - expr.ty(self.body, self.vcx.tcx), - ).unwrap().to_primitive.unwrap(); - let ty_rvalue = self.deps.require_ref::( - rvalue.ty(self.body, self.vcx.tcx), - ).unwrap().from_primitive.unwrap(); + let ty_expr = self + .deps + .require_ref::(expr.ty(self.body, self.vcx.tcx)) + .unwrap() + .to_primitive + .unwrap(); + let ty_rvalue = self + .deps + .require_ref::(rvalue.ty(self.body, self.vcx.tcx)) + .unwrap() + .from_primitive + .unwrap(); self.vcx.mk_func_app( ty_rvalue, - &[self.vcx.alloc(ExprRetData::UnOp(self.vcx.alloc(vir::UnOpGenData { - kind: op.into(), - expr: self.vcx.mk_func_app( - ty_expr, - &[self.encode_operand(curr_ver, expr)], - ), - })))], + &[self.vcx.alloc(ExprRetData::UnOp( + self.vcx.alloc(vir::UnOpGenData { + kind: op.into(), + expr: self + .vcx + .mk_func_app(ty_expr, &[self.encode_operand(curr_ver, expr)]), + }), + ))], ) } // Discriminant mir::Rvalue::Aggregate(box kind, fields) => match kind { mir::AggregateKind::Tuple if fields.len() == 0 => - // TODO: why is this not a constant? - self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Tuple0_cons()"), - )), + // TODO: why is this not a constant? + { + self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Tuple0_cons()" + ))) + } mir::AggregateKind::Closure(..) => { // TODO: only when this is a spec closure? - let tuple_ref = self.deps.require_ref::( - fields.len(), - ).unwrap(); - tuple_ref.mk_cons(self.vcx, &fields.iter() - .map(|field| self.encode_operand(curr_ver, field)) - .collect::>()) + let tuple_ref = self + .deps + .require_ref::(fields.len()) + .unwrap(); + tuple_ref.mk_cons( + self.vcx, + &fields + .iter() + .map(|field| self.encode_operand(curr_ver, field)) + .collect::>(), + ) } - _ => todo!(), + _ => todo!("Unsupported Rvalue::AggregateKind: {kind:?}"), + }, + mir::Rvalue::CheckedBinaryOp(binop, box (l, r)) => { + let binop_function = self + .deps + .require_ref::( + crate::encoders::MirBuiltinEncoderTask::CheckedBinOp( + *binop, + l.ty(self.body, self.vcx.tcx), // TODO: ? + ), + ) + .unwrap() + .name; + self.vcx.mk_func_app( + binop_function, + &[ + self.encode_operand(curr_ver, l), + self.encode_operand(curr_ver, r), + ], + ) } // ShallowInitBox // CopyForDeref @@ -693,35 +1019,46 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> operand: &mir::Operand<'vir>, ) -> ExprRet<'vir> { match operand { - mir::Operand::Copy(place) - | mir::Operand::Move(place) => self.encode_place(curr_ver, place), + mir::Operand::Copy(place) | mir::Operand::Move(place) => { + self.encode_place(curr_ver, place) + } mir::Operand::Constant(box constant) => { // TODO: duplicated from mir_impure! match constant.literal { - mir::ConstantKind::Val(const_val, const_ty) => { - match const_ty.kind() { - ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Tuple0_cons()"), - )), - ty::TyKind::Int(int_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Int_{}_cons({})", int_ty.name_str(), scalar_val.try_to_int(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Uint(uint_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Uint_{}_cons({})", uint_ty.name_str(), scalar_val.try_to_uint(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Bool => self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Bool_cons({})", const_val.try_to_bool().unwrap()), - )), - unsupported_ty => todo!("unsupported constant literal type: {unsupported_ty:?}"), + mir::ConstantKind::Val(const_val, const_ty) => match const_ty.kind() { + ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc( + ExprRetData::Todo(vir::vir_format!(self.vcx, "s_Tuple0_cons()")), + ), + ty::TyKind::Int(int_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Int_{}_cons({})", + int_ty.name_str(), + scalar_val.try_to_int(scalar_val.size()).unwrap() + ))) + } + ty::TyKind::Uint(uint_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Uint_{}_cons({})", + uint_ty.name_str(), + scalar_val.try_to_uint(scalar_val.size()).unwrap() + ))) } + ty::TyKind::Bool => self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Bool_cons({})", + const_val.try_to_bool().unwrap() + ))), + unsupported_ty => { + todo!("unsupported constant literal type: {unsupported_ty:?}") + } + }, + unsupported_literal => { + todo!("unsupported constant literal: {unsupported_literal:?}") } - unsupported_literal => todo!("unsupported constant literal: {unsupported_literal:?}"), } } } @@ -734,31 +1071,58 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> ) -> ExprRet<'vir> { // TODO: remove (debug) if !curr_ver.contains_key(&place.local) { - println!("unknown version of local! {}", place.local.as_usize()); - return self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "unknown_version_{}", place.local.as_usize()), - )); + tracing::error!("unknown version of local! {}", place.local.as_usize()); + return self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "unknown_version_{}", + place.local.as_usize() + ))); } let local = self.mk_local_ex(place.local, curr_ver[&place.local]); - if !place.projection.is_empty() { - // TODO: for now, assume this is a closure argument - assert_eq!(place.projection[0], mir::ProjectionElem::Deref); - assert!(matches!(place.projection[1], mir::ProjectionElem::Field(..))); - assert_eq!(place.projection[2], mir::ProjectionElem::Deref); - assert_eq!(place.projection.len(), 3); - let upvars = match self.body.local_decls[place.local].ty.peel_refs().kind() { - TyKind::Closure(_def_id, args) => args.as_closure().upvar_tys().collect::>().len(), - _ => unreachable!(), - }; - let tuple_ref = self.deps.require_ref::( - upvars, - ).unwrap(); - return match place.projection[1] { - mir::ProjectionElem::Field(idx, _) => tuple_ref.mk_elem(self.vcx, local, idx.as_usize()), - _ => todo!(), - }; + let mut partent_ty = self.body.local_decls[place.local].ty; + let mut expr = local; + + for elem in place.projection { + (partent_ty, expr) = self.encode_place_element(partent_ty, elem, expr); + } + + expr + } + + fn encode_place_element( + &mut self, + parent_ty: ty::Ty<'vir>, + elem: mir::PlaceElem<'vir>, + expr: ExprRet<'vir>, + ) -> (ty::Ty<'vir>, ExprRet<'vir>) { + let parent_ty = parent_ty.peel_refs(); + + match elem { + mir::ProjectionElem::Deref => (parent_ty, expr), + mir::ProjectionElem::Field(field_idx, field_ty) => { + let field_idx = field_idx.as_usize(); + match parent_ty.kind() { + TyKind::Closure(_def_id, args) => { + let upvars = args.as_closure().upvar_tys().collect::>().len(); + let tuple_ref = self.deps.require_ref::(upvars).unwrap(); + let tup = tuple_ref.mk_elem(self.vcx, expr, field_idx); + + (field_ty, tup) + } + _ => { + let local_encoded_ty = + self.deps.require_ref::(parent_ty).unwrap(); + let struct_like = local_encoded_ty.expect_structlike(); + let proj = struct_like.field_read[field_idx]; + + let app = self.vcx.mk_func_app(proj, self.vcx.alloc_slice(&[expr])); + + (field_ty, app) + } + } + } + _ => todo!("Unsupported ProjectionElem {:?}", elem), } - local } } diff --git a/prusti-encoder/src/encoders/mir_pure_function.rs b/prusti-encoder/src/encoders/mir_pure_function.rs index 60cab5ac4a9..e73c72894d7 100644 --- a/prusti-encoder/src/encoders/mir_pure_function.rs +++ b/prusti-encoder/src/encoders/mir_pure_function.rs @@ -1,8 +1,11 @@ -use prusti_rustc_interface::{middle::{mir, ty}, span::def_id::DefId}; +use prusti_rustc_interface::{ + middle::{mir, ty}, + span::def_id::DefId, +}; +use std::cell::RefCell; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use vir::Reify; -use std::cell::RefCell; use crate::encoders::{ MirPureEncoder, MirPureEncoderTask, SpecEncoder, SpecEncoderTask, TypeEncoder, @@ -16,13 +19,13 @@ pub enum MirFunctionEncoderError { #[derive(Clone, Debug)] pub struct MirFunctionEncoderOutputRef<'vir> { - pub method_name: &'vir str, + pub function_name: &'vir str, } impl<'vir> task_encoder::OutputRefAny<'vir> for MirFunctionEncoderOutputRef<'vir> {} #[derive(Clone, Debug)] pub struct MirFunctionEncoderOutput<'vir> { - pub method: vir::Function<'vir>, + pub function: vir::Function<'vir>, } thread_local! { @@ -72,45 +75,81 @@ impl TaskEncoder for MirFunctionEncoder { tracing::debug!("encoding {def_id:?}"); - let method_name = vir::vir_format!(vcx, "f_{}", vcx.tcx.item_name(def_id)); - deps.emit_output_ref::(def_id, MirFunctionEncoderOutputRef { method_name }); - - let local_defs = deps.require_local::( - def_id, - ).unwrap(); - let spec = deps.require_local::( - (def_id, true) - ).unwrap(); - - let func_args: Vec<_> = (1..=local_defs.arg_count).map(mir::Local::from).map(|arg| vcx.alloc(vir::LocalDeclData { - name: local_defs.locals[arg].local.name, - ty: local_defs.locals[arg].snapshot, - })).collect(); - - // Encode the body of the function - let expr = deps - .require_local::(MirPureEncoderTask { - encoding_depth: 0, - parent_def_id: def_id, - promoted: None, - param_env: vcx.tcx.param_env(def_id), - substs: ty::List::identity_for_item(vcx.tcx, def_id), - }) + let function_name = vir::vir_format!(vcx, "f_{}", vcx.tcx.item_name(def_id)); + deps.emit_output_ref::(*task_key, MirFunctionEncoderOutputRef { function_name }); + + let local_def_id = def_id.expect_local(); + let body = vcx + .body + .borrow_mut() + .get_impure_fn_body_identity(local_def_id); + + let spec = deps + .require_local::((def_id, true)) + .unwrap(); + + let mut func_args = Vec::with_capacity(body.arg_count); + + for (arg_idx0, arg_local) in body.args_iter().enumerate() { + let arg_idx = arg_idx0 + 1; // enumerate is 0 based but we want to start at 1 + + let arg_decl = body.local_decls.get(arg_local).unwrap(); + let arg_type_ref = deps.require_ref::(arg_decl.ty).unwrap(); + + let name_p = vir::vir_format!(vcx, "_{arg_idx}p"); + func_args.push(vcx.alloc(vir::LocalDeclData { + name: name_p, + ty: arg_type_ref.snapshot, + })); + } + + // TODO: dedup with mir_pure + let attrs = vcx.tcx.get_attrs_unchecked(def_id); + let is_trusted = attrs + .iter() + .filter(|attr| !attr.is_doc_comment()) + .map(|attr| attr.get_normal_item()) + .any(|item| { + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "trusted" + }); + + let expr = if is_trusted { + None + } else { + // Encode the body of the function + let expr = deps + .require_local::(MirPureEncoderTask { + encoding_depth: 0, + parent_def_id: def_id, + promoted: None, + param_env: vcx.tcx.param_env(def_id), + substs: ty::List::identity_for_item(vcx.tcx, def_id), + }) + .unwrap() + .expr; + + Some(expr.reify(vcx, (def_id, &spec.pre_args.split_last().unwrap().1))) + }; + + // Snapshot type of the return type + let ret = deps + .require_ref::(body.return_ty()) .unwrap() - .expr; - let expr = expr.reify(vcx, (def_id, &spec.pre_args[1..])); + .snapshot; tracing::debug!("finished {def_id:?}"); Ok(( MirFunctionEncoderOutput { - method: vcx.alloc(vir::FunctionData { - name: method_name, + function: vcx.alloc(vir::FunctionData { + name: function_name, args: vcx.alloc_slice(&func_args), - ret: local_defs.locals[mir::RETURN_PLACE].snapshot, + ret, pres: vcx.alloc_slice(&spec.pres), posts: vcx.alloc_slice(&spec.posts), - expr: Some(expr), + expr, }), }, (), diff --git a/prusti-encoder/src/encoders/mod.rs b/prusti-encoder/src/encoders/mod.rs index 0acdcb605aa..6afd9a4fc81 100644 --- a/prusti-encoder/src/encoders/mod.rs +++ b/prusti-encoder/src/encoders/mod.rs @@ -9,33 +9,13 @@ mod mir_pure_function; pub mod pure; pub mod local_def; -pub use generic::{ - GenericEncoder, -}; -pub use mir_builtin::{ - MirBuiltinEncoder, - MirBuiltinEncoderTask, -}; +pub use generic::GenericEncoder; +pub use mir_builtin::{MirBuiltinEncoder, MirBuiltinEncoderTask}; pub use mir_impure::MirImpureEncoder; -pub use mir_pure::{ - MirPureEncoder, - MirPureEncoderTask, -}; -pub use spec::{ - SpecEncoder, - SpecEncoderOutput, - SpecEncoderTask, -}; +pub use mir_pure::{MirPureEncoder, MirPureEncoderTask}; pub(super) use spec::{init_def_spec, with_def_spec}; -pub use typ::{ - TypeEncoder, - TypeEncoderOutputRef, - TypeEncoderOutput, -}; -pub use viper_tuple::{ - ViperTupleEncoder, - ViperTupleEncoderOutputRef, - ViperTupleEncoderOutput, -}; +pub use spec::{SpecEncoder, SpecEncoderOutput, SpecEncoderTask}; +pub use typ::{TypeEncoder, TypeEncoderOutput, TypeEncoderOutputRef}; +pub use viper_tuple::{ViperTupleEncoder, ViperTupleEncoderOutput, ViperTupleEncoderOutputRef}; pub use mir_pure_function::MirFunctionEncoder; diff --git a/prusti-encoder/src/encoders/pure/spec.rs b/prusti-encoder/src/encoders/pure/spec.rs index 462783f8f74..e59500b51cf 100644 --- a/prusti-encoder/src/encoders/pure/spec.rs +++ b/prusti-encoder/src/encoders/pure/spec.rs @@ -1,8 +1,11 @@ -use prusti_rustc_interface::{middle::{mir, ty}, span::def_id::DefId}; +use prusti_rustc_interface::{ + middle::{mir, ty}, + span::def_id::DefId, +}; +use std::cell::RefCell; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use vir::Reify; -use std::cell::RefCell; use crate::encoders::MirPureEncoder; pub struct MirSpecEncoder; @@ -20,8 +23,8 @@ thread_local! { impl TaskEncoder for MirSpecEncoder { type TaskDescription<'vir> = ( - DefId, // The function annotated with specs - bool, // If to encode as pure or not + DefId, // The function annotated with specs + bool, // If to encode as pure or not ); type OutputFullLocal<'vir> = MirSpecEncoderOutput<'vir>; @@ -61,81 +64,98 @@ impl TaskEncoder for MirSpecEncoder { let (def_id, pure) = *task_key; deps.emit_output_ref::(*task_key, ()); - let local_defs = deps.require_local::( - def_id, - ).unwrap(); - let specs = deps.require_local::( - crate::encoders::SpecEncoderTask { + let local_defs = deps + .require_local::(def_id) + .unwrap(); + let specs = deps + .require_local::(crate::encoders::SpecEncoderTask { def_id, - } - ).unwrap(); + }) + .unwrap(); vir::with_vcx(|vcx| { - let pre_args: Vec<_> = (0..=local_defs.arg_count) + let mut pre_args: Vec<_> = (1..=local_defs.arg_count) .map(mir::Local::from) .map(|local| { if pure { - if local.index() == 0 { - vcx.mk_local_ex(vir::vir_format!(vcx, "result")) - } else { - local_defs.locals[local].local_ex - } + local_defs.locals[local].local_ex } else { local_defs.locals[local].impure_snap } }) .collect(); + + pre_args.push(if pure { + vcx.mk_local_ex(vir::vir_format!(vcx, "result")) + } else { + local_defs.locals[mir::Local::from(0u32)].impure_snap + }); + let pre_args = vcx.alloc_slice(&pre_args); - let to_bool = deps.require_ref::( - vcx.tcx.types.bool, - ).unwrap().to_primitive.unwrap(); - - let pres = specs.pres.iter().map(|spec_def_id| { - let expr = deps.require_local::( - crate::encoders::MirPureEncoderTask { - encoding_depth: 0, - parent_def_id: *spec_def_id, - promoted: None, - param_env: vcx.tcx.param_env(spec_def_id), - substs: ty::List::identity_for_item(vcx.tcx, *spec_def_id), - } - ).unwrap().expr; - let expr = vcx.mk_func_app( - to_bool, - &[expr], - ); - expr.reify(vcx, (*spec_def_id, &pre_args[1..])) - }).collect::>>(); + let to_bool = deps + .require_ref::(vcx.tcx.types.bool) + .unwrap() + .to_primitive + .unwrap(); + + let pres = specs + .pres + .iter() + .map(|spec_def_id| { + let expr = deps + .require_local::( + crate::encoders::MirPureEncoderTask { + encoding_depth: 0, + parent_def_id: *spec_def_id, + promoted: None, + param_env: vcx.tcx.param_env(spec_def_id), + substs: ty::List::identity_for_item(vcx.tcx, *spec_def_id), + }, + ) + .unwrap() + .expr; + let expr = vcx.mk_func_app(to_bool, &[expr]); + expr.reify(vcx, (*spec_def_id, &pre_args.split_last().unwrap().1)) + }) + .collect::>>(); let post_args = if pure { pre_args } else { - let post_args: Vec<_> = pre_args.iter().enumerate().map(|(idx, arg)| { - if idx == 0 { - arg - } else { - vcx.alloc(vir::ExprData::Old(arg)) - } - }).collect(); + let post_args: Vec<_> = pre_args + .iter() + .enumerate() + .map(|(idx, arg)| { + if idx == pre_args.len() - 1 { + arg + } else { + vcx.alloc(vir::ExprData::Old(arg)) + } + }) + .collect(); vcx.alloc_slice(&post_args) }; - let posts = specs.posts.iter().map(|spec_def_id| { - let expr = deps.require_local::( - crate::encoders::MirPureEncoderTask { - encoding_depth: 0, - parent_def_id: *spec_def_id, - promoted: None, - param_env: vcx.tcx.param_env(spec_def_id), - substs: ty::List::identity_for_item(vcx.tcx, *spec_def_id), - } - ).unwrap().expr; - let expr = vcx.mk_func_app( - to_bool, - &[expr], - ); - expr.reify(vcx, (*spec_def_id, post_args)) - }).collect::>>(); + let posts = specs + .posts + .iter() + .map(|spec_def_id| { + let expr = deps + .require_local::( + crate::encoders::MirPureEncoderTask { + encoding_depth: 0, + parent_def_id: *spec_def_id, + promoted: None, + param_env: vcx.tcx.param_env(spec_def_id), + substs: ty::List::identity_for_item(vcx.tcx, *spec_def_id), + }, + ) + .unwrap() + .expr; + let expr = vcx.mk_func_app(to_bool, &[expr]); + expr.reify(vcx, (*spec_def_id, post_args)) + }) + .collect::>>(); let data = MirSpecEncoderOutput { pres, posts, diff --git a/prusti-encoder/src/encoders/spec.rs b/prusti-encoder/src/encoders/spec.rs index c6da14c0db3..a36f8237284 100644 --- a/prusti-encoder/src/encoders/spec.rs +++ b/prusti-encoder/src/encoders/spec.rs @@ -1,12 +1,6 @@ -use prusti_rustc_interface::{ - //middle::{mir, ty}, - span::def_id::DefId, -}; use prusti_interface::specs::typed::DefSpecificationMap; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use prusti_rustc_interface::span::def_id::DefId; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct SpecEncoder; @@ -42,7 +36,7 @@ pub fn init_def_spec(def_spec: DefSpecificationMap) { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct SpecEncoderTask { pub def_id: DefId, // ID of the function - // TODO: substs here? + // TODO: substs here? } impl TaskEncoder for SpecEncoder { @@ -57,7 +51,8 @@ impl TaskEncoder for SpecEncoder { type EncodingError = SpecEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, SpecEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, SpecEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -78,13 +73,16 @@ impl TaskEncoder for SpecEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { deps.emit_output_ref::(task_key.clone(), ()); vir::with_vcx(|vcx| { with_def_spec(|def_spec| { @@ -98,7 +96,7 @@ impl TaskEncoder for SpecEncoder { .and_then(|specs| specs.base_spec.posts.expect_empty_or_inherent()) .map(|specs| vcx.alloc_slice(specs)) .unwrap_or_default(); - Ok((SpecEncoderOutput { pres, posts, }, () )) + Ok((SpecEncoderOutput { pres, posts }, ())) }) }) } diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index cc240f5a7dd..99423544bf4 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -1,9 +1,6 @@ use prusti_rustc_interface::middle::ty; use rustc_type_ir::sty::TyKind; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct TypeEncoder; @@ -54,18 +51,22 @@ impl<'vir> TypeEncoderOutputRef<'vir> { // TODO: not great: store the TyKind as well? // or should this be a different task for TypeEncoder? match self.snapshot_name { - "s_Bool" => vir::with_vcx(|vcx| vcx.mk_func_app( + "s_Bool" => vir::with_vcx(|vcx| { + vcx.mk_func_app( self.from_primitive.unwrap(), &[vcx.alloc(vir::ExprData::Const( vcx.alloc(vir::ConstData::Bool(val != 0)), ))], - )), - name if name.starts_with("s_Int_") || name.starts_with("s_Uint_") => vir::with_vcx(|vcx| vcx.mk_func_app( - self.from_primitive.unwrap(), - &[vcx.alloc(vir::ExprData::Const( - vcx.alloc(vir::ConstData::Int(val)), - ))], - )), + ) + }), + name if name.starts_with("s_Int_") || name.starts_with("s_Uint_") => { + vir::with_vcx(|vcx| { + vcx.mk_func_app( + self.from_primitive.unwrap(), + &[vcx.alloc(vir::ExprData::Const(vcx.alloc(vir::ConstData::Int(val))))], + ) + }) + } k => todo!("unsupported type in expr_from_u128 {k:?}"), } } @@ -99,7 +100,8 @@ impl TaskEncoder for TypeEncoder { type EncodingError = TypeEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, TypeEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, TypeEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -146,13 +148,16 @@ impl TaskEncoder for TypeEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { fn mk_unreachable<'vir>( vcx: &'vir vir::VirCtxt, snapshot_name: &'vir str, @@ -173,8 +178,8 @@ impl TaskEncoder for TypeEncoder { field_name: &'vir str, ) -> vir::Predicate<'vir> { let predicate_body = vcx.alloc(vir::ExprData::AccField(vcx.alloc(vir::AccFieldData { - recv: vcx.mk_local_ex("self_p"), - field: field_name, + recv: vcx.mk_local_ex("self_p"), + field: field_name, }))); vir::vir_predicate! { vcx; predicate [predicate_name](self_p: Ref) { [predicate_body] } } } @@ -269,27 +274,20 @@ impl TaskEncoder for TypeEncoder { ) -> vir::Function<'vir> { let pred_app = vcx.alloc(vir::PredicateAppData { target: predicate_name, - args: vcx.alloc_slice(&[ - vcx.mk_local_ex("self"), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_ex("self")]), }); vcx.alloc(vir::FunctionData { name: vir::vir_format!(vcx, "{predicate_name}_snap"), - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self", &vir::TypeData::Ref)]), ret: snapshot_ty, - pres: vcx.alloc_slice(&[ - vcx.alloc(vir::ExprData::PredicateApp(pred_app)), - ]), + pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), posts: &[], - expr: field_name.map(|field_name| vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: vcx.alloc(vir::ExprData::Field( - vcx.mk_local_ex("self"), - field_name, - )), - })))), + expr: field_name.map(|field_name| { + vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: vcx.alloc(vir::ExprData::Field(vcx.mk_local_ex("self"), field_name)), + }))) + }), }) } fn mk_structlike<'vir>( @@ -299,10 +297,13 @@ impl TaskEncoder for TypeEncoder { name_s: &'vir str, name_p: &'vir str, field_ty_out: Vec>, - ) -> Result<::OutputFullLocal<'vir>, ( - ::EncodingError, - Option<::OutputFullDependency<'vir>>, - )> { + ) -> Result< + ::OutputFullLocal<'vir>, + ( + ::EncodingError, + Option<::OutputFullDependency<'vir>>, + ), + > { let mut field_read_names = Vec::new(); let mut field_write_names = Vec::new(); let mut field_projection_p_names = Vec::new(); @@ -315,50 +316,58 @@ impl TaskEncoder for TypeEncoder { let field_write_names = vcx.alloc_slice(&field_write_names); let field_projection_p_names = vcx.alloc_slice(&field_projection_p_names); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: name_s, - predicate_name: name_p, - from_primitive: None, - to_primitive: None, - snapshot: vcx.alloc(vir::TypeData::Domain(name_s)), - function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), - function_snap: vir::vir_format!(vcx, "{name_p}_snap"), - //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), - specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { - field_read: field_read_names, - field_write: field_write_names, - field_projection_p: field_projection_p_names, - }), - method_assign: vir::vir_format!(vcx, "assign_{name_p}"), - }); + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: name_s, + predicate_name: name_p, + from_primitive: None, + to_primitive: None, + snapshot: vcx.alloc(vir::TypeData::Domain(name_s)), + function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), + function_snap: vir::vir_format!(vcx, "{name_p}_snap"), + //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), + specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { + field_read: field_read_names, + field_write: field_write_names, + field_projection_p: field_projection_p_names, + }), + method_assign: vir::vir_format!(vcx, "assign_{name_p}"), + }, + ); let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); let mut funcs: Vec> = vec![]; let mut axioms: Vec> = vec![]; let cons_name = vir::vir_format!(vcx, "{name_s}_cons"); - funcs.push(vcx.alloc(vir::DomainFunctionData { - unique: false, - name: cons_name, - args: vcx.alloc_slice(&field_ty_out.iter() - .map(|field_ty_out| field_ty_out.snapshot) - .collect::>()), - ret: ty_s, - })); + funcs.push( + vcx.alloc(vir::DomainFunctionData { + unique: false, + name: cons_name, + args: vcx.alloc_slice( + &field_ty_out + .iter() + .map(|field_ty_out| field_ty_out.snapshot) + .collect::>(), + ), + ret: ty_s, + }), + ); let mut field_projection_p = Vec::new(); for (idx, ty_out) in field_ty_out.iter().enumerate() { let name_r = vir::vir_format!(vcx, "{name_s}_read_{idx}"); - funcs.push(vir::vir_domain_func! { vcx; function [name_r]([ty_s]): [ty_out.snapshot] }); + funcs.push( + vir::vir_domain_func! { vcx; function [name_r]([ty_s]): [ty_out.snapshot] }, + ); let name_w = vir::vir_format!(vcx, "{name_s}_write_{idx}"); funcs.push(vir::vir_domain_func! { vcx; function [name_w]([ty_s], [ty_out.snapshot]): [ty_s] }); field_projection_p.push(vcx.alloc(vir::FunctionData { name: vir::vir_format!(vcx, "{name_p}_field_{idx}"), - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self", &vir::TypeData::Ref)]), ret: &vir::TypeData::Ref, pres: &[], posts: &[], @@ -370,35 +379,30 @@ impl TaskEncoder for TypeEncoder { for (write_idx, write_ty_out) in field_ty_out.iter().enumerate() { for (read_idx, _read_ty_out) in field_ty_out.iter().enumerate() { axioms.push(vcx.alloc(vir::DomainAxiomData { - name: vir::vir_format!(vcx, "ax_{name_s}_write_{write_idx}_read_{read_idx}"), + name: vir::vir_format!( + vcx, + "ax_{name_s}_write_{write_idx}_read_{read_idx}" + ), expr: if read_idx == write_idx { vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { qvars: vcx.alloc_slice(&[ vcx.mk_local_decl("self", ty_s), vcx.mk_local_decl("val", write_ty_out.snapshot), ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[ - vcx.mk_local_ex("self"), - vcx.mk_local_ex("val"), - ], - )], - ), - ])]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + )], + )])]), body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpEq, lhs: vcx.mk_func_app( vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), &[vcx.mk_func_app( vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[ - vcx.mk_local_ex("self"), - vcx.mk_local_ex("val"), - ], + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], )], ), rhs: vcx.mk_local_ex("val"), @@ -410,28 +414,20 @@ impl TaskEncoder for TypeEncoder { vcx.mk_local_decl("self", ty_s), vcx.mk_local_decl("val", write_ty_out.snapshot), ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[ - vcx.mk_local_ex("self"), - vcx.mk_local_ex("val"), - ], - )], - ), - ])]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + )], + )])]), body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpEq, lhs: vcx.mk_func_app( vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), &[vcx.mk_func_app( vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[ - vcx.mk_local_ex("self"), - vcx.mk_local_ex("val"), - ], + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], )], ), rhs: vcx.mk_func_app( @@ -448,18 +444,21 @@ impl TaskEncoder for TypeEncoder { // constructor { let cons_qvars = vcx.alloc_slice( - &field_ty_out.iter() + &field_ty_out + .iter() .enumerate() - .map(|(idx, field_ty_out)| vcx.mk_local_decl( - vir::vir_format!(vcx, "f{idx}"), - field_ty_out.snapshot, - )) - .collect::>()); - let cons_args = field_ty_out.iter() + .map(|(idx, field_ty_out)| { + vcx.mk_local_decl( + vir::vir_format!(vcx, "f{idx}"), + field_ty_out.snapshot, + ) + }) + .collect::>(), + ); + let cons_args = field_ty_out + .iter() .enumerate() - .map(|(idx, _field_ty_out)| vcx.mk_local_ex( - vir::vir_format!(vcx, "f{idx}"), - )) + .map(|(idx, _field_ty_out)| vcx.mk_local_ex(vir::vir_format!(vcx, "f{idx}"))) .collect::>(); let cons_call = vcx.mk_func_app(cons_name, &cons_args); @@ -468,12 +467,7 @@ impl TaskEncoder for TypeEncoder { name: vir::vir_format!(vcx, "ax_{name_s}_cons_read_{read_idx}"), expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { qvars: cons_qvars.clone(), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[cons_call], - ), - ])]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpEq, lhs: vcx.mk_func_app( @@ -491,21 +485,19 @@ impl TaskEncoder for TypeEncoder { &field_ty_out .iter() .enumerate() - .map(|(idx, _field_ty_out)| vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{idx}"), - &[vcx.mk_local_ex("self")], - )) + .map(|(idx, _field_ty_out)| { + vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{idx}"), + &[vcx.mk_local_ex("self")], + ) + }) .collect::>(), ); axioms.push(vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_{name_s}_cons"), expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", ty_s), - ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - cons_call_with_reads, - ])]), + qvars: vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call_with_reads])]), body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpEq, lhs: cons_call_with_reads, @@ -517,28 +509,31 @@ impl TaskEncoder for TypeEncoder { // predicate let predicate = { - let expr = field_ty_out.iter() + let expr = field_ty_out + .iter() .enumerate() - .map(|(idx, field_ty_out)| vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc(vir::PredicateAppData { - target: field_ty_out.predicate_name, - args: vcx.alloc_slice(&[ - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_p}_field_{idx}"), - &[vcx.mk_local_ex("self_p")], - ), - ]), - })))) - .reduce(|base, field_expr| vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::And, - lhs: base, - rhs: field_expr, - })))) + .map(|(idx, field_ty_out)| { + vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc( + vir::PredicateAppData { + target: field_ty_out.predicate_name, + args: vcx.alloc_slice(&[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_p}_field_{idx}"), + &[vcx.mk_local_ex("self_p")], + )]), + }, + ))) + }) + .reduce(|base, field_expr| { + vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::And, + lhs: base, + rhs: field_expr, + }))) + }) .unwrap_or_else(|| vcx.mk_true()); vcx.alloc(vir::PredicateData { name: name_p, - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self_p", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), expr: Some(expr), }) }; @@ -554,40 +549,42 @@ impl TaskEncoder for TypeEncoder { function_snap: { let pred_app = vcx.alloc(vir::PredicateAppData { target: name_p, - args: vcx.alloc_slice(&[ - vcx.mk_local_ex("self_p"), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_ex("self_p")]), }); vcx.alloc(vir::FunctionData { name: vir::vir_format!(vcx, "{name_p}_snap"), - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self_p", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), ret: ty_s, - pres: vcx.alloc_slice(&[ - vcx.alloc(vir::ExprData::PredicateApp(pred_app)), - ]), + pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), posts: &[], - expr: Some(vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: vcx.mk_func_app( - cons_name, - vcx.alloc_slice(&field_ty_out - .iter() - .enumerate() - .map(|(idx, field_ty_out)| vcx.mk_func_app( - field_ty_out.function_snap, - &[ - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_p}_field_{idx}"), - &[vcx.mk_local_ex("self_p")], - ), - ], - )) - .collect::>(), - ), - ), - })))), + expr: Some( + vcx.alloc(vir::ExprData::Unfolding( + vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: vcx.mk_func_app( + cons_name, + vcx.alloc_slice( + &field_ty_out + .iter() + .enumerate() + .map(|(idx, field_ty_out)| { + vcx.mk_func_app( + field_ty_out.function_snap, + &[vcx.mk_func_app( + vir::vir_format!( + vcx, + "{name_p}_field_{idx}" + ), + &[vcx.mk_local_ex("self_p")], + )], + ) + }) + .collect::>(), + ), + ), + }), + )), + ), }) }, //method_refold: mk_refold(vcx, name_p, ty_s), @@ -599,38 +596,44 @@ impl TaskEncoder for TypeEncoder { vir::with_vcx(|vcx| match task_key.kind() { TyKind::Bool => { let ty_s = vcx.alloc(vir::TypeData::Domain("s_Bool")); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: "s_Bool", - predicate_name: "p_Bool", - to_primitive: Some("s_Bool_val"), - from_primitive: Some("s_Bool_cons"), - snapshot: ty_s, - function_unreachable: "s_Bool_unreachable", - function_snap: "p_Bool_snap", - //method_refold: "refold_p_Bool", - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: "assign_p_Bool", - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: "f_Bool", - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain s_Bool { - function s_Bool_cons(Bool): s_Bool; - function s_Bool_val(s_Bool): Bool; - axiom_inverse(s_Bool_val, s_Bool_cons, Bool); - } }, - predicate: mk_simple_predicate(vcx, "p_Bool", "f_Bool"), - function_unreachable: mk_unreachable(vcx, "s_Bool", ty_s), - function_snap: mk_snap(vcx, "p_Bool", "s_Bool", Some("f_Bool"), ty_s), - //method_refold: mk_refold(vcx, "p_Bool", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Bool", ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: "s_Bool", + predicate_name: "p_Bool", + to_primitive: Some("s_Bool_val"), + from_primitive: Some("s_Bool_cons"), + snapshot: ty_s, + function_unreachable: "s_Bool_unreachable", + function_snap: "p_Bool_snap", + //method_refold: "refold_p_Bool", + specifics: TypeEncoderOutputRefSub::Primitive, + method_assign: "assign_p_Bool", + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: "f_Bool", + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain s_Bool { + function s_Bool_cons(Bool): s_Bool; + function s_Bool_val(s_Bool): Bool; + axiom_inverse(s_Bool_val, s_Bool_cons, Bool); + axiom_inverse(s_Bool_cons, s_Bool_val, s_Bool); + } }, + predicate: mk_simple_predicate(vcx, "p_Bool", "f_Bool"), + function_unreachable: mk_unreachable(vcx, "s_Bool", ty_s), + function_snap: mk_snap(vcx, "p_Bool", "s_Bool", Some("f_Bool"), ty_s), + //method_refold: mk_refold(vcx, "p_Bool", ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, "p_Bool", ty_s), + }, + (), + )) } - TyKind::Int(_) | - TyKind::Uint(_) => { + TyKind::Int(_) | TyKind::Uint(_) => { let (sign, name_str) = match task_key.kind() { TyKind::Int(kind) => ("Int", kind.name_str()), TyKind::Uint(kind) => ("Uint", kind.name_str()), @@ -642,83 +645,110 @@ impl TaskEncoder for TypeEncoder { let name_val = vir::vir_format!(vcx, "{name_s}_val"); let name_field = vir::vir_format!(vcx, "f_{name_s}"); let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: name_s, - predicate_name: name_p, - to_primitive: Some(name_val), - from_primitive: Some(name_cons), - snapshot: ty_s, - function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), - function_snap: vir::vir_format!(vcx, "{name_p}_snap"), - //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: vir::vir_format!(vcx, "assign_{name_p}"), - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: name_field, - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain [name_s] { - function [name_cons](Int): [ty_s]; - function [name_val]([ty_s]): Int; - axiom_inverse([name_val], [name_cons], Int); - } }, - predicate: mk_simple_predicate(vcx, name_p, name_field), - function_unreachable: mk_unreachable(vcx, name_s, ty_s), - function_snap: mk_snap(vcx, name_p, name_s, Some(name_field), ty_s), - //method_refold: mk_refold(vcx, name_p, ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, name_p, ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: name_s, + predicate_name: name_p, + to_primitive: Some(name_val), + from_primitive: Some(name_cons), + snapshot: ty_s, + function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), + function_snap: vir::vir_format!(vcx, "{name_p}_snap"), + //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), + specifics: TypeEncoderOutputRefSub::Primitive, + method_assign: vir::vir_format!(vcx, "assign_{name_p}"), + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: name_field, + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain [name_s] { + function [name_cons](Int): [ty_s]; + function [name_val]([ty_s]): Int; + axiom_inverse([name_val], [name_cons], Int); + axiom_inverse([name_cons], [name_val], [ty_s]); + } }, + predicate: mk_simple_predicate(vcx, name_p, name_field), + function_unreachable: mk_unreachable(vcx, name_s, ty_s), + function_snap: mk_snap(vcx, name_p, name_s, Some(name_field), ty_s), + //method_refold: mk_refold(vcx, name_p, ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, name_p, ty_s), + }, + (), + )) } TyKind::Tuple(tys) if tys.len() == 0 => { let ty_s = vcx.alloc(vir::TypeData::Domain("s_Tuple0")); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: "s_Tuple0", - predicate_name: "p_Tuple0", - to_primitive: None, - from_primitive: None, - snapshot: ty_s, - function_unreachable: "s_Tuple0_unreachable", - function_snap: "p_Tuple0_snap", - //method_refold: "refold_p_Tuple0", - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: vir::vir_format!(vcx, "assign_p_Tuple0"), - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: vir::vir_format!(vcx, "f_Tuple0"), - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain s_Tuple0 { - function s_Tuple0_cons(): [ty_s]; - } }, - predicate: vir::vir_predicate! { vcx; predicate p_Tuple0(self_p: Ref) }, - function_unreachable: mk_unreachable(vcx, "s_Tuple0", ty_s), - function_snap: mk_snap(vcx, "p_Tuple0", "s_Tuple0", None, ty_s), - //method_refold: mk_refold(vcx, "p_Tuple0", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Tuple0", ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: "s_Tuple0", + predicate_name: "p_Tuple0", + to_primitive: None, + from_primitive: None, + snapshot: ty_s, + function_unreachable: "s_Tuple0_unreachable", + function_snap: "p_Tuple0_snap", + //method_refold: "refold_p_Tuple0", + specifics: TypeEncoderOutputRefSub::Primitive, + method_assign: vir::vir_format!(vcx, "assign_p_Tuple0"), + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: vir::vir_format!(vcx, "f_Tuple0"), + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain s_Tuple0 { + function s_Tuple0_cons(): [ty_s]; + } }, + predicate: vir::vir_predicate! { vcx; predicate p_Tuple0(self_p: Ref) }, + function_unreachable: mk_unreachable(vcx, "s_Tuple0", ty_s), + function_snap: mk_snap(vcx, "p_Tuple0", "s_Tuple0", None, ty_s), + //method_refold: mk_refold(vcx, "p_Tuple0", ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, "p_Tuple0", ty_s), + }, + (), + )) } TyKind::Tuple(tys) => { let field_ty_out = tys .iter() - .map(|ty| deps.require_ref::(ty).unwrap()) + .map(|ty| { + deps.require_ref::(ty) + .unwrap() + }) .collect::>(); - // TODO: name the tuple according to its types, or make generic? - Ok((mk_structlike( - vcx, - deps, - task_key, - vir::vir_format!(vcx, "s_Tuple{}", tys.len()), - vir::vir_format!(vcx, "p_Tuple{}", tys.len()), - field_ty_out, - )?, ())) + // TODO: Properly name the tuple according to its types, or make generic? + + // Temporary fix to make it possisble to have multiple tuples of the same size with different element types + let tmp_ty_name = field_ty_out + .iter() + .map(|e| e.snapshot_name) + .collect::>() + .join("_"); + + Ok(( + mk_structlike( + vcx, + deps, + task_key, + vir::vir_format!(vcx, "s_Tuple{}_{}", tys.len(), tmp_ty_name), + vir::vir_format!(vcx, "p_Tuple{}_{}", tys.len(), tmp_ty_name), + field_ty_out, + )?, + (), + )) /* let ty_len = tys.len(); @@ -755,76 +785,97 @@ impl TaskEncoder for TypeEncoder { } TyKind::Param(_param) => { - let param_out = deps.require_ref::(()).unwrap(); + let param_out = deps + .require_ref::(()) + .unwrap(); let ty_s = vcx.alloc(vir::TypeData::Domain("s_Param")); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: param_out.snapshot_param_name, - predicate_name: param_out.predicate_param_name, - to_primitive: None, - from_primitive: None, - snapshot: ty_s, - function_unreachable: "s_Param_unreachable", - function_snap: "p_Param_snap", - //method_refold: "refold_p_Param", - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: vir::vir_format!(vcx, "assign_p_Bool"), - }); - Ok((TypeEncoderOutput { - fields: &[], - snapshot: vir::vir_domain! { vcx; domain s_ParamTodo { // TODO: should not be emitted -- make outputs vectors - } }, - predicate: vir::vir_predicate! { vcx; predicate p_ParamTodo(self_p: Ref) }, - function_unreachable: mk_unreachable(vcx, "p_Param", ty_s), - function_snap: mk_snap(vcx, "p_Param", "s_Param", None, ty_s), - //method_refold: mk_refold(vcx, "p_Param", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Param", ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: param_out.snapshot_param_name, + predicate_name: param_out.predicate_param_name, + to_primitive: None, + from_primitive: None, + snapshot: ty_s, + function_unreachable: "s_Param_unreachable", + function_snap: "p_Param_snap", + //method_refold: "refold_p_Param", + specifics: TypeEncoderOutputRefSub::Primitive, + method_assign: vir::vir_format!(vcx, "assign_p_Bool"), + }, + ); + Ok(( + TypeEncoderOutput { + fields: &[], + snapshot: vir::vir_domain! { vcx; domain s_ParamTodo { // TODO: should not be emitted -- make outputs vectors + } }, + predicate: vir::vir_predicate! { vcx; predicate p_ParamTodo(self_p: Ref) }, + function_unreachable: mk_unreachable(vcx, "p_Param", ty_s), + function_snap: mk_snap(vcx, "p_Param", "s_Param", None, ty_s), + //method_refold: mk_refold(vcx, "p_Param", ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, "p_Param", ty_s), + }, + (), + )) } TyKind::Adt(adt_def, substs) if adt_def.is_struct() => { - println!("encoding ADT {adt_def:?} with substs {substs:?}"); + tracing::debug!("encoding ADT {adt_def:?} with substs {substs:?}"); let substs = ty::List::identity_for_item(vcx.tcx, adt_def.did()); - let field_ty_out = adt_def.all_fields() - .map(|field| deps.require_ref::(field.ty(vcx.tcx, substs)).unwrap()) + let field_ty_out = adt_def + .all_fields() + .map(|field| { + deps.require_ref::(field.ty(vcx.tcx, substs)) + .unwrap() + }) .collect::>(); let did_name = vcx.tcx.item_name(adt_def.did()).to_ident_string(); - Ok((mk_structlike( - vcx, - deps, - task_key, - vir::vir_format!(vcx, "s_Adt_{did_name}"), - vir::vir_format!(vcx, "p_Adt_{did_name}"), - field_ty_out, - )?, ())) + Ok(( + mk_structlike( + vcx, + deps, + task_key, + vir::vir_format!(vcx, "s_Adt_{did_name}"), + vir::vir_format!(vcx, "p_Adt_{did_name}"), + field_ty_out, + )?, + (), + )) } TyKind::Never => { let ty_s = vcx.alloc(vir::TypeData::Domain("s_Never")); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: "s_Never", - predicate_name: "p_Never", - to_primitive: None, - from_primitive: None, - snapshot: ty_s, - function_unreachable: "s_Never_unreachable", - function_snap: "p_Never_snap", - //method_refold: "refold_p_Never", - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: vir::vir_format!(vcx, "assign_p_Never"), - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: vir::vir_format!(vcx, "f_Never"), - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain s_Never {} }, - predicate: vir::vir_predicate! { vcx; predicate p_Never(self_p: Ref) }, - function_unreachable: mk_unreachable(vcx, "s_Never", ty_s), - function_snap: mk_snap(vcx, "p_Never", "s_Never", None, ty_s), - //method_refold: mk_refold(vcx, "p_Never", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Never", ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: "s_Never", + predicate_name: "p_Never", + to_primitive: None, + from_primitive: None, + snapshot: ty_s, + function_unreachable: "s_Never_unreachable", + function_snap: "p_Never_snap", + //method_refold: "refold_p_Never", + specifics: TypeEncoderOutputRefSub::Primitive, + method_assign: vir::vir_format!(vcx, "assign_p_Never"), + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: vir::vir_format!(vcx, "f_Never"), + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain s_Never {} }, + predicate: vir::vir_predicate! { vcx; predicate p_Never(self_p: Ref) }, + function_unreachable: mk_unreachable(vcx, "s_Never", ty_s), + function_snap: mk_snap(vcx, "p_Never", "s_Never", None, ty_s), + //method_refold: mk_refold(vcx, "p_Never", ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, "p_Never", ty_s), + }, + (), + )) } //_ => Err((TypeEncoderError::UnsupportedType, None)), unsupported_type => todo!("type not supported: {unsupported_type:?}"), diff --git a/prusti-encoder/src/encoders/viper_tuple.rs b/prusti-encoder/src/encoders/viper_tuple.rs index 87bee1488cc..174c07a7ade 100644 --- a/prusti-encoder/src/encoders/viper_tuple.rs +++ b/prusti-encoder/src/encoders/viper_tuple.rs @@ -1,8 +1,5 @@ -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; use std::cell::RefCell; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct ViperTupleEncoder; @@ -19,7 +16,7 @@ impl<'vir> ViperTupleEncoderOutputRef<'vir> { pub fn mk_cons( &self, vcx: &'vir vir::VirCtxt<'vir>, - elems: &[vir::ExprGen<'vir, Curr, Next>] + elems: &[vir::ExprGen<'vir, Curr, Next>], ) -> vir::ExprGen<'vir, Curr, Next> { if self.elem_count == 1 { return elems[0]; @@ -58,7 +55,8 @@ impl TaskEncoder for ViperTupleEncoder { type EncodingError = (); fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, ViperTupleEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, ViperTupleEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -76,69 +74,89 @@ impl TaskEncoder for ViperTupleEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { vir::with_vcx(|vcx| { let domain_name = vir::vir_format!(vcx, "Tuple_{task_key}"); let cons_name = vir::vir_format!(vcx, "Tuple_{task_key}_cons"); let elem_names = (0..*task_key) .map(|idx| vir::vir_format!(vcx, "Tuple_{task_key}_elem_{idx}")) .collect::>(); - deps.emit_output_ref::(*task_key, ViperTupleEncoderOutputRef { - elem_count: *task_key, - domain_name, - cons_name, - elem_names: vcx.alloc_slice(&elem_names), - }); + deps.emit_output_ref::( + *task_key, + ViperTupleEncoderOutputRef { + elem_count: *task_key, + domain_name, + cons_name, + elem_names: vcx.alloc_slice(&elem_names), + }, + ); let typaram_names = (0..*task_key) .map(|idx| vir::vir_format!(vcx, "T{idx}")) .collect::>(); - let typaram_tys = vcx.alloc_slice(&typaram_names.iter() - .map(|name| vcx.alloc(vir::TypeData::Domain(name))) - .collect::>()); + let typaram_tys = vcx.alloc_slice( + &typaram_names + .iter() + .map(|name| vcx.alloc(vir::TypeData::Domain(name))) + .collect::>(), + ); let domain_ty = vcx.alloc(vir::TypeData::DomainParams(domain_name, typaram_tys)); let qvars_names = (0..*task_key) .map(|idx| vir::vir_format!(vcx, "elem{idx}")) .collect::>(); - let qvars_decl = vcx.alloc_slice(&(0..*task_key) - .map(|idx| vcx.mk_local_decl(qvars_names[idx], typaram_tys[idx])) - .collect::>()); + let qvars_decl = vcx.alloc_slice( + &(0..*task_key) + .map(|idx| vcx.mk_local_decl(qvars_names[idx], typaram_tys[idx])) + .collect::>(), + ); let qvars_ex = (0..*task_key) .map(|idx| vcx.mk_local_ex(qvars_names[idx])) .collect::>(); let cons_call = vcx.mk_func_app( cons_name, - &qvars_names.iter() + &qvars_names + .iter() .map(|qvar| vcx.mk_local_ex(qvar)) .collect::>(), ); let axiom = vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_Tuple_{task_key}_elem"), - expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: qvars_decl, - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), - body: vcx.mk_conj(&(0..*task_key) - .map(|idx| vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: vcx.mk_func_app(elem_names[idx], &[cons_call]), - rhs: qvars_ex[idx], - })))) - .collect::>()), - }))), + expr: vcx.alloc(vir::ExprData::Forall( + vcx.alloc(vir::ForallData { + qvars: qvars_decl, + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), + body: vcx.mk_conj( + &(0..*task_key) + .map(|idx| { + vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::CmpEq, + lhs: vcx.mk_func_app(elem_names[idx], &[cons_call]), + rhs: qvars_ex[idx], + }))) + }) + .collect::>(), + ), + }), + )), }); let elem_args = vcx.alloc_slice(&[domain_ty]); let mut functions = (0..*task_key) - .map(|idx| vcx.alloc(vir::DomainFunctionData { - unique: false, - name: elem_names[idx], - args: elem_args, - ret: typaram_tys[idx], - })) + .map(|idx| { + vcx.alloc(vir::DomainFunctionData { + unique: false, + name: elem_names[idx], + args: elem_args, + ret: typaram_tys[idx], + }) + }) .collect::>(); functions.push(vcx.alloc(vir::DomainFunctionData { unique: false, @@ -146,14 +164,21 @@ impl TaskEncoder for ViperTupleEncoder { args: typaram_tys, ret: domain_ty, })); - Ok((ViperTupleEncoderOutput { - domain: Some(vcx.alloc(vir::DomainData { - name: domain_name, - typarams: vcx.alloc_slice(&typaram_names), - axioms: vcx.alloc_slice(&[axiom]), - functions: vcx.alloc_slice(&functions), - })), - }, ())) + Ok(( + ViperTupleEncoderOutput { + domain: Some(vcx.alloc(vir::DomainData { + name: domain_name, + typarams: vcx.alloc_slice(&typaram_names), + axioms: if task_key == &0 { + &[] + } else { + vcx.alloc_slice(&[axiom]) + }, + functions: vcx.alloc_slice(&functions), + })), + }, + (), + )) }) } } diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index cdf1c061312..a91c3cf160c 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -9,11 +9,8 @@ extern crate rustc_type_ir; mod encoders; -use prusti_interface::environment::EnvBody; -use prusti_rustc_interface::{ - middle::ty, - hir, -}; +use prusti_interface::{environment::EnvBody, specs::typed::SpecificationItem}; +use prusti_rustc_interface::{hir, middle::ty}; /* struct MirBodyPureEncoder; @@ -100,7 +97,7 @@ impl<'vir, 'tcx> TaskEncoder<'vir, 'tcx> for MirBodyImpureEncoder<'vir, 'tcx> { ); // TaskKey, OutputRef same as above type OutputFull = vir::Method<'vir>; -} +} struct MirTyEncoder<'vir, 'tcx>(PhantomData<&'vir ()>, PhantomData<&'tcx ()>); impl<'vir, 'tcx> TaskEncoder<'vir, 'tcx> for MirTyEncoder<'vir, 'tcx> { @@ -127,7 +124,7 @@ pub fn test_entrypoint<'tcx>( // TODO: this should be a "crate" encoder, which will deps.require all the methods in the crate for def_id in tcx.hir_crate_items(()).definitions() { - //println!("item: {def_id:?}"); + tracing::debug!("test_entrypoint item: {def_id:?}"); let kind = tcx.def_kind(def_id); //println!(" kind: {:?}", kind); /*if !format!("{def_id:?}").contains("foo") { @@ -142,19 +139,34 @@ pub fn test_entrypoint<'tcx>( let res = crate::encoders::MirImpureEncoder::encode(def_id.to_def_id()); assert!(res.is_ok()); - let kind = crate::encoders::with_def_spec(|def_spec| - def_spec + let (is_pure, is_trusted) = crate::encoders::with_def_spec(|def_spec| { + let base_spec = def_spec .get_proc_spec(&def_id.to_def_id()) - .map(|e| e.base_spec.kind) - ); + .map(|e| &e.base_spec); + + let is_pure = base_spec + .and_then(|kind| kind.kind.is_pure().ok()) + .unwrap_or_default(); + let is_trusted = matches!( + base_spec.map(|spec| spec.trusted), + Some(SpecificationItem::Inherent(true,)) + ); + (is_pure, is_trusted) + }); + + if !(is_trusted && is_pure) { + let res = crate::encoders::MirImpureEncoder::encode(def_id.to_def_id()); + assert!(res.is_ok()); + } - if kind.and_then(|kind| kind.is_pure().ok()).unwrap_or_default() { - tracing::debug!("Encoding {def_id:?} as a pure function because it is labeled as pure"); + if is_pure { + tracing::debug!( + "Encoding {def_id:?} as a pure function because it is labeled as pure" + ); let res = crate::encoders::MirFunctionEncoder::encode(def_id.to_def_id()); assert!(res.is_ok()); } - /* match res { Ok(res) => println!("ok: {:?}", res), @@ -162,7 +174,7 @@ pub fn test_entrypoint<'tcx>( }*/ } unsupported_item_kind => { - println!("another item: {unsupported_item_kind:?}"); + tracing::debug!("unsupported item: {unsupported_item_kind:?}"); } } } @@ -182,7 +194,7 @@ pub fn test_entrypoint<'tcx>( header(&mut viper_code, "functions"); for output in crate::encoders::MirFunctionEncoder::all_outputs() { - viper_code.push_str(&format!("{:?}\n", output.method)); + viper_code.push_str(&format!("{:?}\n", output.function)); } header(&mut viper_code, "MIR builtins"); @@ -222,20 +234,20 @@ pub fn test_entrypoint<'tcx>( std::fs::write("local-testing/simple.vpr", viper_code).unwrap(); - vir::with_vcx(|vcx| vcx.alloc(vir::ProgramData { - fields: &[], - domains: &[], - predicates: &[], - functions: vcx.alloc_slice(&[ - vcx.alloc(vir::FunctionData { + vir::with_vcx(|vcx| { + vcx.alloc(vir::ProgramData { + fields: &[], + domains: &[], + predicates: &[], + functions: vcx.alloc_slice(&[vcx.alloc(vir::FunctionData { name: "test_function", args: &[], ret: &vir::TypeData::Bool, pres: &[], posts: &[], expr: None, - }), - ]), - methods: &[], - })) + })]), + methods: &[], + }) + }) } diff --git a/prusti-interface/src/environment/body.rs b/prusti-interface/src/environment/body.rs index 90593363e48..84873731aae 100644 --- a/prusti-interface/src/environment/body.rs +++ b/prusti-interface/src/environment/body.rs @@ -136,7 +136,7 @@ impl<'tcx> EnvBody<'tcx> { BodyWithBorrowckFacts { body: MirBody(Rc::new(body_with_facts.body)), - // borrowck_facts: Rc::new(facts), + // borrowck_facts: Rc::new(facts), } } @@ -174,12 +174,18 @@ impl<'tcx> EnvBody<'tcx> { { let monomorphised = if let Some(caller_def_id) = caller_def_id { let param_env = self.tcx.param_env(caller_def_id); - self.tcx - .subst_and_normalize_erasing_regions(substs, param_env, ty::EarlyBinder::bind(body.0)) + self.tcx.subst_and_normalize_erasing_regions( + substs, + param_env, + ty::EarlyBinder::bind(body.0), + ) } else { let param_env = self.tcx.param_env(def_id); - self.tcx - .subst_and_normalize_erasing_regions(substs, param_env, ty::EarlyBinder::bind(body.0)) + self.tcx.subst_and_normalize_erasing_regions( + substs, + param_env, + ty::EarlyBinder::bind(body.0), + ) }; v.insert(MirBody(monomorphised)).clone() } else { @@ -201,7 +207,11 @@ impl<'tcx> EnvBody<'tcx> { /// with the given type substitutions. /// /// FIXME: This function is called only in pure contexts??? - pub fn get_impure_fn_body(&self, def_id: LocalDefId, substs: GenericArgsRef<'tcx>) -> MirBody<'tcx> { + pub fn get_impure_fn_body( + &self, + def_id: LocalDefId, + substs: GenericArgsRef<'tcx>, + ) -> MirBody<'tcx> { if let Some(body) = self.get_monomorphised(def_id.to_def_id(), substs, None) { return body; } @@ -328,7 +338,7 @@ impl<'tcx> EnvBody<'tcx> { pub(crate) fn load_pure_fn_body(&mut self, def_id: LocalDefId) { assert!(!self.pure_fns.local.contains_key(&def_id)); - let body = Self::load_local_mir( self.tcx, def_id); + let body = Self::load_local_mir(self.tcx, def_id); self.pure_fns.local.insert(def_id, body); let bwbf = Self::load_local_mir_with_facts(self.tcx, def_id); // Also add to `impure_fns` since we'll also be encoding this as impure diff --git a/prusti-interface/src/environment/query.rs b/prusti-interface/src/environment/query.rs index 77c0319a7f2..174eb9843a9 100644 --- a/prusti-interface/src/environment/query.rs +++ b/prusti-interface/src/environment/query.rs @@ -7,7 +7,10 @@ use prusti_rustc_interface::{ hir::hir_id::HirId, middle::{ hir::map::Map, - ty::{self, Binder, BoundConstness, GenericArgsRef, ImplPolarity, ParamEnv, TraitPredicate, TyCtxt}, + ty::{ + self, Binder, BoundConstness, GenericArgsRef, ImplPolarity, ParamEnv, TraitPredicate, + TyCtxt, + }, }, span::{ def_id::{DefId, LocalDefId}, @@ -206,66 +209,66 @@ impl<'tcx> EnvQuery<'tcx> { impl_method_substs: GenericArgsRef<'tcx>, // what are the substs on the call? ) -> Option<(ProcedureDefId, GenericArgsRef<'tcx>)> { todo!() /* - let impl_method_def_id = impl_method_def_id.into_param(); - let impl_def_id = self.tcx.impl_of_method(impl_method_def_id)?; - let trait_ref = self.tcx.impl_trait_ref(impl_def_id)?.skip_binder(); - - // At this point, we know that the given method: - // - belongs to an impl block and - // - the impl block implements a trait. - // For the `get_assoc_item` call, we therefore `unwrap`, as not finding - // the associated item would be a (compiler) internal error. - let trait_def_id = trait_ref.def_id; - let trait_method_def_id = self - .get_assoc_item(trait_def_id, impl_method_def_id) - .unwrap() - .def_id; - - // sanity check: have we been given the correct number of substs? - let identity_impl_method = self.identity_substs(impl_method_def_id); - assert_eq!(identity_impl_method.len(), impl_method_substs.len()); - - // Given: - // ``` - // trait Trait { - // fn f(); - // } - // struct Struct { ... } - // impl Trait for Struct { - // fn f() { ... } - // } - // ``` - // - // The various substs look like this: - // - identity for Trait: `[Self, Tp]` - // - identity for Trait::f: `[Self, Tp, Tx, Ty, Tz]` - // - substs of the impl trait ref: `[Struct, A]` - // - identity for the impl: `[A, B, C]` - // - identity for Struct::f: `[A, B, C, X, Y, Z]` - // - // What we need is a substs suitable for a call to Trait::f, which is in - // this case `[Struct, A, X, Y, Z]`. More generally, it is the - // concatenation of the trait ref substs with the identity of the impl - // method after skipping the identity of the impl. - // - // We also need to subst the prefix (`[Struct, A]` in the example - // above) with call substs, so that we get the trait's type parameters - // more precisely. We can do this directly with `impl_method_substs` - // because they contain the substs for the `impl` block as a prefix. - let call_trait_substs = - ty::EarlyBinder(trait_ref.substs).subst(self.tcx, impl_method_substs); - let impl_substs = self.identity_substs(impl_def_id); - let trait_method_substs = self.tcx.mk_substs_from_iter( - call_trait_substs - .iter() - .chain(impl_method_substs.iter().skip(impl_substs.len())), - ); + let impl_method_def_id = impl_method_def_id.into_param(); + let impl_def_id = self.tcx.impl_of_method(impl_method_def_id)?; + let trait_ref = self.tcx.impl_trait_ref(impl_def_id)?.skip_binder(); + + // At this point, we know that the given method: + // - belongs to an impl block and + // - the impl block implements a trait. + // For the `get_assoc_item` call, we therefore `unwrap`, as not finding + // the associated item would be a (compiler) internal error. + let trait_def_id = trait_ref.def_id; + let trait_method_def_id = self + .get_assoc_item(trait_def_id, impl_method_def_id) + .unwrap() + .def_id; + + // sanity check: have we been given the correct number of substs? + let identity_impl_method = self.identity_substs(impl_method_def_id); + assert_eq!(identity_impl_method.len(), impl_method_substs.len()); + + // Given: + // ``` + // trait Trait { + // fn f(); + // } + // struct Struct { ... } + // impl Trait for Struct { + // fn f() { ... } + // } + // ``` + // + // The various substs look like this: + // - identity for Trait: `[Self, Tp]` + // - identity for Trait::f: `[Self, Tp, Tx, Ty, Tz]` + // - substs of the impl trait ref: `[Struct, A]` + // - identity for the impl: `[A, B, C]` + // - identity for Struct::f: `[A, B, C, X, Y, Z]` + // + // What we need is a substs suitable for a call to Trait::f, which is in + // this case `[Struct, A, X, Y, Z]`. More generally, it is the + // concatenation of the trait ref substs with the identity of the impl + // method after skipping the identity of the impl. + // + // We also need to subst the prefix (`[Struct, A]` in the example + // above) with call substs, so that we get the trait's type parameters + // more precisely. We can do this directly with `impl_method_substs` + // because they contain the substs for the `impl` block as a prefix. + let call_trait_substs = + ty::EarlyBinder(trait_ref.substs).subst(self.tcx, impl_method_substs); + let impl_substs = self.identity_substs(impl_def_id); + let trait_method_substs = self.tcx.mk_substs_from_iter( + call_trait_substs + .iter() + .chain(impl_method_substs.iter().skip(impl_substs.len())), + ); - // sanity check: do we now have the correct number of substs? - let identity_trait_method = self.identity_substs(trait_method_def_id); - assert_eq!(trait_method_substs.len(), identity_trait_method.len()); + // sanity check: do we now have the correct number of substs? + let identity_trait_method = self.identity_substs(trait_method_def_id); + assert_eq!(trait_method_substs.len(), identity_trait_method.len()); - Some((trait_method_def_id, trait_method_substs))*/ + Some((trait_method_def_id, trait_method_substs))*/ } /// Given some procedure `proc_def_id` which is called, this method returns the actual method which will be executed when `proc_def_id` is defined on a trait. diff --git a/prusti-interface/src/specs/encoder.rs b/prusti-interface/src/specs/encoder.rs index 21b7b6f4531..ae42d305b6c 100644 --- a/prusti-interface/src/specs/encoder.rs +++ b/prusti-interface/src/specs/encoder.rs @@ -18,10 +18,7 @@ pub struct DefSpecsEncoder<'tcx> { } impl<'tcx> DefSpecsEncoder<'tcx> { - pub fn new( - tcx: TyCtxt<'tcx>, - path: &std::path::PathBuf, - ) -> std::io::Result { + pub fn new(tcx: TyCtxt<'tcx>, path: &std::path::PathBuf) -> std::io::Result { Ok(DefSpecsEncoder { tcx, opaque: opaque::FileEncoder::new(path)?, diff --git a/prusti-interface/src/specs/external.rs b/prusti-interface/src/specs/external.rs index 608a5cd8d62..4cfc2f60c73 100644 --- a/prusti-interface/src/specs/external.rs +++ b/prusti-interface/src/specs/external.rs @@ -4,10 +4,7 @@ use prusti_rustc_interface::{ def_id::{DefId, LocalDefId}, intravisit::{self, Visitor}, }, - middle::{ - hir::map::Map, - ty::GenericArgsRef, - }, + middle::{hir::map::Map, ty::GenericArgsRef}, span::Span, }; diff --git a/prusti-interface/src/utils.rs b/prusti-interface/src/utils.rs index c9de49e76ed..c5e8d7f80db 100644 --- a/prusti-interface/src/utils.rs +++ b/prusti-interface/src/utils.rs @@ -10,7 +10,10 @@ use prusti_rustc_interface::{ abi::FieldIdx, ast::ast, data_structures::fx::FxHashSet, - middle::{mir, ty::{self, TyCtxt}}, + middle::{ + mir, + ty::{self, TyCtxt}, + }, }; use std::borrow::Borrow; @@ -95,8 +98,7 @@ pub fn expand_struct_place<'tcx>( for (index, field_def) in variant.fields.iter().enumerate() { if Some(index) != without_field { let field = FieldIdx::from_usize(index); - let field_place = - tcx.mk_place_field(place, field, field_def.ty(tcx, substs)); + let field_place = tcx.mk_place_field(place, field, field_def.ty(tcx, substs)); places.push(field_place); } } @@ -133,7 +135,6 @@ pub fn expand_struct_place<'tcx>( places } - /// Pop the last projection from the place and return the new place with the popped element. pub fn try_pop_one_level<'tcx>( tcx: TyCtxt<'tcx>, diff --git a/prusti/src/callbacks.rs b/prusti/src/callbacks.rs index beb48f290f8..102da95c4fb 100644 --- a/prusti/src/callbacks.rs +++ b/prusti/src/callbacks.rs @@ -9,16 +9,12 @@ use prusti_rustc_interface::{ borrowck::consumers, data_structures::steal::Steal, driver::Compilation, + hir::{def::DefKind, def_id::LocalDefId}, index::IndexVec, interface::{interface::Compiler, Config, Queries}, - hir::{def::DefKind, def_id::LocalDefId}, middle::{ mir, - query::{ - queries::mir_borrowck::ProvidedValue as MirBorrowck, - ExternProviders, - Providers - }, + query::{queries::mir_borrowck::ProvidedValue as MirBorrowck, ExternProviders, Providers}, ty::TyCtxt, }, session::{EarlyErrorHandler, Session}, @@ -174,7 +170,7 @@ impl prusti_rustc_interface::driver::Callbacks for PrustiCompilerCalls { test_free_pcs(&mir, tcx); } } else {*/ - verify(env, def_spec); + verify(env, def_spec); //} } }); diff --git a/prusti/src/driver.rs b/prusti/src/driver.rs index 358dad21e0c..7a58e953631 100644 --- a/prusti/src/driver.rs +++ b/prusti/src/driver.rs @@ -56,18 +56,20 @@ fn report_prusti_ice(info: &panic::PanicInfo<'_>, bug_report_url: &str) { prusti_rustc_interface::driver::DEFAULT_LOCALE_RESOURCES.to_vec(), false, ); - let emitter = Box::new(prusti_rustc_interface::errors::emitter::EmitterWriter::stderr( - prusti_rustc_interface::errors::ColorConfig::Auto, - None, - None, - fallback_bundle, - false, - false, - None, - false, - false, - prusti_rustc_interface::errors::TerminalUrl::Auto, - )); + let emitter = Box::new( + prusti_rustc_interface::errors::emitter::EmitterWriter::stderr( + prusti_rustc_interface::errors::ColorConfig::Auto, + None, + None, + fallback_bundle, + false, + false, + None, + false, + false, + prusti_rustc_interface::errors::TerminalUrl::Auto, + ), + ); let handler = prusti_rustc_interface::errors::Handler::with_emitter(true, None, emitter); // a .span_bug or .bug call has already printed what it wants to print. diff --git a/prusti/src/verifier.rs b/prusti/src/verifier.rs index 842265c59c7..a047c1fb038 100644 --- a/prusti/src/verifier.rs +++ b/prusti/src/verifier.rs @@ -14,40 +14,36 @@ pub fn verify(env: Environment<'_>, def_spec: typed::DefSpecificationMap) { if env.diagnostic.has_errors() { warn!("The compiler reported an error, so the program will not be verified."); } else { - debug!("Prepare verification task...");/* - // TODO: can we replace `get_annotated_procedures` with information - // that is already in `def_spec`? - let (annotated_procedures, types) = env.get_annotated_procedures_and_types(); - let verification_task = VerificationTask { - procedures: annotated_procedures, - types, - }; - debug!("Verification task: {:?}", &verification_task); + debug!("Prepare verification task..."); /* + // TODO: can we replace `get_annotated_procedures` with information + // that is already in `def_spec`? + let (annotated_procedures, types) = env.get_annotated_procedures_and_types(); + let verification_task = VerificationTask { + procedures: annotated_procedures, + types, + }; + debug!("Verification task: {:?}", &verification_task); - user::message(format!( - "Verification of {} items...", - verification_task.procedures.len() - )); + user::message(format!( + "Verification of {} items...", + verification_task.procedures.len() + )); - if config::print_collected_verification_items() { - println!( - "Collected verification items {}:", - verification_task.procedures.len() - ); - for procedure in &verification_task.procedures { - println!( - "procedure: {} at {:?}", - env.name.get_item_def_path(*procedure), - env.query.get_def_span(procedure) - ); - } - }*/ + if config::print_collected_verification_items() { + println!( + "Collected verification items {}:", + verification_task.procedures.len() + ); + for procedure in &verification_task.procedures { + println!( + "procedure: {} at {:?}", + env.name.get_item_def_path(*procedure), + env.query.get_def_span(procedure) + ); + } + }*/ - let program = prusti_encoder::test_entrypoint( - env.tcx(), - env.body, - def_spec, - ); + let program = prusti_encoder::test_entrypoint(env.tcx(), env.body, def_spec); //viper::verify(program); //let verification_result = diff --git a/task-encoder/src/lib.rs b/task-encoder/src/lib.rs index d92cea432de..fa36c724f9d 100644 --- a/task-encoder/src/lib.rs +++ b/task-encoder/src/lib.rs @@ -8,7 +8,6 @@ impl<'vir> OutputRefAny<'vir> for () {} pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { // None, // indicated by absence in the cache - /// Task was enqueued but not yet started. Enqueued, @@ -29,9 +28,7 @@ pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { }, /// An error occurred when enqueing the task. - ErrorEnqueue { - error: TaskEncoderError, - }, + ErrorEnqueue { error: TaskEncoderError }, /// An error occurred when encoding the task. The full "local" encoding is /// not available. However, tasks which depend on this task may still @@ -49,16 +46,12 @@ pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { /// Cache for a task encoder. See `TaskEncoderCacheState` for a description of /// the possible values in the encoding process. -pub type Cache<'vir, E> = LinkedHashMap< - ::TaskKey<'vir>, - TaskEncoderCacheState<'vir, E>, ->; +pub type Cache<'vir, E> = + LinkedHashMap<::TaskKey<'vir>, TaskEncoderCacheState<'vir, E>>; pub type CacheRef<'vir, E> = RefCell>; -pub type CacheStatic = LinkedHashMap< - ::TaskKey<'static>, - TaskEncoderCacheState<'static, E>, ->; +pub type CacheStatic = + LinkedHashMap<::TaskKey<'static>, TaskEncoderCacheState<'static, E>>; pub type CacheStaticRef = RefCell>; /* pub struct TaskEncoderOutput<'vir, E: TaskEncoder>( @@ -85,7 +78,8 @@ pub enum TaskEncoderError { } impl std::fmt::Debug for TaskEncoderError - where ::EncodingError: std::fmt::Debug +where + ::EncodingError: std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut helper = f.debug_struct("TaskEncoderError"); @@ -118,30 +112,21 @@ impl<'vir> TaskEncoderDependencies<'vir> { pub fn require_ref( &mut self, task: ::TaskDescription<'vir>, - ) -> Result< - ::OutputRef<'vir>, - TaskEncoderError, - > { + ) -> Result<::OutputRef<'vir>, TaskEncoderError> { E::encode_ref(task) } pub fn require_local( &mut self, task: ::TaskDescription<'vir>, - ) -> Result< - ::OutputFullLocal<'vir>, - TaskEncoderError, - > { + ) -> Result<::OutputFullLocal<'vir>, TaskEncoderError> { E::encode(task).map(|(_output_ref, output_local, _output_dep)| output_local) } pub fn require_dep( &mut self, task: ::TaskDescription<'vir>, - ) -> Result< - ::OutputFullDependency<'vir>, - TaskEncoderError, - > { + ) -> Result<::OutputFullDependency<'vir>, TaskEncoderError> { E::encode(task).map(|(_output_ref, _output_local, output_dep)| output_dep) } @@ -150,10 +135,12 @@ impl<'vir> TaskEncoderDependencies<'vir> { task_key: E::TaskKey<'vir>, output_ref: E::OutputRef<'vir>, ) { - assert!(E::with_cache(move |cache| matches!(cache.borrow_mut().insert( - task_key, - TaskEncoderCacheState::Started { output_ref }, - ), Some(TaskEncoderCacheState::Enqueued)))); + assert!(E::with_cache(move |cache| matches!( + cache + .borrow_mut() + .insert(task_key, TaskEncoderCacheState::Started { output_ref },), + Some(TaskEncoderCacheState::Enqueued) + ))); } } @@ -166,7 +153,8 @@ pub trait TaskEncoder { /// for example if the description should be normalised or some non-trivial /// resolution needs to happen. In other words, multiple descriptions may /// lead to the same key and hence the same output. - type TaskKey<'vir>: std::hash::Hash + Eq + Clone + std::fmt::Debug = Self::TaskDescription<'vir>; + type TaskKey<'vir>: std::hash::Hash + Eq + Clone + std::fmt::Debug = + Self::TaskDescription<'vir>; /// A reference to an encoded item. Should be non-unit for tasks which can /// be "referred" to from other parts of a program, as opposed to tasks @@ -190,7 +178,9 @@ pub trait TaskEncoder { /// Enters the given function with a reference to the cache for this /// encoder. fn with_cache<'vir, F, R>(f: F) -> R - where Self: 'vir, F: FnOnce(&'vir CacheRef<'vir, Self>) -> R; + where + Self: 'vir, + F: FnOnce(&'vir CacheRef<'vir, Self>) -> R; //fn get_all_outputs() -> Self::CacheRef<'vir> { // todo!() @@ -199,7 +189,8 @@ pub trait TaskEncoder { //} fn enqueue<'vir>(task: Self::TaskDescription<'vir>) - where Self: 'vir + where + Self: 'vir, { let task_key = Self::task_to_key(&task); let task_key_clone = task_key.clone(); // TODO: remove? @@ -209,28 +200,32 @@ pub trait TaskEncoder { } // enqueue, expecting no entry (we just checked) - assert!(Self::with_cache(move |cache| cache.borrow_mut().insert( - task_key, - TaskEncoderCacheState::Enqueued, - ).is_none())); + assert!(Self::with_cache(move |cache| cache + .borrow_mut() + .insert(task_key, TaskEncoderCacheState::Enqueued,) + .is_none())); } - fn encode_ref<'vir>(task: Self::TaskDescription<'vir>) -> Result< - Self::OutputRef<'vir>, - TaskEncoderError, - > - where Self: 'vir + fn encode_ref<'vir>( + task: Self::TaskDescription<'vir>, + ) -> Result, TaskEncoderError> + where + Self: 'vir, { let task_key = Self::task_to_key(&task); // is there an output ref available already? let task_key_clone = task_key.clone(); - if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { - Some(TaskEncoderCacheState::Started { output_ref, .. }) - | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) - | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), - _ => None, - }) { + if let Some(output_ref) = + Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { + Some(TaskEncoderCacheState::Started { output_ref, .. }) + | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) + | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => { + Some(output_ref.clone()) + } + _ => None, + }) + { return Ok(output_ref); } @@ -252,24 +247,34 @@ pub trait TaskEncoder { Self::encode(task)?; let task_key_clone = task_key.clone(); - if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { - Some(TaskEncoderCacheState::Started { output_ref, .. }) - | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) - | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), - _ => None, - }) { + if let Some(output_ref) = + Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { + Some(TaskEncoderCacheState::Started { output_ref, .. }) + | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) + | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => { + Some(output_ref.clone()) + } + _ => None, + }) + { return Ok(output_ref); } panic!("output ref not found after encoding") // TODO: error? } - fn encode<'vir>(task: Self::TaskDescription<'vir>) -> Result<( - Self::OutputRef<'vir>, - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), TaskEncoderError> - where Self: 'vir + fn encode<'vir>( + task: Self::TaskDescription<'vir>, + ) -> Result< + ( + Self::OutputRef<'vir>, + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + TaskEncoderError, + > + where + Self: 'vir, { let task_key = Self::task_to_key(&task); @@ -290,8 +295,9 @@ pub trait TaskEncoder { output_local.clone(), output_dep.clone(), ))), - TaskEncoderCacheState::Enqueued | TaskEncoderCacheState::Started { .. } => - panic!("Encoding already started or enqueued"), + TaskEncoderCacheState::Enqueued | TaskEncoderCacheState::Started { .. } => { + panic!("Encoding already started or enqueued") + } }, None => { // enqueue @@ -314,160 +320,172 @@ pub trait TaskEncoder { match encode_result { Ok((output_local, output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { - output_ref: output_ref.clone(), - deps, - output_local: output_local.clone(), - output_dep: output_dep.clone(), - })); - Ok(( - output_ref, - output_local, - output_dep, - )) + Self::with_cache(|cache| { + cache.borrow_mut().insert( + task_key, + TaskEncoderCacheState::Encoded { + output_ref: output_ref.clone(), + deps, + output_local: output_local.clone(), + output_dep: output_dep.clone(), + }, + ) + }); + Ok((output_ref, output_local, output_dep)) } Err((err, maybe_output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::ErrorEncode { - output_ref: output_ref.clone(), - deps, - error: TaskEncoderError::EncodingError(err.clone()), - output_dep: maybe_output_dep, - })); + Self::with_cache(|cache| { + cache.borrow_mut().insert( + task_key, + TaskEncoderCacheState::ErrorEncode { + output_ref: output_ref.clone(), + deps, + error: TaskEncoderError::EncodingError(err.clone()), + output_dep: maybe_output_dep, + }, + ) + }); Err(TaskEncoderError::EncodingError(err)) } } } /* - /// Given a task description for this encoder, enqueue it and return the - /// reference to the output. If the task is already enqueued, the output - /// reference already exists. - fn encode<'vir>(task: Self::TaskDescription<'vir>) -> Self::OutputRef<'vir> - where Self: 'vir - { - let task_key = Self::task_to_key(&task); - let task_key_clone = task_key.clone(); - if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { - Some(TaskEncoderCacheState::Enqueued { output_ref }) - | Some(TaskEncoderCacheState::Started { output_ref, .. }) - | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) - | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), - _ => None, - }) { - return output_ref; + /// Given a task description for this encoder, enqueue it and return the + /// reference to the output. If the task is already enqueued, the output + /// reference already exists. + fn encode<'vir>(task: Self::TaskDescription<'vir>) -> Self::OutputRef<'vir> + where Self: 'vir + { + let task_key = Self::task_to_key(&task); + let task_key_clone = task_key.clone(); + if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { + Some(TaskEncoderCacheState::Enqueued { output_ref }) + | Some(TaskEncoderCacheState::Started { output_ref, .. }) + | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) + | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), + _ => None, + }) { + return output_ref; + } + let task_ref = Self::task_to_output_ref(&task); + let task_key_clone = task_key.clone(); + let task_ref_clone = task_ref.clone(); + assert!(Self::with_cache(move |cache| cache.borrow_mut().insert( + task_key_clone, + TaskEncoderCacheState::Enqueued { output_ref: task_ref_clone }, + ).is_none())); + task_ref } - let task_ref = Self::task_to_output_ref(&task); - let task_key_clone = task_key.clone(); - let task_ref_clone = task_ref.clone(); - assert!(Self::with_cache(move |cache| cache.borrow_mut().insert( - task_key_clone, - TaskEncoderCacheState::Enqueued { output_ref: task_ref_clone }, - ).is_none())); - task_ref - } - - // TODO: this function should not be needed - fn encode_eager<'vir>(task: Self::TaskDescription<'vir>) -> Result<( - Self::OutputRef<'vir>, - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), TaskEncoderError> - where Self: 'vir - { - let task_key = Self::task_to_key(&task); - // enqueue - let output_ref = Self::encode(task); - // process - Self::encode_full(task_key) - .map(|(output_full_local, output_full_dep)| (output_ref, output_full_local, output_full_dep)) - } - /// Given a task key, fully encode the given task. If this task was already - /// finished, the encoding is not repeated. If this task was enqueued, but - /// not finished, return a `CyclicError`. - fn encode_full<'vir>(task_key: Self::TaskKey<'vir>) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), TaskEncoderError> - where Self: 'vir - { - let mut output_ref_opt = None; - let ret = Self::with_cache(|cache| { - // should be queued by now - match cache.borrow().get(&task_key).unwrap() { - TaskEncoderCacheState::Enqueued { output_ref } => { - output_ref_opt = Some(output_ref.clone()); - None - } - TaskEncoderCacheState::Started { .. } => Some(Err(TaskEncoderError::CyclicError)), - TaskEncoderCacheState::Encoded { output_local, output_dep, .. } => - Some(Ok(( - output_local.clone(), - output_dep.clone(), - ))), - TaskEncoderCacheState::ErrorEncode { error, .. } => - Some(Err(error.clone())), - } - }); - if let Some(ret) = ret { - return ret; + // TODO: this function should not be needed + fn encode_eager<'vir>(task: Self::TaskDescription<'vir>) -> Result<( + Self::OutputRef<'vir>, + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), TaskEncoderError> + where Self: 'vir + { + let task_key = Self::task_to_key(&task); + // enqueue + let output_ref = Self::encode(task); + // process + Self::encode_full(task_key) + .map(|(output_full_local, output_full_dep)| (output_ref, output_full_local, output_full_dep)) } - let output_ref = output_ref_opt.unwrap(); - let mut deps: TaskEncoderDependencies<'vir> = Default::default(); - match Self::do_encode_full(&task_key, &mut deps) { - Ok((output_local, output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { - output_ref: output_ref.clone(), - deps, - output_local: output_local.clone(), - output_dep: output_dep.clone(), - })); - Ok(( - output_local, - output_dep, - )) + /// Given a task key, fully encode the given task. If this task was already + /// finished, the encoding is not repeated. If this task was enqueued, but + /// not finished, return a `CyclicError`. + fn encode_full<'vir>(task_key: Self::TaskKey<'vir>) -> Result<( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), TaskEncoderError> + where Self: 'vir + { + let mut output_ref_opt = None; + let ret = Self::with_cache(|cache| { + // should be queued by now + match cache.borrow().get(&task_key).unwrap() { + TaskEncoderCacheState::Enqueued { output_ref } => { + output_ref_opt = Some(output_ref.clone()); + None + } + TaskEncoderCacheState::Started { .. } => Some(Err(TaskEncoderError::CyclicError)), + TaskEncoderCacheState::Encoded { output_local, output_dep, .. } => + Some(Ok(( + output_local.clone(), + output_dep.clone(), + ))), + TaskEncoderCacheState::ErrorEncode { error, .. } => + Some(Err(error.clone())), + } + }); + if let Some(ret) = ret { + return ret; } - Err((err, maybe_output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::ErrorEncode { - output_ref: output_ref.clone(), - deps, - error: TaskEncoderError::EncodingError(err.clone()), - output_dep: maybe_output_dep, - })); - Err(TaskEncoderError::EncodingError(err)) + let output_ref = output_ref_opt.unwrap(); + + let mut deps: TaskEncoderDependencies<'vir> = Default::default(); + match Self::do_encode_full(&task_key, &mut deps) { + Ok((output_local, output_dep)) => { + Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { + output_ref: output_ref.clone(), + deps, + output_local: output_local.clone(), + output_dep: output_dep.clone(), + })); + Ok(( + output_local, + output_dep, + )) + } + Err((err, maybe_output_dep)) => { + Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::ErrorEncode { + output_ref: output_ref.clone(), + deps, + error: TaskEncoderError::EncodingError(err.clone()), + output_dep: maybe_output_dep, + })); + Err(TaskEncoderError::EncodingError(err)) + } } } - } -*/ + */ /// Given a task description, create a key for storing it in the cache. - fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> // Result< - Self::TaskKey<'vir>;//, - // Self::EnqueueingError, - //> -/* - /// Given a task description, create a reference to the output. - fn task_to_output_ref<'vir>(task: &Self::TaskDescription<'vir>) -> Self::OutputRef<'vir>; -*/ + fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir>; //, + // Self::EnqueueingError, + //> + /* + /// Given a task description, create a reference to the output. + fn task_to_output_ref<'vir>(task: &Self::TaskDescription<'vir>) -> Self::OutputRef<'vir>; + */ fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )>; + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + >; fn all_outputs<'vir>() -> Vec> - where Self: 'vir + where + Self: 'vir, { Self::with_cache(|cache| { let mut ret = vec![]; for (_task_key, cache_state) in cache.borrow().iter() { - match cache_state { // TODO: make this into an iterator chain - TaskEncoderCacheState::Encoded { output_local, .. } => ret.push(output_local.clone()), + match cache_state { + // TODO: make this into an iterator chain + TaskEncoderCacheState::Encoded { output_local, .. } => { + ret.push(output_local.clone()) + } _ => {} } } diff --git a/test-crates/src/main.rs b/test-crates/src/main.rs index ec61eddfb13..bb4ebdddc17 100644 --- a/test-crates/src/main.rs +++ b/test-crates/src/main.rs @@ -4,6 +4,10 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use clap::Parser; +use log::{error, info, warn, LevelFilter}; +use rustwide::{cmd, logging, logging::LogStorage, Crate, Toolchain, Workspace, WorkspaceBuilder}; +use serde::Deserialize; use std::{ env, error::Error, @@ -11,10 +15,6 @@ use std::{ path::{Path, PathBuf}, process::Command, }; -use log::{error, info, warn, LevelFilter}; -use rustwide::{cmd, logging, logging::LogStorage, Crate, Toolchain, Workspace, WorkspaceBuilder}; -use serde::Deserialize; -use clap::Parser; /// How a crate should be tested. All tests use `check_panics=false`, `check_overflows=false` and /// `skip_unsupported_features=true`. @@ -145,17 +145,21 @@ struct Args { shard_index: usize, } -fn attempt_fetch(krate: &Crate, workspace: &Workspace, num_retries: u8) -> Result<(), failure::Error> { +fn attempt_fetch( + krate: &Crate, + workspace: &Workspace, + num_retries: u8, +) -> Result<(), failure::Error> { let mut i = 0; while i < num_retries + 1 { if let Err(err) = krate.fetch(workspace) { warn!("Error fetching crate {}: {}", krate, err); if i == num_retries { // Last attempt failed, return the error - return Err(err) + return Err(err); } } else { - return Ok(()) + return Ok(()); } i += 1; } @@ -228,7 +232,12 @@ fn main() -> Result<(), Box> { .collect::, _>>()? .into_iter() .filter(|record| record.name.contains(&args.filter_crate_name)) - .map(|record| (Crate::crates_io(&record.name, &record.version), record.test_kind)) + .map(|record| { + ( + Crate::crates_io(&record.name, &record.version), + record.test_kind, + ) + }) .collect(); info!("There are {} crates in total.", crates_list.len()); @@ -239,8 +248,11 @@ fn main() -> Result<(), Box> { // List of crates on which Prusti succeed. let mut successful_crates = vec![]; - let shard_crates_list: Vec<&(Crate, TestKind)> = crates_list.iter().skip(args.shard_index) - .step_by(args.num_shards).collect(); + let shard_crates_list: Vec<&(Crate, TestKind)> = crates_list + .iter() + .skip(args.shard_index) + .step_by(args.num_shards) + .collect(); info!( "Iterate over the {} crates of the shard {}/{}...", shard_crates_list.len(), @@ -248,7 +260,13 @@ fn main() -> Result<(), Box> { args.num_shards, ); for (index, (krate, test_kind)) in shard_crates_list.iter().enumerate() { - info!("Crate {}/{}: {}, test kind: {:?}", index, shard_crates_list.len(), krate, test_kind); + info!( + "Crate {}/{}: {}, test kind: {:?}", + index, + shard_crates_list.len(), + krate, + test_kind + ); if let TestKind::Skip = test_kind { info!("Skip crate"); @@ -300,11 +318,7 @@ fn main() -> Result<(), Box> { guest_prusti_home, cmd::MountKind::ReadOnly, ) - .mount( - host_viper_home, - guest_viper_home, - cmd::MountKind::ReadOnly, - ) + .mount(host_viper_home, guest_viper_home, cmd::MountKind::ReadOnly) .mount(host_z3_home, guest_z3_home, cmd::MountKind::ReadOnly) .mount(&host_java_home, &guest_java_home, cmd::MountKind::ReadOnly); for java_policy_path in &host_java_policies { @@ -317,7 +331,8 @@ fn main() -> Result<(), Box> { let verification_status = build_dir.build(&toolchain, krate, sandbox).run(|build| { logging::capture(&storage, || { - let mut command = build.cmd(&cargo_prusti) + let mut command = build + .cmd(&cargo_prusti) .env("RUST_BACKTRACE", "1") .env("PRUSTI_ASSERT_TIMEOUT", "60000") .env("PRUSTI_LOG_DIR", "/tmp/prusti_log") @@ -328,13 +343,14 @@ fn main() -> Result<(), Box> { .env("PRUSTI_SKIP_UNSUPPORTED_FEATURES", "true"); match test_kind { TestKind::NoErrorsWithUnreachableUnsupportedCode => { - command = command.env("PRUSTI_ALLOW_UNREACHABLE_UNSUPPORTED_CODE", "true"); + command = + command.env("PRUSTI_ALLOW_UNREACHABLE_UNSUPPORTED_CODE", "true"); } - TestKind::NoErrors => {}, + TestKind::NoErrors => {} TestKind::NoCrash => { // Report internal errors as warnings command = command.env("PRUSTI_INTERNAL_ERRORS_AS_WARNINGS", "true"); - }, + } TestKind::Skip => { unreachable!(); } @@ -383,7 +399,11 @@ fn main() -> Result<(), Box> { } // Panic - assert!(failed_crates.is_empty(), "Failed to verify {} crates", failed_crates.len()); + assert!( + failed_crates.is_empty(), + "Failed to verify {} crates", + failed_crates.len() + ); Ok(()) } diff --git a/tracing/Cargo.toml b/tracing/Cargo.toml index d96a0d59852..3c9026321eb 100644 --- a/tracing/Cargo.toml +++ b/tracing/Cargo.toml @@ -8,5 +8,5 @@ edition = "2021" doctest = false [dependencies] -tracing = "0.1" +tracing = { version = "0.1", features = ["log"] } proc-macro-tracing = { path = "proc-macro-tracing" } diff --git a/tracing/proc-macro-tracing/src/lib.rs b/tracing/proc-macro-tracing/src/lib.rs index 044b9e8b47c..a7687d87613 100644 --- a/tracing/proc-macro-tracing/src/lib.rs +++ b/tracing/proc-macro-tracing/src/lib.rs @@ -9,7 +9,7 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use syn::{ItemFn, ReturnType, token::Paren, Type, TypeParen}; +use syn::{token::Paren, ItemFn, ReturnType, Type, TypeParen}; // Using `tracing::instrument` without this crate (from the vanilla tracing crate) // causes RA to not pick up the return type properly atm - it colours wrong and @@ -21,17 +21,22 @@ pub fn instrument(attr: TokenStream, tokens: TokenStream) -> TokenStream { let (attr, tokens): (TokenStream2, TokenStream2) = (attr.into(), tokens.into()); if let Ok(mut item) = syn::parse2::(tokens.clone()) { if let ReturnType::Type(a, ty) = item.sig.output { - let new_ty = Type::Paren(TypeParen { paren_token: Paren::default(), elem: ty }); + let new_ty = Type::Paren(TypeParen { + paren_token: Paren::default(), + elem: ty, + }); item.sig.output = ReturnType::Type(a, Box::new(new_ty)); } quote! { #[tracing::tracing_instrument(#attr)] #item - }.into() + } + .into() } else { quote! { #[tracing::tracing_instrument(#attr)] #tokens - }.into() + } + .into() } } diff --git a/tracing/src/lib.rs b/tracing/src/lib.rs index e03deef40ed..e9705a3c56b 100644 --- a/tracing/src/lib.rs +++ b/tracing/src/lib.rs @@ -4,5 +4,5 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -pub use tracing::{*, instrument as tracing_instrument}; pub use proc_macro_tracing::instrument; +pub use tracing::{instrument as tracing_instrument, *}; diff --git a/vir-proc-macro/src/lib.rs b/vir-proc-macro/src/lib.rs index b99512fc16e..a072ea905be 100644 --- a/vir-proc-macro/src/lib.rs +++ b/vir-proc-macro/src/lib.rs @@ -3,7 +3,9 @@ use quote::quote; use syn::{parse_macro_input, DeriveInput}; fn is_reify_copy(field: &syn::Field) -> bool { - field.attrs.iter() + field + .attrs + .iter() .filter_map(|attr| match &attr.meta { syn::Meta::Path(p) => Some(&p.segments), _ => None, @@ -29,13 +31,11 @@ pub fn derive_reify(input: TokenStream) -> TokenStream { }; TokenStream::from(match input.data { syn::Data::Struct(syn::DataStruct { - fields: syn::Fields::Named(syn::FieldsNamed { - named, - .. - }), + fields: syn::Fields::Named(syn::FieldsNamed { named, .. }), .. }) => { - let compute_fields = named.iter() + let compute_fields = named + .iter() .filter_map(|field| { let name = field.ident.as_ref().unwrap(); if is_reify_copy(field) { @@ -47,7 +47,8 @@ pub fn derive_reify(input: TokenStream) -> TokenStream { } }) .collect::>(); - let fields = named.iter() + let fields = named + .iter() .map(|field| { let name = field.ident.as_ref().unwrap(); if is_reify_copy(field) { @@ -70,25 +71,21 @@ pub fn derive_reify(input: TokenStream) -> TokenStream { #slice_impl } } - syn::Data::Enum(syn::DataEnum { - variants, - .. - }) => { - let variants = variants.iter() + syn::Data::Enum(syn::DataEnum { variants, .. }) => { + let variants = variants + .iter() .map(|variant| { let variant_name = &variant.ident; match &variant.fields { - syn::Fields::Unnamed(syn::FieldsUnnamed { - unnamed, - .. - }) => { + syn::Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => { let vbinds = (0..unnamed.len()) .map(|idx| quote::format_ident!("v{idx}")) .collect::>(); let obinds = (0..unnamed.len()) .map(|idx| quote::format_ident!("opt{idx}")) .collect::>(); - let compute_fields = unnamed.iter() + let compute_fields = unnamed + .iter() .enumerate() .filter_map(|(idx, field)| { if is_reify_copy(field) { @@ -102,7 +99,8 @@ pub fn derive_reify(input: TokenStream) -> TokenStream { } }) .collect::>(); - let fields = unnamed.iter() + let fields = unnamed + .iter() .enumerate() .map(|(idx, field)| { let vbind = &vbinds[idx]; @@ -120,7 +118,7 @@ pub fn derive_reify(input: TokenStream) -> TokenStream { vcx.alloc(#name::#variant_name(#(#fields),*)) } } - }, + } syn::Fields::Unit => quote! { #name::#variant_name => &#name::#variant_name }, diff --git a/vir/src/context.rs b/vir/src/context.rs index be684a11d76..ef0778c30eb 100644 --- a/vir/src/context.rs +++ b/vir/src/context.rs @@ -1,11 +1,8 @@ -use std::cell::RefCell; use prusti_interface::environment::EnvBody; use prusti_rustc_interface::middle::ty; +use std::cell::RefCell; -use crate::data::*; -use crate::gendata::*; -use crate::genrefs::*; -use crate::refs::*; +use crate::{data::*, gendata::*, genrefs::*, refs::*}; /// The VIR context is a data structure used throughout the encoding process. pub struct VirCtxt<'tcx> { @@ -20,13 +17,11 @@ pub struct VirCtxt<'tcx> { pub span_stack: Vec, // TODO: span stack // TODO: error positions? - /// The compiler's typing context. This allows convenient access to most /// of the compiler's APIs. pub tcx: ty::TyCtxt<'tcx>, pub body: RefCell>, - } impl<'tcx> VirCtxt<'tcx> { @@ -48,25 +43,23 @@ impl<'tcx> VirCtxt<'tcx> { &*self.arena.alloc_str(val) } -/* pub fn alloc_slice<'a, T: Copy>(&'tcx self, val: &'a [T]) -> &'tcx [T] { - &*self.arena.alloc_slice_copy(val) - }*/ + /* pub fn alloc_slice<'a, T: Copy>(&'tcx self, val: &'a [T]) -> &'tcx [T] { + &*self.arena.alloc_slice_copy(val) + }*/ pub fn alloc_slice(&self, val: &[T]) -> &[T] { &*self.arena.alloc_slice_copy(val) } pub fn mk_local<'vir>(&'vir self, name: &'vir str) -> Local<'vir> { - self.arena.alloc(LocalData { - name, - }) + self.arena.alloc(LocalData { name }) } pub fn mk_local_decl(&'tcx self, name: &'tcx str, ty: Type<'tcx>) -> LocalDecl<'tcx> { - self.arena.alloc(LocalDeclData { - name, - ty, - }) + self.arena.alloc(LocalDeclData { name, ty }) } - pub fn mk_local_ex_local(&'tcx self, local: Local<'tcx>) -> ExprGen<'tcx, Curr, Next> { + pub fn mk_local_ex_local( + &'tcx self, + local: Local<'tcx>, + ) -> ExprGen<'tcx, Curr, Next> { self.arena.alloc(ExprGenData::Local(local)) } pub fn mk_local_ex(&'tcx self, name: &'tcx str) -> ExprGen<'tcx, Curr, Next> { @@ -77,16 +70,18 @@ impl<'tcx> VirCtxt<'tcx> { target: &'tcx str, src_args: &[ExprGen<'tcx, Curr, Next>], ) -> ExprGen<'tcx, Curr, Next> { - self.arena.alloc(ExprGenData::FuncApp(self.arena.alloc(FuncAppGenData { - target, - args: self.alloc_slice(src_args), - }))) + self.arena + .alloc(ExprGenData::FuncApp(self.arena.alloc(FuncAppGenData { + target, + args: self.alloc_slice(src_args), + }))) } pub fn mk_pred_app(&'tcx self, target: &'tcx str, src_args: &[Expr<'tcx>]) -> Expr<'tcx> { - self.arena.alloc(ExprData::PredicateApp(self.arena.alloc(PredicateAppData { - target, - args: self.alloc_slice(src_args), - }))) + self.arena + .alloc(ExprData::PredicateApp(self.arena.alloc(PredicateAppData { + target, + args: self.alloc_slice(src_args), + }))) } pub fn mk_true(&'tcx self) -> Expr<'tcx> { diff --git a/vir/src/data.rs b/vir/src/data.rs index b4a6f21cf9d..d9b42b131ef 100644 --- a/vir/src/data.rs +++ b/vir/src/data.rs @@ -1,7 +1,7 @@ use std::fmt::Debug; -use prusti_rustc_interface::middle::mir; use crate::refs::*; +use prusti_rustc_interface::middle::mir; pub struct LocalData<'vir> { pub name: &'vir str, // TODO: identifiers @@ -41,6 +41,7 @@ pub enum BinOpKind { CmpLe, And, Add, + Sub, // ... } impl From for BinOpKind { diff --git a/vir/src/debug.rs b/vir/src/debug.rs index 7ba9cf57145..8475f38fce6 100644 --- a/vir/src/debug.rs +++ b/vir/src/debug.rs @@ -1,13 +1,14 @@ use std::fmt::{Debug, Display, Formatter, Result as FmtResult}; -use crate::data::*; -use crate::gendata::*; +use crate::{data::*, gendata::*}; fn fmt_comma_sep_display(f: &mut Formatter<'_>, els: &[T]) -> FmtResult { els.iter() .enumerate() .map(|(idx, el)| { - if idx > 0 { write!(f, ", ")? } + if idx > 0 { + write!(f, ", ")? + } el.fmt(f) }) .collect::() @@ -16,7 +17,9 @@ fn fmt_comma_sep(f: &mut Formatter<'_>, els: &[T]) -> FmtResult { els.iter() .enumerate() .map(|(idx, el)| { - if idx > 0 { write!(f, ", ")? } + if idx > 0 { + write!(f, ", ")? + } el.fmt(f) }) .collect::() @@ -32,10 +35,7 @@ fn fmt_comma_sep_lines(f: &mut Formatter<'_>, els: &[T]) -> FmtResult Ok(()) } fn indent(s: String) -> String { - s - .split("\n") - .intersperse("\n ") - .collect::() + s.split("\n").intersperse("\n ").collect::() } impl<'vir, Curr, Next> Debug for AccFieldGenData<'vir, Curr, Next> { @@ -48,16 +48,21 @@ impl<'vir, Curr, Next> Debug for BinOpGenData<'vir, Curr, Next> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "(")?; self.lhs.fmt(f)?; - write!(f, ") {} (", match self.kind { - BinOpKind::CmpEq => "==", - BinOpKind::CmpNe => "!=", - BinOpKind::CmpGt => ">", - BinOpKind::CmpGe => ">=", - BinOpKind::CmpLt => "<", - BinOpKind::CmpLe => "<=", - BinOpKind::And => "&&", - BinOpKind::Add => "+", - })?; + write!( + f, + ") {} (", + match self.kind { + BinOpKind::CmpEq => "==", + BinOpKind::CmpNe => "!=", + BinOpKind::CmpGt => ">", + BinOpKind::CmpGe => ">=", + BinOpKind::CmpLt => "<", + BinOpKind::CmpLe => "<=", + BinOpKind::And => "&&", + BinOpKind::Add => "+", + BinOpKind::Sub => "-", + } + )?; self.rhs.fmt(f)?; write!(f, ")") } @@ -92,8 +97,14 @@ impl<'vir, Curr, Next> Debug for DomainGenData<'vir, Curr, Next> { write!(f, "]")?; } writeln!(f, " {{")?; - self.axioms.iter().map(|el| el.fmt(f)).collect::()?; - self.functions.iter().map(|el| el.fmt(f)).collect::()?; + self.axioms + .iter() + .map(|el| el.fmt(f)) + .collect::()?; + self.functions + .iter() + .map(|el| el.fmt(f)) + .collect::()?; writeln!(f, "}}") } } @@ -173,8 +184,14 @@ impl<'vir, Curr, Next> Debug for FunctionGenData<'vir, Curr, Next> { writeln!(f, "function {}(", self.name)?; fmt_comma_sep_lines(f, &self.args)?; writeln!(f, "): {:?}", self.ret)?; - self.pres.iter().map(|el| writeln!(f, " requires {:?}", el)).collect::()?; - self.posts.iter().map(|el| writeln!(f, " ensures {:?}", el)).collect::()?; + self.pres + .iter() + .map(|el| writeln!(f, " requires {:?}", el)) + .collect::()?; + self.posts + .iter() + .map(|el| writeln!(f, " ensures {:?}", el)) + .collect::()?; if let Some(expr) = self.expr { write!(f, "{{\n ")?; expr.fmt(f)?; @@ -221,8 +238,14 @@ impl<'vir, Curr, Next> Debug for MethodGenData<'vir, Curr, Next> { } else { writeln!(f, ")")?; } - self.pres.iter().map(|el| writeln!(f, " requires {:?}", el)).collect::()?; - self.posts.iter().map(|el| writeln!(f, " ensures {:?}", el)).collect::()?; + self.pres + .iter() + .map(|el| writeln!(f, " requires {:?}", el)) + .collect::()?; + self.posts + .iter() + .map(|el| writeln!(f, " ensures {:?}", el)) + .collect::()?; if let Some(blocks) = self.blocks.as_ref() { writeln!(f, "{{")?; for block in blocks.iter() { @@ -301,7 +324,11 @@ impl<'vir, Curr, Next> Debug for TerminatorStmtGenData<'vir, Curr, Next> { write!(f, "goto {:?}", data.otherwise) } else { for target in data.targets { - write!(f, "if ({:?} == {:?}) {{ goto {:?} }}\n else", data.value, target.0, target.1)?; + write!( + f, + "if ({:?} == {:?}) {{ goto {:?} }}\n else", + data.value, target.0, target.1 + )?; } write!(f, " {{ goto {:?} }}", data.otherwise) } @@ -342,10 +369,15 @@ impl<'vir> Debug for TypeData<'vir> { impl<'vir, Curr, Next> Debug for UnOpGenData<'vir, Curr, Next> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - write!(f, "{}({:?})", match self.kind { - UnOpKind::Neg => "-", - UnOpKind::Not => "!", - }, self.expr) + write!( + f, + "{}({:?})", + match self.kind { + UnOpKind::Neg => "-", + UnOpKind::Not => "!", + }, + self.expr + ) } } diff --git a/vir/src/gendata.rs b/vir/src/gendata.rs index 2893c45992f..e66fe87b499 100644 --- a/vir/src/gendata.rs +++ b/vir/src/gendata.rs @@ -1,20 +1,20 @@ use std::fmt::Debug; -use crate::data::*; -use crate::genrefs::*; -use crate::refs::*; +use crate::{data::*, genrefs::*, refs::*}; use vir_proc_macro::*; #[derive(Reify)] pub struct UnOpGenData<'vir, Curr, Next> { - #[reify_copy] pub kind: UnOpKind, + #[reify_copy] + pub kind: UnOpKind, pub expr: ExprGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct BinOpGenData<'vir, Curr, Next> { - #[reify_copy] pub kind: BinOpKind, + #[reify_copy] + pub kind: BinOpKind, pub lhs: ExprGen<'vir, Curr, Next>, pub rhs: ExprGen<'vir, Curr, Next>, } @@ -28,20 +28,23 @@ pub struct TernaryGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct ForallGenData<'vir, Curr, Next> { - #[reify_copy] pub qvars: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub qvars: &'vir [LocalDecl<'vir>], pub triggers: &'vir [&'vir [ExprGen<'vir, Curr, Next>]], pub body: ExprGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct FuncAppGenData<'vir, Curr, Next> { - #[reify_copy] pub target: &'vir str, // TODO: identifiers + #[reify_copy] + pub target: &'vir str, // TODO: identifiers pub args: &'vir [ExprGen<'vir, Curr, Next>], } #[derive(Reify)] pub struct PredicateAppGenData<'vir, Curr, Next> { - #[reify_copy] pub target: &'vir str, // TODO: identifiers + #[reify_copy] + pub target: &'vir str, // TODO: identifiers pub args: &'vir [ExprGen<'vir, Curr, Next>], } @@ -54,13 +57,15 @@ pub struct UnfoldingGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct AccFieldGenData<'vir, Curr, Next> { pub recv: ExprGen<'vir, Curr, Next>, - #[reify_copy] pub field: &'vir str, // TODO: identifiers - // TODO: permission amount + #[reify_copy] + pub field: &'vir str, // TODO: identifiers + // TODO: permission amount } #[derive(Reify)] pub struct LetGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, + #[reify_copy] + pub name: &'vir str, pub val: ExprGen<'vir, Curr, Next>, pub expr: ExprGen<'vir, Curr, Next>, } @@ -103,8 +108,10 @@ pub enum ExprGenData<'vir, Curr: 'vir, Next: 'vir> { PredicateApp(PredicateAppGen<'vir, Curr, Next>), // TODO: this should not be used instead of acc? // domain func app // inhale/exhale - - Lazy(&'vir str, Box, Curr) -> Next + 'vir>), + Lazy( + &'vir str, + Box, Curr) -> Next + 'vir>, + ), Todo(&'vir str), } @@ -120,30 +127,39 @@ impl<'vir, Curr, Next> ExprGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct DomainAxiomGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // ? or comment, then auto-gen the names? + #[reify_copy] + pub name: &'vir str, // ? or comment, then auto-gen the names? pub expr: ExprGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct DomainGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub typarams: &'vir [&'vir str], + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub typarams: &'vir [&'vir str], pub axioms: &'vir [DomainAxiomGen<'vir, Curr, Next>], - #[reify_copy] pub functions: &'vir [DomainFunction<'vir>], + #[reify_copy] + pub functions: &'vir [DomainFunction<'vir>], } #[derive(Reify)] pub struct PredicateGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub args: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub args: &'vir [LocalDecl<'vir>], pub expr: Option>, } #[derive(Reify)] pub struct FunctionGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub args: &'vir [LocalDecl<'vir>], - #[reify_copy] pub ret: Type<'vir>, + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub args: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub ret: Type<'vir>, pub pres: &'vir [ExprGen<'vir, Curr, Next>], pub posts: &'vir [ExprGen<'vir, Curr, Next>], pub expr: Option>, @@ -160,14 +176,19 @@ pub struct PureAssignGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct MethodCallGenData<'vir, Curr, Next> { - #[reify_copy] pub targets: &'vir [Local<'vir>], - #[reify_copy] pub method: &'vir str, + #[reify_copy] + pub targets: &'vir [Local<'vir>], + #[reify_copy] + pub method: &'vir str, pub args: &'vir [ExprGen<'vir, Curr, Next>], } #[derive(Reify)] pub enum StmtGenData<'vir, Curr, Next> { - LocalDecl(#[reify_copy] LocalDecl<'vir>, Option>), + LocalDecl( + #[reify_copy] LocalDecl<'vir>, + Option>, + ), PureAssign(PureAssignGen<'vir, Curr, Next>), Inhale(ExprGen<'vir, Curr, Next>), Exhale(ExprGen<'vir, Curr, Next>), @@ -182,7 +203,8 @@ pub enum StmtGenData<'vir, Curr, Next> { pub struct GotoIfGenData<'vir, Curr, Next> { pub value: ExprGen<'vir, Curr, Next>, pub targets: &'vir [(ExprGen<'vir, Curr, Next>, CfgBlockLabel<'vir>)], - #[reify_copy] pub otherwise: CfgBlockLabel<'vir>, + #[reify_copy] + pub otherwise: CfgBlockLabel<'vir>, } #[derive(Reify)] @@ -196,16 +218,20 @@ pub enum TerminatorStmtGenData<'vir, Curr, Next> { #[derive(Debug, Reify)] pub struct CfgBlockGenData<'vir, Curr, Next> { - #[reify_copy] pub label: CfgBlockLabel<'vir>, + #[reify_copy] + pub label: CfgBlockLabel<'vir>, pub stmts: &'vir [StmtGen<'vir, Curr, Next>], pub terminator: TerminatorStmtGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct MethodGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub args: &'vir [LocalDecl<'vir>], - #[reify_copy] pub rets: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub args: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub rets: &'vir [LocalDecl<'vir>], // TODO: pre/post - add a comment variant pub pres: &'vir [ExprGen<'vir, Curr, Next>], pub posts: &'vir [ExprGen<'vir, Curr, Next>], @@ -214,7 +240,8 @@ pub struct MethodGenData<'vir, Curr, Next> { #[derive(Debug, Reify)] pub struct ProgramGenData<'vir, Curr, Next> { - #[reify_copy] pub fields: &'vir [Field<'vir>], + #[reify_copy] + pub fields: &'vir [Field<'vir>], pub domains: &'vir [DomainGen<'vir, Curr, Next>], pub predicates: &'vir [PredicateGen<'vir, Curr, Next>], pub functions: &'vir [FunctionGen<'vir, Curr, Next>], diff --git a/vir/src/genrefs.rs b/vir/src/genrefs.rs index f21ca3c8ef8..f98eb0bef1a 100644 --- a/vir/src/genrefs.rs +++ b/vir/src/genrefs.rs @@ -1,7 +1,8 @@ pub type AccFieldGen<'vir, Curr, Next> = &'vir crate::gendata::AccFieldGenData<'vir, Curr, Next>; pub type BinOpGen<'vir, Curr, Next> = &'vir crate::gendata::BinOpGenData<'vir, Curr, Next>; pub type CfgBlockGen<'vir, Curr, Next> = &'vir crate::gendata::CfgBlockGenData<'vir, Curr, Next>; -pub type DomainAxiomGen<'vir, Curr, Next> = &'vir crate::gendata::DomainAxiomGenData<'vir, Curr, Next>; +pub type DomainAxiomGen<'vir, Curr, Next> = + &'vir crate::gendata::DomainAxiomGenData<'vir, Curr, Next>; pub type DomainGen<'vir, Curr, Next> = &'vir crate::gendata::DomainGenData<'vir, Curr, Next>; pub type ExprGen<'vir, Curr, Next> = &'vir crate::gendata::ExprGenData<'vir, Curr, Next>; pub type ForallGen<'vir, Curr, Next> = &'vir crate::gendata::ForallGenData<'vir, Curr, Next>; @@ -10,13 +11,17 @@ pub type FunctionGen<'vir, Curr, Next> = &'vir crate::gendata::FunctionGenData<' pub type GotoIfGen<'vir, Curr, Next> = &'vir crate::gendata::GotoIfGenData<'vir, Curr, Next>; pub type LetGen<'vir, Curr, Next> = &'vir crate::gendata::LetGenData<'vir, Curr, Next>; pub type MethodGen<'vir, Curr, Next> = &'vir crate::gendata::MethodGenData<'vir, Curr, Next>; -pub type MethodCallGen<'vir, Curr, Next> = &'vir crate::gendata::MethodCallGenData<'vir, Curr, Next>; +pub type MethodCallGen<'vir, Curr, Next> = + &'vir crate::gendata::MethodCallGenData<'vir, Curr, Next>; pub type PredicateGen<'vir, Curr, Next> = &'vir crate::gendata::PredicateGenData<'vir, Curr, Next>; -pub type PredicateAppGen<'vir, Curr, Next> = &'vir crate::gendata::PredicateAppGenData<'vir, Curr, Next>; +pub type PredicateAppGen<'vir, Curr, Next> = + &'vir crate::gendata::PredicateAppGenData<'vir, Curr, Next>; pub type ProgramGen<'vir, Curr, Next> = &'vir crate::gendata::ProgramGenData<'vir, Curr, Next>; -pub type PureAssignGen<'vir, Curr, Next> = &'vir crate::gendata::PureAssignGenData<'vir, Curr, Next>; +pub type PureAssignGen<'vir, Curr, Next> = + &'vir crate::gendata::PureAssignGenData<'vir, Curr, Next>; pub type StmtGen<'vir, Curr, Next> = &'vir crate::gendata::StmtGenData<'vir, Curr, Next>; -pub type TerminatorStmtGen<'vir, Curr, Next> = &'vir crate::gendata::TerminatorStmtGenData<'vir, Curr, Next>; +pub type TerminatorStmtGen<'vir, Curr, Next> = + &'vir crate::gendata::TerminatorStmtGenData<'vir, Curr, Next>; pub type TernaryGen<'vir, Curr, Next> = &'vir crate::gendata::TernaryGenData<'vir, Curr, Next>; pub type UnOpGen<'vir, Curr, Next> = &'vir crate::gendata::UnOpGenData<'vir, Curr, Next>; pub type UnfoldingGen<'vir, Curr, Next> = &'vir crate::gendata::UnfoldingGenData<'vir, Curr, Next>; diff --git a/vir/src/macros.rs b/vir/src/macros.rs index d9210baf377..80476b45dbe 100644 --- a/vir/src/macros.rs +++ b/vir/src/macros.rs @@ -1,6 +1,6 @@ //#[macro_export] //macro_rules! vir_expr_nopos { -// +// //} //#[macro_export] @@ -112,8 +112,12 @@ macro_rules! vir_expr { #[macro_export] macro_rules! vir_ident { - ($vcx:expr; [ $name:expr ]) => { $name }; - ($vcx:expr; $name:ident ) => { $vcx.alloc_str(stringify!($name)) }; + ($vcx:expr; [ $name:expr ]) => { + $name + }; + ($vcx:expr; $name:ident ) => { + $vcx.alloc_str(stringify!($name)) + }; } #[macro_export] @@ -123,10 +127,18 @@ macro_rules! vir_format { #[macro_export] macro_rules! vir_type { - ($vcx:expr; Int) => { $vcx.alloc($crate::TypeData::Int) }; - ($vcx:expr; Bool) => { $vcx.alloc($crate::TypeData::Bool) }; - ($vcx:expr; Ref) => { $vcx.alloc($crate::TypeData::Ref) }; - ($vcx:expr; [ $ty:expr ]) => { $ty }; + ($vcx:expr; Int) => { + $vcx.alloc($crate::TypeData::Int) + }; + ($vcx:expr; Bool) => { + $vcx.alloc($crate::TypeData::Bool) + }; + ($vcx:expr; Ref) => { + $vcx.alloc($crate::TypeData::Ref) + }; + ($vcx:expr; [ $ty:expr ]) => { + $ty + }; ($vcx:expr; $name:ident) => { $vcx.alloc($crate::TypeData::Domain($vcx.alloc_str(stringify!($name)))) }; diff --git a/vir/src/reify.rs b/vir/src/reify.rs index 6c1106a2731..6e40b4d69ae 100644 --- a/vir/src/reify.rs +++ b/vir/src/reify.rs @@ -1,18 +1,11 @@ -use crate::VirCtxt; -use crate::gendata::*; -use crate::genrefs::*; -use crate::refs::*; +use crate::{gendata::*, genrefs::*, refs::*, VirCtxt}; pub use vir_proc_macro::*; pub trait Reify<'vir, Curr, NextA, NextB> { type Next: Sized; - fn reify( - &self, - vcx: &'vir VirCtxt<'vir>, - lctx: Curr, - ) -> Self::Next; + fn reify(&self, vcx: &'vir VirCtxt<'vir>, lctx: Curr) -> Self::Next; } impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> @@ -31,7 +24,9 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> ExprGenData::Forall(v) => vcx.alloc(ExprGenData::Forall(v.reify(vcx, lctx))), ExprGenData::Let(v) => vcx.alloc(ExprGenData::Let(v.reify(vcx, lctx))), ExprGenData::FuncApp(v) => vcx.alloc(ExprGenData::FuncApp(v.reify(vcx, lctx))), - ExprGenData::PredicateApp(v) => vcx.alloc(ExprGenData::PredicateApp(v.reify(vcx, lctx))), + ExprGenData::PredicateApp(v) => { + vcx.alloc(ExprGenData::PredicateApp(v.reify(vcx, lctx))) + } ExprGenData::Local(v) => vcx.alloc(ExprGenData::Local(v)), ExprGenData::Const(v) => vcx.alloc(ExprGenData::Const(v)), @@ -51,9 +46,12 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> { type Next = &'vir [ExprGen<'vir, NextA, NextB>]; fn reify(&self, vcx: &'vir VirCtxt<'vir>, lctx: Curr) -> Self::Next { - vcx.alloc_slice(&self.iter() - .map(|elem| elem.reify(vcx, lctx)) - .collect::>()) + vcx.alloc_slice( + &self + .iter() + .map(|elem| elem.reify(vcx, lctx)) + .collect::>(), + ) } } @@ -62,20 +60,29 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> { type Next = &'vir [&'vir [ExprGen<'vir, NextA, NextB>]]; fn reify(&self, vcx: &'vir VirCtxt<'vir>, lctx: Curr) -> Self::Next { - vcx.alloc_slice(&self.iter() - .map(|elem| elem.reify(vcx, lctx)) - .collect::>()) + vcx.alloc_slice( + &self + .iter() + .map(|elem| elem.reify(vcx, lctx)) + .collect::>(), + ) } } impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> - for [(ExprGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>, CfgBlockLabel<'vir>)] + for [( + ExprGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>, + CfgBlockLabel<'vir>, + )] { type Next = &'vir [(ExprGen<'vir, NextA, NextB>, CfgBlockLabel<'vir>)]; fn reify(&self, vcx: &'vir VirCtxt<'vir>, lctx: Curr) -> Self::Next { - vcx.alloc_slice(&self.iter() - .map(|(elem, label)| (elem.reify(vcx, lctx), *label)) - .collect::>()) + vcx.alloc_slice( + &self + .iter() + .map(|(elem, label)| (elem.reify(vcx, lctx), *label)) + .collect::>(), + ) } } @@ -97,7 +104,6 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> } } - /* impl< 'vir,