From ed6cec1941f6ea21ae45f0933a10a897b1acce49 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Wed, 15 Nov 2023 19:22:13 +0100 Subject: [PATCH] Run rustfmt --- prusti-encoder/src/encoders/generic.rs | 87 ++- prusti-encoder/src/encoders/local_def.rs | 72 +- prusti-encoder/src/encoders/mir_builtin.rs | 238 +++--- prusti-encoder/src/encoders/mir_impure.rs | 509 +++++++------ prusti-encoder/src/encoders/mir_pure.rs | 687 +++++++++++------- .../src/encoders/mir_pure_function.rs | 73 +- prusti-encoder/src/encoders/mod.rs | 28 +- prusti-encoder/src/encoders/pure/spec.rs | 133 ++-- prusti-encoder/src/encoders/spec.rs | 39 +- prusti-encoder/src/encoders/typ.rs | 632 +++++++++------- prusti-encoder/src/encoders/viper_tuple.rs | 135 ++-- prusti-encoder/src/lib.rs | 39 +- task-encoder/src/lib.rs | 398 +++++----- vir/src/callable_idents.rs | 34 +- vir/src/context.rs | 59 +- vir/src/data.rs | 7 +- vir/src/debug.rs | 95 ++- vir/src/gendata.rs | 89 ++- vir/src/genrefs.rs | 15 +- vir/src/lib.rs | 2 +- vir/src/macros.rs | 36 +- vir/src/reify.rs | 48 +- 22 files changed, 1977 insertions(+), 1478 deletions(-) diff --git a/prusti-encoder/src/encoders/generic.rs b/prusti-encoder/src/encoders/generic.rs index 26e2905bfd3..d44426337aa 100644 --- a/prusti-encoder/src/encoders/generic.rs +++ b/prusti-encoder/src/encoders/generic.rs @@ -1,8 +1,5 @@ -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; -use vir::{FunctionIdent, CallableIdent, NullaryArity}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use vir::{CallableIdent, FunctionIdent, NullaryArity}; pub struct GenericEncoder; @@ -40,7 +37,8 @@ impl TaskEncoder for GenericEncoder { type EncodingError = GenericEncoderError; fn with_cache<'tcx: 'vir, 'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, GenericEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, GenericEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -59,18 +57,24 @@ impl TaskEncoder for GenericEncoder { fn do_encode_full<'tcx: 'vir, 'vir>( task_key: &Self::TaskKey<'tcx>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { - deps.emit_output_ref::(*task_key, GenericEncoderOutputRef { - snapshot_param_name: "s_Param", - predicate_param_name: "p_Param", - domain_type_name: "s_Type", - }); + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { + deps.emit_output_ref::( + *task_key, + GenericEncoderOutputRef { + snapshot_param_name: "s_Param", + predicate_param_name: "p_Param", + domain_type_name: "s_Type", + }, + ); let s_Type_Bool = FunctionIdent::new("s_Type_Bool", NullaryArity::new([])); let s_Type_Int_isize = FunctionIdent::new("s_Type_Int_isize", NullaryArity::new([])); let s_Type_Int_i8 = FunctionIdent::new("s_Type_Int_i8", NullaryArity::new([])); @@ -84,26 +88,31 @@ impl TaskEncoder for GenericEncoder { let s_Type_Uint_u32 = FunctionIdent::new("s_Type_Uint_u32", NullaryArity::new([])); let s_Type_Uint_u64 = FunctionIdent::new("s_Type_Uint_u64", NullaryArity::new([])); let s_Type_Uint_u128 = FunctionIdent::new("s_Type_Uint_u128", NullaryArity::new([])); - vir::with_vcx(|vcx| Ok((GenericEncoderOutput { - snapshot_param: vir::vir_domain! { vcx; domain s_Param {} }, - predicate_param: vir::vir_predicate! { vcx; predicate p_Param(self_p: Ref/*, self_s: s_Param*/) }, - domain_type: vir::vir_domain! { vcx; domain s_Type { - // TODO: only emit these when the types are actually used? - // emit instead from type encoder, to be consistent with the ADT case? - unique function s_Type_Bool(): s_Type; - unique function s_Type_Int_isize(): s_Type; - unique function s_Type_Int_i8(): s_Type; - unique function s_Type_Int_i16(): s_Type; - unique function s_Type_Int_i32(): s_Type; - unique function s_Type_Int_i64(): s_Type; - unique function s_Type_Int_i128(): s_Type; - unique function s_Type_Uint_usize(): s_Type; - unique function s_Type_Uint_u8(): s_Type; - unique function s_Type_Uint_u16(): s_Type; - unique function s_Type_Uint_u32(): s_Type; - unique function s_Type_Uint_u64(): s_Type; - unique function s_Type_Uint_u128(): s_Type; - } }, - }, ()))) + vir::with_vcx(|vcx| { + Ok(( + GenericEncoderOutput { + snapshot_param: vir::vir_domain! { vcx; domain s_Param {} }, + predicate_param: vir::vir_predicate! { vcx; predicate p_Param(self_p: Ref/*, self_s: s_Param*/) }, + domain_type: vir::vir_domain! { vcx; domain s_Type { + // TODO: only emit these when the types are actually used? + // emit instead from type encoder, to be consistent with the ADT case? + unique function s_Type_Bool(): s_Type; + unique function s_Type_Int_isize(): s_Type; + unique function s_Type_Int_i8(): s_Type; + unique function s_Type_Int_i16(): s_Type; + unique function s_Type_Int_i32(): s_Type; + unique function s_Type_Int_i64(): s_Type; + unique function s_Type_Int_i128(): s_Type; + unique function s_Type_Uint_usize(): s_Type; + unique function s_Type_Uint_u8(): s_Type; + unique function s_Type_Uint_u16(): s_Type; + unique function s_Type_Uint_u32(): s_Type; + unique function s_Type_Uint_u64(): s_Type; + unique function s_Type_Uint_u128(): s_Type; + } }, + }, + (), + )) + }) } } diff --git a/prusti-encoder/src/encoders/local_def.rs b/prusti-encoder/src/encoders/local_def.rs index 7f2e09f9119..e76d982dd39 100644 --- a/prusti-encoder/src/encoders/local_def.rs +++ b/prusti-encoder/src/encoders/local_def.rs @@ -1,11 +1,11 @@ use prusti_rustc_interface::{ index::IndexVec, middle::{mir, ty}, - span::def_id::DefId + span::def_id::DefId, }; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use std::cell::RefCell; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use crate::encoders::TypeEncoderOutputRef; @@ -32,9 +32,9 @@ thread_local! { impl TaskEncoder for MirLocalDefEncoder { type TaskDescription<'tcx> = ( - DefId, // ID of the function + DefId, // ID of the function ty::GenericArgsRef<'tcx>, // ? this should be the "signature", after applying the env/substs - Option, // ID of the caller function, if any + Option, // ID of the caller function, if any ); type OutputFullLocal<'vir> = MirLocalDefEncoderOutput<'vir>; @@ -73,12 +73,16 @@ impl TaskEncoder for MirLocalDefEncoder { > { let (def_id, substs, caller_def_id) = *task_key; deps.emit_output_ref::(*task_key, ()); - fn mk_local_def<'vir, 'tcx>(vcx: &'vir vir::VirCtxt<'tcx>, name: &'vir str, ty: TypeEncoderOutputRef<'vir>) -> LocalDef<'vir> { + fn mk_local_def<'vir, 'tcx>( + vcx: &'vir vir::VirCtxt<'tcx>, + name: &'vir str, + ty: TypeEncoderOutputRef<'vir>, + ) -> LocalDef<'vir> { let local = vcx.mk_local(name); let local_ex = vcx.mk_local_ex_local(local); let impure_snap = ty.ref_to_snap.apply(vcx, [local_ex]); let impure_pred = vcx.alloc(vir::ExprData::PredicateApp( - ty.ref_to_pred.apply(vcx, [local_ex]) + ty.ref_to_pred.apply(vcx, [local_ex]), )); LocalDef { local, @@ -91,36 +95,48 @@ impl TaskEncoder for MirLocalDefEncoder { vir::with_vcx(|vcx| { let data = if let Some(local_def_id) = def_id.as_local() { - let body = vcx.body.borrow_mut().get_impure_fn_body(local_def_id, substs, caller_def_id); - let locals = IndexVec::from_fn_n(|arg: mir::Local| { - let local = vir::vir_format!(vcx, "_{}p", arg.index()); - let ty = deps.require_ref::( - body.local_decls[arg].ty, - ).unwrap(); - mk_local_def(vcx, local, ty) - }, body.local_decls.len()); + let body = + vcx.body + .borrow_mut() + .get_impure_fn_body(local_def_id, substs, caller_def_id); + let locals = IndexVec::from_fn_n( + |arg: mir::Local| { + let local = vir::vir_format!(vcx, "_{}p", arg.index()); + let ty = deps + .require_ref::(body.local_decls[arg].ty) + .unwrap(); + mk_local_def(vcx, local, ty) + }, + body.local_decls.len(), + ); MirLocalDefEncoderOutput { locals: vcx.alloc(locals), arg_count: body.arg_count, } } else { let param_env = vcx.tcx.param_env(caller_def_id.unwrap_or(def_id)); - let sig = vcx.tcx - .subst_and_normalize_erasing_regions(substs, param_env, vcx.tcx.fn_sig(def_id)); + let sig = vcx.tcx.subst_and_normalize_erasing_regions( + substs, + param_env, + vcx.tcx.fn_sig(def_id), + ); let sig = sig.skip_binder(); - let locals = IndexVec::from_fn_n(|arg: mir::Local| { - let local = vir::vir_format!(vcx, "_{}p", arg.index()); - let ty = if arg.index() == 0 { - sig.output() - } else { - sig.inputs()[arg.index() - 1] - }; - let ty = deps.require_ref::( - ty, - ).unwrap(); - mk_local_def(vcx, local, ty) - }, sig.inputs_and_output.len()); + let locals = IndexVec::from_fn_n( + |arg: mir::Local| { + let local = vir::vir_format!(vcx, "_{}p", arg.index()); + let ty = if arg.index() == 0 { + sig.output() + } else { + sig.inputs()[arg.index() - 1] + }; + let ty = deps + .require_ref::(ty) + .unwrap(); + mk_local_def(vcx, local, ty) + }, + sig.inputs_and_output.len(), + ); MirLocalDefEncoderOutput { locals: vcx.alloc(locals), diff --git a/prusti-encoder/src/encoders/mir_builtin.rs b/prusti-encoder/src/encoders/mir_builtin.rs index 903b32b9421..490eaf5dc5b 100644 --- a/prusti-encoder/src/encoders/mir_builtin.rs +++ b/prusti-encoder/src/encoders/mir_builtin.rs @@ -1,12 +1,6 @@ -use prusti_rustc_interface::{ - middle::ty, - middle::mir, -}; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; -use vir::{UnknownArity, FunctionIdent, CallableIdent}; +use prusti_rustc_interface::middle::{mir, ty}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use vir::{CallableIdent, FunctionIdent, UnknownArity}; pub struct MirBuiltinEncoder; @@ -48,7 +42,8 @@ impl TaskEncoder for MirBuiltinEncoder { type EncodingError = MirBuiltinEncoderError; fn with_cache<'tcx: 'vir, 'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, MirBuiltinEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, MirBuiltinEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -66,28 +61,30 @@ impl TaskEncoder for MirBuiltinEncoder { fn do_encode_full<'tcx: 'vir, 'vir>( task_key: &Self::TaskKey<'tcx>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { - vir::with_vcx(|vcx| { - match *task_key { - MirBuiltinEncoderTask::UnOp(res_ty, op, operand_ty) => { - assert_eq!(res_ty, operand_ty); - let function = Self::handle_un_op(vcx, deps, *task_key, op, operand_ty); - Ok((MirBuiltinEncoderOutput { function }, ())) - } - MirBuiltinEncoderTask::BinOp(res_ty, op, l_ty, r_ty) => { - let function = Self::handle_bin_op(vcx, deps, *task_key, res_ty, op, l_ty, r_ty); - Ok((MirBuiltinEncoderOutput { function }, ())) - } - MirBuiltinEncoderTask::CheckedBinOp(res_ty, op, l_ty, r_ty) => { - let function = Self::handle_checked_bin_op(vcx, deps, *task_key, res_ty, op, l_ty, r_ty); - Ok((MirBuiltinEncoderOutput { function }, ())) - } + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { + vir::with_vcx(|vcx| match *task_key { + MirBuiltinEncoderTask::UnOp(res_ty, op, operand_ty) => { + assert_eq!(res_ty, operand_ty); + let function = Self::handle_un_op(vcx, deps, *task_key, op, operand_ty); + Ok((MirBuiltinEncoderOutput { function }, ())) + } + MirBuiltinEncoderTask::BinOp(res_ty, op, l_ty, r_ty) => { + let function = Self::handle_bin_op(vcx, deps, *task_key, res_ty, op, l_ty, r_ty); + Ok((MirBuiltinEncoderOutput { function }, ())) + } + MirBuiltinEncoderTask::CheckedBinOp(res_ty, op, l_ty, r_ty) => { + let function = + Self::handle_checked_bin_op(vcx, deps, *task_key, res_ty, op, l_ty, r_ty); + Ok((MirBuiltinEncoderOutput { function }, ())) } }) } @@ -109,29 +106,28 @@ impl MirBuiltinEncoder { deps: &mut TaskEncoderDependencies<'vir>, key: ::TaskKey<'tcx>, op: mir::UnOp, - ty: ty::Ty<'tcx> + ty: ty::Ty<'tcx>, ) -> vir::Function<'vir> { - let e_ty = deps.require_ref::( - ty, - ).unwrap(); + let e_ty = deps + .require_ref::(ty) + .unwrap(); let name = vir::vir_format!(vcx, "mir_unop_{op:?}_{}", int_name(ty)); let arity = UnknownArity::new(vcx.alloc_slice(&[e_ty.snapshot])); let function = FunctionIdent::new(name, arity); - deps.emit_output_ref::(key, MirBuiltinEncoderOutputRef { - function, - }); + deps.emit_output_ref::(key, MirBuiltinEncoderOutputRef { function }); let e_res_ty = &e_ty; let prim_res_ty = e_res_ty.expect_prim(); let snap_arg = vcx.mk_local_ex("arg"); let prim_arg = e_ty.expect_prim().snap_to_prim.apply(vcx, [snap_arg]); // `prim_to_snap(-snap_to_prim(arg))` - let mut val = prim_res_ty.prim_to_snap.apply(vcx, + let mut val = prim_res_ty.prim_to_snap.apply( + vcx, [vcx.alloc(vir::ExprData::UnOp(vcx.alloc(vir::UnOpData { kind: vir::UnOpKind::from(op), expr: prim_arg, - })))] + })))], ); // Can overflow when doing `- iN::MIN -> iN::MIN`. There is no // `CheckedUnOp`, instead the compiler puts an `TerminatorKind::Assert` @@ -174,29 +170,34 @@ impl MirBuiltinEncoder { r_ty: ty::Ty<'tcx>, ) -> vir::Function<'vir> { use mir::BinOp::*; - let e_l_ty = deps.require_ref::( - l_ty, - ).unwrap(); - let e_r_ty = deps.require_ref::( - r_ty, - ).unwrap(); - let e_res_ty = deps.require_ref::( - res_ty, - ).unwrap(); + let e_l_ty = deps + .require_ref::(l_ty) + .unwrap(); + let e_r_ty = deps + .require_ref::(r_ty) + .unwrap(); + let e_res_ty = deps + .require_ref::(res_ty) + .unwrap(); let prim_res_ty = e_res_ty.expect_prim(); - let name = vir::vir_format!(vcx, "mir_binop_{op:?}_{}_{}", int_name(l_ty), int_name(r_ty)); + let name = vir::vir_format!( + vcx, + "mir_binop_{op:?}_{}_{}", + int_name(l_ty), + int_name(r_ty) + ); let arity = UnknownArity::new(vcx.alloc_slice(&[e_l_ty.snapshot, e_r_ty.snapshot])); let function = FunctionIdent::new(name, arity); - deps.emit_output_ref::(key, MirBuiltinEncoderOutputRef { - function, - }); - let lhs = e_l_ty.expect_prim().snap_to_prim.apply(vcx, - [vcx.mk_local_ex("arg1")], - ); - let mut rhs = e_r_ty.expect_prim().snap_to_prim.apply(vcx, - [vcx.mk_local_ex("arg2")], - ); + deps.emit_output_ref::(key, MirBuiltinEncoderOutputRef { function }); + let lhs = e_l_ty + .expect_prim() + .snap_to_prim + .apply(vcx, [vcx.mk_local_ex("arg1")]); + let mut rhs = e_r_ty + .expect_prim() + .snap_to_prim + .apply(vcx, [vcx.mk_local_ex("arg2")]); if matches!(op, Shl | Shr) { // RHS must be smaller than the bit width of the LHS, this is // implicit in the `Shl` and `Shr` operators. @@ -206,18 +207,21 @@ impl MirBuiltinEncoder { rhs: vcx.get_bit_width_int(e_l_ty.expect_prim().prim_type, l_ty.kind()), }))); } - let val = prim_res_ty.prim_to_snap.apply(vcx, + let val = prim_res_ty.prim_to_snap.apply( + vcx, [vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::from(op), lhs, rhs, - })))] + })))], ); let (pres, val) = match op { // Overflow well defined as wrapping (implicit) and for the shifts // the RHS will be masked to the bit width. - Add | Sub | Mul | Shl | Shr => - (Vec::new(), Self::get_wrapped_val(vcx, val, prim_res_ty.prim_type, res_ty)), + Add | Sub | Mul | Shl | Shr => ( + Vec::new(), + Self::get_wrapped_val(vcx, val, prim_res_ty.prim_type, res_ty), + ), // Undefined behavior to overflow (need precondition) AddUnchecked | SubUnchecked | MulUnchecked => { let min = vcx.get_min_int(prim_res_ty.prim_type, res_ty.kind()); @@ -253,7 +257,10 @@ impl MirBuiltinEncoder { lhs: rhs, rhs: max, }))); - (vec![lower_bound, upper_bound], Self::get_wrapped_val(vcx, val, prim_res_ty.prim_type, res_ty)) + ( + vec![lower_bound, upper_bound], + Self::get_wrapped_val(vcx, val, prim_res_ty.prim_type, res_ty), + ) } // Could divide by zero or overflow if divisor is `-1` Div | Rem => { @@ -289,8 +296,7 @@ impl MirBuiltinEncoder { (pres, val) } // Cannot overflow and no undefined behavior - BitXor | BitAnd | BitOr | Eq | Lt | Le | Ne | Ge | Gt | Offset => - (Vec::new(), val), + BitXor | BitAnd | BitOr | Eq | Lt | Le | Ne | Ge | Gt | Offset => (Vec::new(), val), }; vcx.alloc(vir::FunctionData { name, @@ -315,56 +321,74 @@ impl MirBuiltinEncoder { r_ty: ty::Ty<'tcx>, ) -> vir::Function<'vir> { // `op` can only be `Add`, `Sub` or `Mul` - assert!(matches!(op, mir::BinOp::Add | mir::BinOp::Sub | mir::BinOp::Mul)); - let e_l_ty = deps.require_ref::( - l_ty, - ).unwrap(); - let e_r_ty = deps.require_ref::( - r_ty, - ).unwrap(); + assert!(matches!( + op, + mir::BinOp::Add | mir::BinOp::Sub | mir::BinOp::Mul + )); + let e_l_ty = deps + .require_ref::(l_ty) + .unwrap(); + let e_r_ty = deps + .require_ref::(r_ty) + .unwrap(); - let name = vir::vir_format!(vcx, "mir_checkedbinop_{op:?}_{}_{}", int_name(l_ty), int_name(r_ty)); + let name = vir::vir_format!( + vcx, + "mir_checkedbinop_{op:?}_{}_{}", + int_name(l_ty), + int_name(r_ty) + ); let arity = UnknownArity::new(vcx.alloc_slice(&[e_l_ty.snapshot, e_r_ty.snapshot])); let function = FunctionIdent::new(name, arity); - deps.emit_output_ref::(key, MirBuiltinEncoderOutputRef { - function, - }); + deps.emit_output_ref::(key, MirBuiltinEncoderOutputRef { function }); - let e_res_ty = deps.require_ref::( - res_ty, - ).unwrap(); + let e_res_ty = deps + .require_ref::(res_ty) + .unwrap(); // The result of a checked add will always be `(T, bool)`, get the `T` // type let rvalue_pure_ty = res_ty.tuple_fields()[0]; let bool_ty = res_ty.tuple_fields()[1]; assert!(bool_ty.is_bool()); - let e_rvalue_pure_ty = deps.require_ref::( - rvalue_pure_ty, - ).unwrap(); - let bool_cons = deps.require_ref::( - bool_ty, - ).unwrap().expect_prim().prim_to_snap; + let e_rvalue_pure_ty = deps + .require_ref::(rvalue_pure_ty) + .unwrap(); + let bool_cons = deps + .require_ref::(bool_ty) + .unwrap() + .expect_prim() + .prim_to_snap; // Unbounded value - let val_exp = vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::from(op), - lhs: e_l_ty.expect_prim().snap_to_prim.apply(vcx, - [vcx.mk_local_ex("arg1")], - ), - rhs: e_r_ty.expect_prim().snap_to_prim.apply(vcx, - [vcx.mk_local_ex("arg2")], - ), - }))); + let val_exp = vcx.alloc(vir::ExprData::BinOp( + vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::from(op), + lhs: e_l_ty + .expect_prim() + .snap_to_prim + .apply(vcx, [vcx.mk_local_ex("arg1")]), + rhs: e_r_ty + .expect_prim() + .snap_to_prim + .apply(vcx, [vcx.mk_local_ex("arg2")]), + }), + )); let val_str = vir::vir_format!(vcx, "val"); let val = vcx.mk_local_ex(val_str); // Wrapped value - let wrapped_val_exp = Self::get_wrapped_val(vcx, val, e_rvalue_pure_ty.expect_prim().prim_type, rvalue_pure_ty); + let wrapped_val_exp = Self::get_wrapped_val( + vcx, + val, + e_rvalue_pure_ty.expect_prim().prim_type, + rvalue_pure_ty, + ); let wrapped_val_str = vir::vir_format!(vcx, "wrapped_val"); let wrapped_val = vcx.mk_local_ex(wrapped_val_str); - let wrapped_val_snap = e_rvalue_pure_ty.expect_prim().prim_to_snap.apply(vcx, - [wrapped_val], - ); + let wrapped_val_snap = e_rvalue_pure_ty + .expect_prim() + .prim_to_snap + .apply(vcx, [wrapped_val]); // Overflowed? let overflowed = vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpNe, @@ -373,9 +397,10 @@ impl MirBuiltinEncoder { }))); let overflowed_snap = bool_cons.apply(vcx, [overflowed]); // `tuple(prim_to_snap(wrapped_val), wrapped_val != val)` - let tuple = e_res_ty.expect_structlike().field_snaps_to_snap.apply(vcx, - &[wrapped_val_snap, overflowed_snap] - ); + let tuple = e_res_ty + .expect_structlike() + .field_snaps_to_snap + .apply(vcx, &[wrapped_val_snap, overflowed_snap]); // `let wrapped_val == (val ..) in $tuple` let inner_let = vcx.alloc(vir::ExprData::Let(vcx.alloc(vir::LetGenData { name: wrapped_val_str, @@ -403,7 +428,12 @@ impl MirBuiltinEncoder { /// Wrap the value in the range of the type, e.g. `uN` is wrapped in the /// range `uN::MIN..=uN::MAX` using modulo. For signed integers, the range /// is `iN::MIN..=iN::MAX` and the value is wrapped using two's complement. - fn get_wrapped_val<'vir, 'tcx>(vcx: &'vir vir::VirCtxt<'tcx>, mut exp: &'vir vir::ExprData<'vir>, ty: vir::Type, rust_ty: ty::Ty) -> &'vir vir::ExprData<'vir> { + fn get_wrapped_val<'vir, 'tcx>( + vcx: &'vir vir::VirCtxt<'tcx>, + mut exp: &'vir vir::ExprData<'vir>, + ty: vir::Type, + rust_ty: ty::Ty, + ) -> &'vir vir::ExprData<'vir> { let shift_amount = vcx.get_signed_shift_int(ty, rust_ty.kind()); if let Some(half) = shift_amount { exp = vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 5ef30d0da70..86c36c11c65 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -1,18 +1,16 @@ +use mir_state_analysis::{ + free_pcs::{CapabilityKind, FreePcsAnalysis, FreePcsBasicBlock, FreePcsLocation, RepackOp}, + utils::Place, +}; use prusti_rustc_interface::{ middle::{mir, ty}, span::def_id::DefId, }; -use mir_state_analysis::{ - free_pcs::{FreePcsAnalysis, FreePcsBasicBlock, FreePcsLocation, RepackOp, CapabilityKind}, utils::Place, -}; //use mir_ssa_analysis::{ // SsaAnalysis, //}; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; -use vir::{MethodIdent, UnknownArity, CallableIdent}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use vir::{CallableIdent, MethodIdent, UnknownArity}; pub struct MirImpureEncoder; @@ -42,9 +40,9 @@ const ENCODE_REACH_BB: bool = false; impl TaskEncoder for MirImpureEncoder { // TODO: local def id (+ promoted, substs, etc) type TaskDescription<'tcx> = ( - DefId, // ID of the function + DefId, // ID of the function ty::GenericArgsRef<'tcx>, // ? this should be the "signature", after applying the env/substs - Option, // ID of the caller function, if any + Option, // ID of the caller function, if any ); type OutputRef<'vir> = MirImpureEncoderOutputRef<'vir>; @@ -53,7 +51,8 @@ impl TaskEncoder for MirImpureEncoder { type EncodingError = MirImpureEncoderError; fn with_cache<'tcx: 'vir, 'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, MirImpureEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, MirImpureEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -71,24 +70,28 @@ impl TaskEncoder for MirImpureEncoder { fn do_encode_full<'tcx: 'vir, 'vir>( task_key: &Self::TaskKey<'tcx>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { let (def_id, substs, caller_def_id) = *task_key; - let trusted = crate::encoders::with_proc_spec(def_id, |def_spec| + let trusted = crate::encoders::with_proc_spec(def_id, |def_spec| { def_spec.trusted.extract_inherit().unwrap_or_default() - ).unwrap_or_default(); + }) + .unwrap_or_default(); vir::with_vcx(|vcx| { use mir::visit::Visitor; - let local_defs = deps.require_local::( - *task_key, - ).unwrap(); + let local_defs = deps + .require_local::(*task_key) + .unwrap(); // Argument count for the Viper method: // - one (`Ref`) for the return place; @@ -103,20 +106,29 @@ impl TaskEncoder for MirImpureEncoder { let arg_count = local_defs.arg_count + 1; let extra: String = substs.iter().map(|s| format!("_{s}")).collect(); - let caller = caller_def_id.map(|id| format!("{}_{}", id.krate, id.index.index())).unwrap_or_default(); - let method_name = vir::vir_format!(vcx, "m_{}{extra}_CALLER_{caller}", vcx.tcx.item_name(def_id)); + let caller = caller_def_id + .map(|id| format!("{}_{}", id.krate, id.index.index())) + .unwrap_or_default(); + let method_name = vir::vir_format!( + vcx, + "m_{}{extra}_CALLER_{caller}", + vcx.tcx.item_name(def_id) + ); let args = vec![&vir::TypeData::Ref; arg_count]; let args = UnknownArity::new(vcx.alloc_slice(&args)); let method_ref = MethodIdent::new(method_name, args); - deps.emit_output_ref::(*task_key, MirImpureEncoderOutputRef { - method_ref, - }); + deps.emit_output_ref::(*task_key, MirImpureEncoderOutputRef { method_ref }); // Do not encode the method body if it is external, trusted or just // a call stub. - let local_def_id = def_id.as_local().filter(|_| !trusted && caller_def_id.is_none()); + let local_def_id = def_id + .as_local() + .filter(|_| !trusted && caller_def_id.is_none()); let blocks = if let Some(local_def_id) = local_def_id { - let body = vcx.body.borrow_mut().get_impure_fn_body(local_def_id, substs, caller_def_id); + let body = + vcx.body + .borrow_mut() + .get_impure_fn_body(local_def_id, substs, caller_def_id); // let body = vcx.tcx.mir_promoted(local_def_id).0.borrow(); let fpcs_analysis = mir_state_analysis::run_free_pcs(&body, vcx.tcx); @@ -143,14 +155,13 @@ impl TaskEncoder for MirImpureEncoder { ))) } if ENCODE_REACH_BB { - start_stmts.extend((0..block_count) - .map(|block| { - let name = vir::vir_format!(vcx, "_reach_bb{block}"); - vcx.alloc(vir::StmtData::LocalDecl( - vir::vir_local_decl! { vcx; [name] : Bool }, - Some(vcx.alloc(vir::ExprData::Todo("false"))), - )) - })); + start_stmts.extend((0..block_count).map(|block| { + let name = vir::vir_format!(vcx, "_reach_bb{block}"); + vcx.alloc(vir::StmtData::LocalDecl( + vir::vir_local_decl! { vcx; [name] : Bool }, + Some(vcx.alloc(vir::ExprData::Todo("false"))), + )) + })); } encoded_blocks.push(vcx.alloc(vir::CfgBlockData { label: vcx.alloc(vir::CfgBlockLabelData::Start), @@ -189,9 +200,14 @@ impl TaskEncoder for MirImpureEncoder { None }; - let spec = deps.require_local::( - (def_id, substs, caller_def_id, false) - ).unwrap(); + let spec = deps + .require_local::(( + def_id, + substs, + caller_def_id, + false, + )) + .unwrap(); let (spec_pres, spec_posts) = (spec.pres, spec.posts); let mut pres = Vec::with_capacity(arg_count - 1); @@ -209,22 +225,26 @@ impl TaskEncoder for MirImpureEncoder { posts.push(local_defs.locals[mir::RETURN_PLACE].impure_pred); posts.extend(spec_posts); - Ok((MirImpureEncoderOutput { - method: vcx.alloc(vir::MethodData { - name: method_name, - args: vcx.alloc_slice(&args), - rets: &[], - pres: vcx.alloc_slice(&pres), - posts: vcx.alloc_slice(&posts), - blocks, - }), - }, ())) + Ok(( + MirImpureEncoderOutput { + method: vcx.alloc(vir::MethodData { + name: method_name, + args: vcx.alloc_slice(&args), + rets: &[], + pres: vcx.alloc_slice(&pres), + posts: vcx.alloc_slice(&posts), + blocks, + }), + }, + (), + )) }) } } struct EncoderVisitor<'tcx, 'vir, 'enc> - where 'vir: 'enc +where + 'vir: 'enc, { vcx: &'vir vir::VirCtxt<'tcx>, deps: &'enc mut TaskEncoderDependencies<'vir>, @@ -262,10 +282,13 @@ impl<'tcx, 'vir, 'enc> EncoderVisitor<'tcx, 'vir, 'enc> { match projection { mir::ProjectionElem::Field(f, ty) => { let ty_out_struct = ty_out.expect_structlike(); - let field_ty_out = self.deps.require_ref::( - ty, - ).unwrap(); - let field_ref = ty_out_struct.field_access[f.as_usize()].projection_p.apply(self.vcx, [base]); + let field_ty_out = self + .deps + .require_ref::(ty) + .unwrap(); + let field_ref = ty_out_struct.field_access[f.as_usize()] + .projection_p + .apply(self.vcx, [base]); (field_ref, field_ty_out) } _ => panic!("unsupported projection"), @@ -278,8 +301,11 @@ impl<'tcx, 'vir, 'enc> EncoderVisitor<'tcx, 'vir, 'enc> { ty_out: crate::encoders::TypeEncoderOutputRef<'vir>, projection: &'vir [mir::PlaceElem<'tcx>], ) -> (vir::Expr<'vir>, crate::encoders::TypeEncoderOutputRef<'vir>) { - projection.iter() - .fold((base, ty_out), |(base, ty_out), proj| self.project_one(base, ty_out, *proj)) + projection + .iter() + .fold((base, ty_out), |(base, ty_out), proj| { + self.project_one(base, ty_out, *proj) + }) } /* @@ -343,13 +369,17 @@ impl<'tcx, 'vir, 'enc> EncoderVisitor<'tcx, 'vir, 'enc> { let place_ty = (*place).ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); - let place_ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); + let place_ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); let ref_p = self.encode_place(place); let predicate = place_ty_out.ref_to_pred.apply(self.vcx, [ref_p]); - if matches!(repack_op, mir_state_analysis::free_pcs::RepackOp::Expand(..)) { + if matches!( + repack_op, + mir_state_analysis::free_pcs::RepackOp::Expand(..) + ) { self.stmt(vir::StmtData::Unfold(predicate)); } else { self.stmt(vir::StmtData::Fold(predicate)); @@ -359,14 +389,17 @@ impl<'tcx, 'vir, 'enc> EncoderVisitor<'tcx, 'vir, 'enc> { let place_ty = (*place).ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); - let place_ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); + let place_ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); let ref_p = self.encode_place(place); - self.stmt(vir::StmtData::Exhale(self.vcx.alloc(vir::ExprData::PredicateApp( - place_ty_out.ref_to_pred.apply(self.vcx, [ref_p]) - )))); + self.stmt(vir::StmtData::Exhale(self.vcx.alloc( + vir::ExprData::PredicateApp( + place_ty_out.ref_to_pred.apply(self.vcx, [ref_p]), + ), + ))); } unsupported_op => panic!("unsupported repack op: {unsupported_op:?}"), } @@ -374,61 +407,65 @@ impl<'tcx, 'vir, 'enc> EncoderVisitor<'tcx, 'vir, 'enc> { self.current_fpcs = Some(current_fpcs); } - fn encode_operand_snap( - &mut self, - operand: &mir::Operand<'tcx>, - ) -> vir::Expr<'vir> { + fn encode_operand_snap(&mut self, operand: &mir::Operand<'tcx>) -> vir::Expr<'vir> { match operand { &mir::Operand::Move(source) => { let place_ty = source.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); // TODO - let ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); + let ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); let place_exp = self.encode_place(Place::from(source)); let snap_val = ty_out.ref_to_snap.apply(self.vcx, [place_exp]); let tmp_exp = self.new_tmp(ty_out.snapshot).1; - self.stmt(vir::StmtData::PureAssign(self.vcx.alloc(vir::PureAssignData { - lhs: tmp_exp, - rhs: snap_val, - }))); - self.stmt(vir::StmtData::Exhale(self.vcx.alloc(vir::ExprData::PredicateApp( - ty_out.ref_to_pred.apply(self.vcx, [place_exp]) - )))); + self.stmt(vir::StmtData::PureAssign(self.vcx.alloc( + vir::PureAssignData { + lhs: tmp_exp, + rhs: snap_val, + }, + ))); + self.stmt(vir::StmtData::Exhale(self.vcx.alloc( + vir::ExprData::PredicateApp(ty_out.ref_to_pred.apply(self.vcx, [place_exp])), + ))); tmp_exp } &mir::Operand::Copy(source) => { let place_ty = source.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); // TODO - let ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); - ty_out.ref_to_snap.apply(self.vcx, [self.encode_place(Place::from(source))]) + let ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); + ty_out + .ref_to_snap + .apply(self.vcx, [self.encode_place(Place::from(source))]) } mir::Operand::Constant(box constant) => self.encode_constant(constant), } } - fn encode_operand( - &mut self, - operand: &mir::Operand<'tcx>, - ) -> vir::Expr<'vir> { + fn encode_operand(&mut self, operand: &mir::Operand<'tcx>) -> vir::Expr<'vir> { let (snap_val, ty_out) = match operand { &mir::Operand::Move(source) => return self.encode_place(Place::from(source)), &mir::Operand::Copy(source) => { let place_ty = source.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); // TODO - let ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); - let source = ty_out.ref_to_snap.apply(self.vcx, [self.encode_place(Place::from(source))]); + let ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); + let source = ty_out + .ref_to_snap + .apply(self.vcx, [self.encode_place(Place::from(source))]); (source, ty_out) } mir::Operand::Constant(box constant) => { - let ty_out = self.deps.require_ref::( - constant.ty(), - ).unwrap(); + let ty_out = self + .deps + .require_ref::(constant.ty()) + .unwrap(); (self.encode_constant(constant), ty_out) } }; @@ -437,49 +474,50 @@ impl<'tcx, 'vir, 'enc> EncoderVisitor<'tcx, 'vir, 'enc> { tmp_exp } - fn encode_place( - &mut self, - place: Place<'tcx>, - ) -> vir::Expr<'vir> { + fn encode_place(&mut self, place: Place<'tcx>) -> vir::Expr<'vir> { //assert!(place.projection.is_empty()); //self.vcx.mk_local_ex(vir::vir_format!(self.vcx, "_{}p", place.local.index())) self.project( self.local_defs.locals[place.local].local_ex, self.local_defs.locals[place.local].ty.clone(), place.projection, - ).0 + ) + .0 } // TODO: this will not work for unevaluated constants (which needs const // MIR evaluation, more like pure fn body encoding) - fn encode_constant( - &self, - constant: &mir::Constant<'tcx>, - ) -> vir::Expr<'vir> { + fn encode_constant(&self, constant: &mir::Constant<'tcx>) -> vir::Expr<'vir> { match constant.literal { - mir::ConstantKind::Val(const_val, const_ty) => { - match const_ty.kind() { - ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Tuple0_cons()"), - )), - ty::TyKind::Int(int_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Int_{}_cons({})", int_ty.name_str(), scalar_val.try_to_int(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Uint(uint_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Uint_{}_cons({})", uint_ty.name_str(), scalar_val.try_to_uint(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Bool => self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Bool_cons({})", const_val.try_to_bool().unwrap()), - )), - unsupported_ty => todo!("unsupported constant literal type: {unsupported_ty:?}"), + mir::ConstantKind::Val(const_val, const_ty) => match const_ty.kind() { + ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc(vir::ExprData::Todo( + vir::vir_format!(self.vcx, "s_Tuple0_cons()"), + )), + ty::TyKind::Int(int_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!( + self.vcx, + "s_Int_{}_cons({})", + int_ty.name_str(), + scalar_val.try_to_int(scalar_val.size()).unwrap() + ))) } - } + ty::TyKind::Uint(uint_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!( + self.vcx, + "s_Uint_{}_cons({})", + uint_ty.name_str(), + scalar_val.try_to_uint(scalar_val.size()).unwrap() + ))) + } + ty::TyKind::Bool => self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!( + self.vcx, + "s_Bool_cons({})", + const_val.try_to_bool().unwrap() + ))), + unsupported_ty => todo!("unsupported constant literal type: {unsupported_ty:?}"), + }, unsupported_literal => todo!("unsupported constant literal: {unsupported_literal:?}"), } } @@ -500,21 +538,23 @@ impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for EncoderVisitor<'tcx, 'vir, // fn visit_body(&mut self, body: &mir::Body<'tcx>) { // println!("visiting body!"); // } - fn visit_basic_block_data( - &mut self, - block: mir::BasicBlock, - data: &mir::BasicBlockData<'tcx>, - ) { + fn visit_basic_block_data(&mut self, block: mir::BasicBlock, data: &mir::BasicBlockData<'tcx>) { self.current_fpcs = Some(self.fpcs_analysis.get_all_for_bb(block)); self.current_stmts = Some(Vec::with_capacity( data.statements.len(), // TODO: not exact? )); if ENCODE_REACH_BB { - self.stmt(vir::StmtData::PureAssign(self.vcx.alloc(vir::PureAssignData { - lhs: self.vcx.mk_local_ex(vir::vir_format!(self.vcx, "_reach_bb{}", block.as_usize())), - rhs: self.vcx.mk_bool::(), - }))); + self.stmt(vir::StmtData::PureAssign(self.vcx.alloc( + vir::PureAssignData { + lhs: self.vcx.mk_local_ex(vir::vir_format!( + self.vcx, + "_reach_bb{}", + block.as_usize() + )), + rhs: self.vcx.mk_bool::(), + }, + ))); } /* @@ -548,26 +588,25 @@ impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for EncoderVisitor<'tcx, 'vir, self.super_basic_block_data(block, data); let stmts = self.current_stmts.take().unwrap(); let terminator = self.current_terminator.take().unwrap(); - self.encoded_blocks.push(self.vcx.alloc(vir::CfgBlockData { - label: self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(block.as_usize())), - stmts: self.vcx.alloc_slice(&stmts), - terminator, - })); + self.encoded_blocks.push( + self.vcx.alloc(vir::CfgBlockData { + label: self + .vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(block.as_usize())), + stmts: self.vcx.alloc_slice(&stmts), + terminator, + }), + ); } - fn visit_statement( - &mut self, - statement: &mir::Statement<'tcx>, - location: mir::Location, - ) { + fn visit_statement(&mut self, statement: &mir::Statement<'tcx>, location: mir::Location) { // TODO: these should not be ignored, but should havoc the local instead // This clears up the noise a bit, making sure StorageLive and other // kinds do not show up in the comments. let IGNORE_NOP_STMTS = true; if IGNORE_NOP_STMTS { match &statement.kind { - mir::StatementKind::StorageLive(..) - | mir::StatementKind::StorageDead(..) => { + mir::StatementKind::StorageLive(..) | mir::StatementKind::StorageDead(..) => { return; } _ => {} @@ -715,11 +754,7 @@ impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for EncoderVisitor<'tcx, 'vir, } } - fn visit_terminator( - &mut self, - terminator: &mir::Terminator<'tcx>, - location: mir::Location, - ) { + fn visit_terminator(&mut self, terminator: &mir::Terminator<'tcx>, location: mir::Location) { self.stmt(vir::StmtData::Comment( // TODO: also add bb and location for better debugging? vir::vir_format!(self.vcx, "{:?}", terminator.kind), @@ -730,34 +765,49 @@ impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for EncoderVisitor<'tcx, 'vir, self.fpcs_repacks(location, |loc| &loc.repacks_middle); let terminator = match &terminator.kind { mir::TerminatorKind::Goto { target } - | mir::TerminatorKind::FalseUnwind { real_target: target, .. } => - self.vcx.alloc(vir::TerminatorStmtData::Goto( - self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), - )), + | mir::TerminatorKind::FalseUnwind { + real_target: target, + .. + } => self.vcx.alloc(vir::TerminatorStmtData::Goto( + self.vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), + )), mir::TerminatorKind::SwitchInt { discr, targets } => { //let discr_version = self.ssa_analysis.version.get(&(location, discr.local)).unwrap(); //let discr_name = vir::vir_format!(self.vcx, "_{}s_{}", discr.local.index(), discr_version); - let ty_out = self.deps.require_ref::( - discr.ty(self.local_decls, self.vcx.tcx), - ).unwrap(); - - let goto_targets = self.vcx.alloc_slice(&targets.iter() - .map(|(value, target)| ( - ty_out.expr_from_u128(value), - //self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!(self.vcx, "constant({value})"))), - self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), - )) - .collect::>()); + let ty_out = self + .deps + .require_ref::( + discr.ty(self.local_decls, self.vcx.tcx), + ) + .unwrap(); + + let goto_targets = self.vcx.alloc_slice( + &targets + .iter() + .map(|(value, target)| { + ( + ty_out.expr_from_u128(value), + //self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!(self.vcx, "constant({value})"))), + self.vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), + ) + }) + .collect::>(), + ); let goto_otherwise = self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock( targets.otherwise().as_usize(), )); let discr_ex = self.encode_operand_snap(discr); - self.vcx.alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc(vir::GotoIfData { - value: discr_ex, // self.vcx.mk_local_ex(discr_name), - targets: goto_targets, - otherwise: goto_otherwise, - }))) + self.vcx + .alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc( + vir::GotoIfData { + value: discr_ex, // self.vcx.mk_local_ex(discr_name), + targets: goto_targets, + otherwise: goto_otherwise, + }, + ))) } mir::TerminatorKind::Return => self.vcx.alloc(vir::TerminatorStmtData::Goto( self.vcx.alloc(vir::CfgBlockLabelData::End), @@ -772,74 +822,103 @@ impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for EncoderVisitor<'tcx, 'vir, // TODO: extracting FnDef given func could be extracted? (duplication in pure) let func_ty = func.ty(self.local_decls, self.vcx.tcx); let (func_def_id, arg_tys) = match func_ty.kind() { - &ty::TyKind::FnDef(def_id, arg_tys) => { - (def_id, arg_tys) - } + &ty::TyKind::FnDef(def_id, arg_tys) => (def_id, arg_tys), _ => todo!(), }; - let is_pure = crate::encoders::with_proc_spec(func_def_id, |spec| + let is_pure = crate::encoders::with_proc_spec(func_def_id, |spec| { spec.kind.is_pure().unwrap_or_default() - ).unwrap_or_default(); + }) + .unwrap_or_default(); let dest = self.encode_place(Place::from(*destination)); let task = (func_def_id, arg_tys, self.def_id); if is_pure { - let pure_func = self.deps.require_ref::( - task - ).unwrap(); + let pure_func = self + .deps + .require_ref::(task) + .unwrap(); - let func_args: Vec<_> = args.iter().map(|op| self.encode_operand_snap(op)).collect(); + let func_args: Vec<_> = + args.iter().map(|op| self.encode_operand_snap(op)).collect(); let pure_func_app = pure_func.function_ref.apply(self.vcx, &func_args); - self.stmt(pure_func.return_type.method_assign.apply(self.vcx, [dest, pure_func_app])); + self.stmt( + pure_func + .return_type + .method_assign + .apply(self.vcx, [dest, pure_func_app]), + ); } else { - let func_out = self.deps.require_ref::( - (task.0, task.1, Some(task.2)), - ).unwrap(); + let func_out = self + .deps + .require_ref::(( + task.0, + task.1, + Some(task.2), + )) + .unwrap(); let method_in = args.iter().map(|op| self.encode_operand(op)); - let method_args = std::iter::once(dest) - .chain(method_in) - .collect::>(); + let method_args = std::iter::once(dest).chain(method_in).collect::>(); self.stmt(func_out.method_ref.apply(self.vcx, &method_args)); } - self.vcx.alloc(vir::TerminatorStmtData::Goto( - self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.unwrap().as_usize())), - )) + self.vcx.alloc(vir::TerminatorStmtData::Goto(self.vcx.alloc( + vir::CfgBlockLabelData::BasicBlock(target.unwrap().as_usize()), + ))) } - mir::TerminatorKind::Assert { cond, expected, msg, target, unwind } => { - let e_bool = self.deps.require_ref::( - self.vcx.tcx.types.bool, - ).unwrap(); + mir::TerminatorKind::Assert { + cond, + expected, + msg, + target, + unwind, + } => { + let e_bool = self + .deps + .require_ref::(self.vcx.tcx.types.bool) + .unwrap(); let enc = self.encode_operand_snap(cond); let enc = e_bool.expect_prim().snap_to_prim.apply(self.vcx, [enc]); - let expected = self.vcx.alloc(vir::ExprData::Const(self.vcx.alloc(vir::ConstData::Bool(*expected)))); - let assert = self.vcx.alloc(vir::ExprData::BinOp(self.vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: enc, - rhs: expected, - }))); + let expected = self.vcx.alloc(vir::ExprData::Const( + self.vcx.alloc(vir::ConstData::Bool(*expected)), + )); + let assert = self + .vcx + .alloc(vir::ExprData::BinOp(self.vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::CmpEq, + lhs: enc, + rhs: expected, + }))); self.stmt(vir::StmtData::Exhale(assert)); - let target_bb = self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())); + let target_bb = self + .vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())); let otherwise = match unwind { - mir::UnwindAction::Cleanup(bb) => - self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(bb.as_usize())), - _ => todo!() + mir::UnwindAction::Cleanup(bb) => self + .vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(bb.as_usize())), + _ => todo!(), }; - self.vcx.alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc(vir::GotoIfData { - value: enc, - targets: self.vcx.alloc_slice(&[(expected, &target_bb)]), - otherwise, - }))) + self.vcx + .alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc( + vir::GotoIfData { + value: enc, + targets: self.vcx.alloc_slice(&[(expected, &target_bb)]), + otherwise, + }, + ))) } - unsupported_kind => self.vcx.alloc(vir::TerminatorStmtData::Dummy( - vir::vir_format!(self.vcx, "terminator {unsupported_kind:?}"), - )), + unsupported_kind => self + .vcx + .alloc(vir::TerminatorStmtData::Dummy(vir::vir_format!( + self.vcx, + "terminator {unsupported_kind:?}" + ))), }; assert!(self.current_terminator.replace(terminator).is_none()); } diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index 06f78f10458..5c5bda86e3d 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -1,15 +1,13 @@ +use crate::encoders::{TypeEncoder, ViperTupleEncoder}; use prusti_rustc_interface::{ + ast, index::IndexVec, middle::{mir, ty}, span::def_id::DefId, - type_ir::sty::TyKind, ast, -}; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, + type_ir::sty::TyKind, }; use std::collections::HashMap; -use crate::encoders::{ViperTupleEncoder, TypeEncoder}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct MirPureEncoder; @@ -49,23 +47,23 @@ pub struct MirPureEncoderTask<'tcx> { // can we integrate the lazy context into the identifier system? pub encoding_depth: usize, pub kind: PureKind, - pub parent_def_id: DefId, // ID of the function - pub promoted: Option, // ID of a constant within the function - pub param_env: ty::ParamEnv<'tcx>, // param environment at the usage site + pub parent_def_id: DefId, // ID of the function + pub promoted: Option, // ID of a constant within the function + pub param_env: ty::ParamEnv<'tcx>, // param environment at the usage site pub substs: ty::GenericArgsRef<'tcx>, // type substitutions at the usage site - pub caller_def_id: DefId, // Caller/Use DefID + pub caller_def_id: DefId, // Caller/Use DefID } impl TaskEncoder for MirPureEncoder { type TaskDescription<'tcx> = MirPureEncoderTask<'tcx>; type TaskKey<'tcx> = ( - usize, // encoding depth - PureKind, // encoding a pure function? - DefId, // ID of the function + usize, // encoding depth + PureKind, // encoding a pure function? + DefId, // ID of the function Option, // ID of a constant within the function, or `None` if encoding the function itself ty::GenericArgsRef<'tcx>, // ? this should be the "signature", after applying the env/substs - DefId, // Caller/Use DefID + DefId, // Caller/Use DefID ); type OutputFullLocal<'vir> = MirPureEncoderOutput<'vir>; @@ -73,7 +71,8 @@ impl TaskEncoder for MirPureEncoder { type EncodingError = MirPureEncoderError; fn with_cache<'tcx: 'vir, 'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, MirPureEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, MirPureEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -99,13 +98,16 @@ impl TaskEncoder for MirPureEncoder { fn do_encode_full<'tcx: 'vir, 'vir>( task_key: &Self::TaskKey<'tcx>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { deps.emit_output_ref::(*task_key, ()); let (_, kind, def_id, promoted, substs, caller_def_id) = *task_key; @@ -114,9 +116,21 @@ impl TaskEncoder for MirPureEncoder { let expr = vir::with_vcx(move |vcx| { //let body = vcx.tcx.mir_promoted(local_def_id).0.borrow(); let body = match kind { - PureKind::Closure => vcx.body.borrow_mut().get_closure_body(def_id, substs, caller_def_id), - PureKind::Spec => vcx.body.borrow_mut().get_spec_body(def_id, substs, caller_def_id), - PureKind::Pure => vcx.body.borrow_mut().get_pure_fn_body(def_id, substs, caller_def_id), + PureKind::Closure => { + vcx.body + .borrow_mut() + .get_closure_body(def_id, substs, caller_def_id) + } + PureKind::Spec => { + vcx.body + .borrow_mut() + .get_spec_body(def_id, substs, caller_def_id) + } + PureKind::Pure => { + vcx.body + .borrow_mut() + .get_pure_fn_body(def_id, substs, caller_def_id) + } }; let expr_inner = Encoder::new(vcx, task_key.0, def_id, &body, deps).encode_body(); @@ -164,8 +178,16 @@ impl<'vir> Update<'vir> { fn merge(self, newer: Self) -> Self { Self { - binds: self.binds.into_iter().chain(newer.binds.into_iter()).collect(), - versions: self.versions.into_iter().chain(newer.versions.into_iter()).collect(), + binds: self + .binds + .into_iter() + .chain(newer.binds.into_iter()) + .collect(), + versions: self + .versions + .into_iter() + .chain(newer.versions.into_iter()) + .collect(), } } @@ -176,8 +198,7 @@ impl<'vir> Update<'vir> { } } -struct Encoder<'tcx, 'vir: 'enc, 'enc> -{ +struct Encoder<'tcx, 'vir: 'enc, 'enc> { vcx: &'vir vir::VirCtxt<'tcx>, encoding_depth: usize, def_id: DefId, @@ -188,8 +209,7 @@ struct Encoder<'tcx, 'vir: 'enc, 'enc> phi_ctr: usize, } -impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> -{ +impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> { fn new( vcx: &'vir vir::VirCtxt<'tcx>, encoding_depth: usize, @@ -197,7 +217,10 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> body: &'enc mir::Body<'tcx>, deps: &'enc mut TaskEncoderDependencies<'vir>, ) -> Self { - assert!(!body.basic_blocks.is_cfg_cyclic(), "MIR pure encoding does not support loops"); + assert!( + !body.basic_blocks.is_cfg_cyclic(), + "MIR pure encoding does not support loops" + ); Self { vcx, @@ -211,26 +234,21 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> } } - fn mk_local( - &self, - local: mir::Local, - version: usize, - ) -> &'vir str { - vir::vir_format!(self.vcx, "_{}_{}s_{}", self.encoding_depth, local.as_usize(), version) + fn mk_local(&self, local: mir::Local, version: usize) -> &'vir str { + vir::vir_format!( + self.vcx, + "_{}_{}s_{}", + self.encoding_depth, + local.as_usize(), + version + ) } - fn mk_local_ex( - &self, - local: mir::Local, - version: usize, - ) -> ExprRet<'vir> { + fn mk_local_ex(&self, local: mir::Local, version: usize) -> ExprRet<'vir> { self.vcx.mk_local_ex(self.mk_local(local, version)) } - fn mk_phi( - &self, - idx: usize, - ) -> &'vir str { + fn mk_phi(&self, idx: usize) -> &'vir str { vir::vir_format!(self.vcx, "_{}_phi_{}", self.encoding_depth, idx) } @@ -243,36 +261,34 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> tuple_ref.mk_elem(self.vcx, self.vcx.mk_local_ex(self.mk_phi(idx)), elem_idx) } - fn bump_version( - &mut self, - update: &mut Update<'vir>, - local: mir::Local, - expr: ExprRet<'vir>, - ) { + fn bump_version(&mut self, update: &mut Update<'vir>, local: mir::Local, expr: ExprRet<'vir>) { let new_version = self.version_ctr[local]; self.version_ctr[local] += 1; - update.binds.push(UpdateBind::Local(local, new_version, expr)); + update + .binds + .push(UpdateBind::Local(local, new_version, expr)); update.versions.insert(local, new_version); } - fn reify_binds( - &self, - update: Update<'vir>, - expr: ExprRet<'vir>, - ) -> ExprRet<'vir> { - update.binds.iter() - .rfold(expr, |expr, bind| match bind { - UpdateBind::Local(local, ver, val) => self.vcx.alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { - name: self.mk_local(*local, *ver), - val, - expr, - }))), - UpdateBind::Phi(idx, val) => self.vcx.alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { - name: self.mk_phi(*idx), - val, - expr, - }))), - }) + fn reify_binds(&self, update: Update<'vir>, expr: ExprRet<'vir>) -> ExprRet<'vir> { + update.binds.iter().rfold(expr, |expr, bind| match bind { + UpdateBind::Local(local, ver, val) => { + self.vcx + .alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { + name: self.mk_local(*local, *ver), + val, + expr, + }))) + } + UpdateBind::Phi(idx, val) => { + self.vcx + .alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { + name: self.mk_phi(*idx), + val, + expr, + }))) + } + }) } fn reify_branch( @@ -282,21 +298,23 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> curr_ver: &HashMap, update: Update<'vir>, ) -> ExprRet<'vir> { - let tuple_args = mod_locals.iter().map(|local| self.mk_local_ex( - *local, - update.versions.get(local).copied().unwrap_or_else(|| { - // TODO: remove (debug) - if !curr_ver.contains_key(&local) { - tracing::error!("unknown version of local! {}", local.as_usize()); - return 0xff - } - curr_ver[local] - }), - )).collect::>(); - self.reify_binds( - update, - tuple_ref.mk_cons(self.vcx, &tuple_args), - ) + let tuple_args = mod_locals + .iter() + .map(|local| { + self.mk_local_ex( + *local, + update.versions.get(local).copied().unwrap_or_else(|| { + // TODO: remove (debug) + if !curr_ver.contains_key(&local) { + tracing::error!("unknown version of local! {}", local.as_usize()); + return 0xff; + } + curr_ver[local] + }), + ) + }) + .collect::>(); + self.reify_binds(update, tuple_ref.mk_cons(self.vcx, &tuple_args)) } fn encode_body(&mut self) -> ExprRet<'vir> { @@ -307,19 +325,16 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> vir::vir_format!(self.vcx, "pure in _{local}"), Box::new(move |_vcx, lctx: ExprInput<'vir>| lctx.1[local - 1]), )); - init.binds.push(UpdateBind::Local(local.into(), 0, local_ex)); + init.binds + .push(UpdateBind::Local(local.into(), 0, local_ex)); init.versions.insert(local.into(), 0); } - let (_, update) = self.encode_cfg( - &init.versions, - mir::START_BLOCK, - None, - ); + let (_, update) = self.encode_cfg(&init.versions, mir::START_BLOCK, None); let res = init.merge(update); let ret_version = res.versions.get(&mir::RETURN_PLACE).copied().unwrap_or(0); - + self.reify_binds(res, self.mk_local_ex(mir::RETURN_PLACE, ret_version)) } @@ -333,17 +348,23 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> // We should never actually reach the join point bb: we should catch // this case and stop recursion in the `Goto` branch below. If this // assert fails we we may need to add catches in the other branches. - debug_assert!(match (dominators.immediate_dominator(curr), branch_point) { - (Some(immediate_dominator), Some(branch_point)) if immediate_dominator == branch_point => + debug_assert!( + match (dominators.immediate_dominator(curr), branch_point) { + (Some(immediate_dominator), Some(branch_point)) + if immediate_dominator == branch_point => // We could also be immediately dominated by the join point if we // are the next bb right after the `SwitchInt`. - self.body.basic_blocks.predecessors()[curr].contains(&branch_point), - _ => true, - }, "reached join point bb {curr:?} (bp {branch_point:?})"); + self.body.basic_blocks.predecessors()[curr].contains(&branch_point), + _ => true, + }, + "reached join point bb {curr:?} (bp {branch_point:?})" + ); // walk block statements first let mut new_curr_ver = curr_ver.clone(); - let stmt_update = self.body[curr].statements.iter() + let stmt_update = self.body[curr] + .statements + .iter() .fold(Update::new(), |update, stmt| { let newer = self.encode_stmt(&new_curr_ver, stmt); newer.add_to_map(&mut new_curr_ver); @@ -360,9 +381,11 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> // Walking the rest of the CFG is handled in a parent call. (Some(immediate_dominator), Some(branch_point)) if immediate_dominator == branch_point => - // We are done with the current fragment of the CFG, the - // rest is handled in a parent call. - (target, stmt_update), + // We are done with the current fragment of the CFG, the + // rest is handled in a parent call. + { + (target, stmt_update) + } _ => { // If you hit this then the join point algorithm // probably not working correctly. @@ -374,22 +397,26 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> mir::TerminatorKind::SwitchInt { discr, targets } => { // encode the discriminant operand let discr_expr = self.encode_operand(&new_curr_ver, discr); - let discr_ty_out = self.deps.require_ref::( - discr.ty(self.body, self.vcx.tcx), - ).unwrap(); + let discr_ty_out = self + .deps + .require_ref::(discr.ty(self.body, self.vcx.tcx)) + .unwrap(); // walk `curr` -> `targets[i]` -> `join` for each target. The // join point is identified by reaching a bb where // `dominators.immediate_dominator(bb)` is equal to the bb of // the branch point (so pass `branch_point: Some(curr)`). // TODO: indexvec? - let mut updates = targets.all_targets().iter() + let mut updates = targets + .all_targets() + .iter() .map(|target| self.encode_cfg(&new_curr_ver, *target, Some(curr))) .collect::>(); // find locals updated in any of the results, which were also // defined before the branch - let mut mod_locals = updates.iter() + let mut mod_locals = updates + .iter() .map(|(_, update)| update.versions.keys()) .flatten() .filter(|local| new_curr_ver.contains_key(&local)) @@ -399,26 +426,35 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> mod_locals.dedup(); // for each branch, create a Viper tuple of the updated locals - let tuple_ref = self.deps.require_ref::( - mod_locals.len(), - ).unwrap(); + let tuple_ref = self + .deps + .require_ref::(mod_locals.len()) + .unwrap(); let (join, otherwise_update) = updates.pop().unwrap(); println!("join: {curr:?} -> {join:?}"); debug_assert!(updates.iter().all(|(other, _)| join == *other)); - let phi_expr = targets.iter() - .zip(updates.into_iter()) - .fold( - self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, otherwise_update), - |expr, ((cond_val, target), (_, branch_update))| self.vcx.alloc(ExprRetData::Ternary(self.vcx.alloc(vir::TernaryGenData { - cond: self.vcx.alloc(ExprRetData::BinOp(self.vcx.alloc(vir::BinOpGenData { - kind: vir::BinOpKind::CmpEq, - lhs: discr_expr, - rhs: discr_ty_out.expr_from_u128(cond_val).lift(), - }))), - then: self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, branch_update), - else_: expr, - }))), - ); + let phi_expr = targets.iter().zip(updates.into_iter()).fold( + self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, otherwise_update), + |expr, ((cond_val, target), (_, branch_update))| { + self.vcx + .alloc(ExprRetData::Ternary(self.vcx.alloc(vir::TernaryGenData { + cond: self.vcx.alloc(ExprRetData::BinOp(self.vcx.alloc( + vir::BinOpGenData { + kind: vir::BinOpKind::CmpEq, + lhs: discr_expr, + rhs: discr_ty_out.expr_from_u128(cond_val).lift(), + }, + ))), + then: self.reify_branch( + &tuple_ref, + &mod_locals, + &new_curr_ver, + branch_update, + ), + else_: expr, + }))) + }, + ); // assign tuple into a `phi` variable let phi_idx = self.phi_ctr; @@ -459,14 +495,22 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> // A fn call in pure can only be one of two kinds: a // call to another pure function, or a call to a prusti // builtin function. - let is_pure = crate::encoders::with_proc_spec(def_id, |def_spec| + let is_pure = crate::encoders::with_proc_spec(def_id, |def_spec| { def_spec.kind.is_pure().unwrap_or_default() - ).unwrap_or_default(); + }) + .unwrap_or_default(); if is_pure { - let pure_func = self.deps.require_ref::( - (def_id, arg_tys, self.def_id) - ).unwrap().function_ref; - let encoded_args = args.iter() + let pure_func = self + .deps + .require_ref::(( + def_id, + arg_tys, + self.def_id, + )) + .unwrap() + .function_ref; + let encoded_args = args + .iter() .map(|oper| self.encode_operand(&new_curr_ver, oper)) .collect::>(); pure_func.apply(self.vcx, &encoded_args) @@ -483,9 +527,10 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> term_update.add_to_map(&mut new_curr_ver); // walk rest of CFG - let (end, end_update) = self.encode_cfg(&new_curr_ver, target.unwrap(), branch_point); + let (end, end_update) = + self.encode_cfg(&new_curr_ver, target.unwrap(), branch_point); - (end, stmt_update.merge(term_update).merge(end_update)) + (end, stmt_update.merge(term_update).merge(end_update)) } k => todo!("terminator kind {k:?}"), @@ -503,11 +548,11 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> let new_version = self.version_ctr[local]; self.version_ctr[local] += 1; update.versions.insert(local, new_version); - }, + } mir::StatementKind::StorageDead(..) | mir::StatementKind::FakeRead(..) - | mir::StatementKind::AscribeUserType(..) - | mir::StatementKind::PlaceMention(..) => {}, // nop + | mir::StatementKind::AscribeUserType(..) + | mir::StatementKind::PlaceMention(..) => {} // nop mir::StatementKind::Assign(box (dest, rvalue)) => { assert!(dest.projection.is_empty()); let expr = self.encode_rvalue(curr_ver, rvalue); @@ -532,72 +577,108 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> // Len // Cast mir::Rvalue::BinaryOp(op, box (l, r)) => { - let ty_l = self.deps.require_ref::( - l.ty(self.body, self.vcx.tcx), - ).unwrap().expect_prim().snap_to_prim; - let ty_r = self.deps.require_ref::( - r.ty(self.body, self.vcx.tcx), - ).unwrap().expect_prim().snap_to_prim; - let ty_rvalue = self.deps.require_ref::( - rvalue.ty(self.body, self.vcx.tcx), - ).unwrap().expect_prim().prim_to_snap; - - ty_rvalue.apply(self.vcx, - [self.vcx.alloc(ExprRetData::BinOp(self.vcx.alloc(vir::BinOpGenData { - kind: op.into(), - lhs: ty_l.apply(self.vcx, [self.encode_operand(curr_ver, l)]), - rhs: ty_r.apply(self.vcx, [self.encode_operand(curr_ver, r)]), - })))], + let ty_l = self + .deps + .require_ref::(l.ty(self.body, self.vcx.tcx)) + .unwrap() + .expect_prim() + .snap_to_prim; + let ty_r = self + .deps + .require_ref::(r.ty(self.body, self.vcx.tcx)) + .unwrap() + .expect_prim() + .snap_to_prim; + let ty_rvalue = self + .deps + .require_ref::(rvalue.ty(self.body, self.vcx.tcx)) + .unwrap() + .expect_prim() + .prim_to_snap; + + ty_rvalue.apply( + self.vcx, + [self + .vcx + .alloc(ExprRetData::BinOp(self.vcx.alloc(vir::BinOpGenData { + kind: op.into(), + lhs: ty_l.apply(self.vcx, [self.encode_operand(curr_ver, l)]), + rhs: ty_r.apply(self.vcx, [self.encode_operand(curr_ver, r)]), + })))], ) } // CheckedBinaryOp // NullaryOp mir::Rvalue::UnaryOp(op, expr) => { - let ty_expr = self.deps.require_ref::( - expr.ty(self.body, self.vcx.tcx), - ).unwrap().expect_prim().snap_to_prim; - let ty_rvalue = self.deps.require_ref::( - rvalue.ty(self.body, self.vcx.tcx), - ).unwrap().expect_prim().prim_to_snap; - - ty_rvalue.apply(self.vcx, - [self.vcx.alloc(ExprRetData::UnOp(self.vcx.alloc(vir::UnOpGenData { - kind: op.into(), - expr: ty_expr.apply(self.vcx, [self.encode_operand(curr_ver, expr)]), - })))] + let ty_expr = self + .deps + .require_ref::(expr.ty(self.body, self.vcx.tcx)) + .unwrap() + .expect_prim() + .snap_to_prim; + let ty_rvalue = self + .deps + .require_ref::(rvalue.ty(self.body, self.vcx.tcx)) + .unwrap() + .expect_prim() + .prim_to_snap; + + ty_rvalue.apply( + self.vcx, + [self + .vcx + .alloc(ExprRetData::UnOp(self.vcx.alloc(vir::UnOpGenData { + kind: op.into(), + expr: ty_expr.apply(self.vcx, [self.encode_operand(curr_ver, expr)]), + })))], ) } // Discriminant mir::Rvalue::Aggregate(box kind, fields) => match kind { mir::AggregateKind::Tuple if fields.len() == 0 => - // TODO: why is this not a constant? - self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Tuple0_cons()"), - )), + // TODO: why is this not a constant? + { + self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Tuple0_cons()" + ))) + } mir::AggregateKind::Closure(..) => { // TODO: only when this is a spec closure? - let tuple_ref = self.deps.require_ref::( - fields.len(), - ).unwrap(); - tuple_ref.mk_cons(self.vcx, &fields.iter() - .map(|field| self.encode_operand(curr_ver, field)) - .collect::>()) + let tuple_ref = self + .deps + .require_ref::(fields.len()) + .unwrap(); + tuple_ref.mk_cons( + self.vcx, + &fields + .iter() + .map(|field| self.encode_operand(curr_ver, field)) + .collect::>(), + ) } _ => todo!("Unsupported Rvalue::AggregateKind: {kind:?}"), - } + }, mir::Rvalue::CheckedBinaryOp(binop, box (l, r)) => { - let binop_function = self.deps.require_ref::( - crate::encoders::MirBuiltinEncoderTask::CheckedBinOp( - rvalue.ty(self.body, self.vcx.tcx), - *binop, - l.ty(self.body, self.vcx.tcx), - r.ty(self.body, self.vcx.tcx), - ), - ).unwrap().function; - binop_function.apply(self.vcx, &[ - self.encode_operand(curr_ver, l), - self.encode_operand(curr_ver, r), - ]) + let binop_function = self + .deps + .require_ref::( + crate::encoders::MirBuiltinEncoderTask::CheckedBinOp( + rvalue.ty(self.body, self.vcx.tcx), + *binop, + l.ty(self.body, self.vcx.tcx), + r.ty(self.body, self.vcx.tcx), + ), + ) + .unwrap() + .function; + binop_function.apply( + self.vcx, + &[ + self.encode_operand(curr_ver, l), + self.encode_operand(curr_ver, r), + ], + ) } // ShallowInitBox // CopyForDeref @@ -614,35 +695,46 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> operand: &mir::Operand<'tcx>, ) -> ExprRet<'vir> { match operand { - mir::Operand::Copy(place) - | mir::Operand::Move(place) => self.encode_place(curr_ver, place), + mir::Operand::Copy(place) | mir::Operand::Move(place) => { + self.encode_place(curr_ver, place) + } mir::Operand::Constant(box constant) => { // TODO: duplicated from mir_impure! match constant.literal { - mir::ConstantKind::Val(const_val, const_ty) => { - match const_ty.kind() { - ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Tuple0_cons()"), - )), - ty::TyKind::Int(int_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Int_{}_cons({})", int_ty.name_str(), scalar_val.try_to_int(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Uint(uint_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Uint_{}_cons({})", uint_ty.name_str(), scalar_val.try_to_uint(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Bool => self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Bool_cons({})", const_val.try_to_bool().unwrap()), - )), - unsupported_ty => todo!("unsupported constant literal type: {unsupported_ty:?}"), + mir::ConstantKind::Val(const_val, const_ty) => match const_ty.kind() { + ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc( + ExprRetData::Todo(vir::vir_format!(self.vcx, "s_Tuple0_cons()")), + ), + ty::TyKind::Int(int_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Int_{}_cons({})", + int_ty.name_str(), + scalar_val.try_to_int(scalar_val.size()).unwrap() + ))) + } + ty::TyKind::Uint(uint_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Uint_{}_cons({})", + uint_ty.name_str(), + scalar_val.try_to_uint(scalar_val.size()).unwrap() + ))) } + ty::TyKind::Bool => self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Bool_cons({})", + const_val.try_to_bool().unwrap() + ))), + unsupported_ty => { + todo!("unsupported constant literal type: {unsupported_ty:?}") + } + }, + unsupported_literal => { + todo!("unsupported constant literal: {unsupported_literal:?}") } - unsupported_literal => todo!("unsupported constant literal: {unsupported_literal:?}"), } } } @@ -656,12 +748,14 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> // TODO: remove (debug) if !curr_ver.contains_key(&place.local) { tracing::error!("unknown version of local! {}", place.local.as_usize()); - return self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "unknown_version_{}", place.local.as_usize()), - )); + return self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "unknown_version_{}", + place.local.as_usize() + ))); } - let mut place_ty = mir::tcx::PlaceTy::from_ty(self.body.local_decls[place.local].ty); + let mut place_ty = mir::tcx::PlaceTy::from_ty(self.body.local_decls[place.local].ty); let mut expr = self.mk_local_ex(place.local, curr_ver[&place.local]); for elem in place.projection { expr = self.encode_place_element(place_ty.ty, elem, expr); @@ -670,18 +764,25 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> expr } - fn encode_place_element(&mut self, ty: ty::Ty<'tcx>, elem: mir::PlaceElem<'tcx>, expr: ExprRet<'vir>) -> ExprRet<'vir> { - match elem { - mir::ProjectionElem::Deref => - expr, + fn encode_place_element( + &mut self, + ty: ty::Ty<'tcx>, + elem: mir::PlaceElem<'tcx>, + expr: ExprRet<'vir>, + ) -> ExprRet<'vir> { + match elem { + mir::ProjectionElem::Deref => expr, mir::ProjectionElem::Field(field_idx, _) => { - let field_idx= field_idx.as_usize(); + let field_idx = field_idx.as_usize(); match ty.kind() { TyKind::Closure(_def_id, args) => { - let upvars = args.as_closure().upvar_tys().iter().collect::>().len(); - let tuple_ref = self.deps.require_ref::( - upvars, - ).unwrap(); + let upvars = args + .as_closure() + .upvar_tys() + .iter() + .collect::>() + .len(); + let tuple_ref = self.deps.require_ref::(upvars).unwrap(); tuple_ref.mk_elem(self.vcx, expr, field_idx) } _ => { @@ -696,7 +797,13 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> } } - fn encode_prusti_builtin(&mut self, curr_ver: &HashMap, def_id: DefId, arg_tys: ty::GenericArgsRef<'tcx>, args: &Vec>) -> ExprRet<'vir> { + fn encode_prusti_builtin( + &mut self, + curr_ver: &HashMap, + def_id: DefId, + arg_tys: ty::GenericArgsRef<'tcx>, + args: &Vec>, + ) -> ExprRet<'vir> { #[derive(Debug)] enum PrustiBuiltin { Forall, @@ -705,19 +812,19 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> // TODO: this attribute extraction should be done elsewhere? let attrs = self.vcx.tcx.get_attrs_unchecked(def_id); - let normal_attrs = attrs.iter() + let normal_attrs = attrs + .iter() .filter(|attr| !attr.is_doc_comment()) - .map(|attr| attr.get_normal_item()).collect::>(); + .map(|attr| attr.get_normal_item()) + .collect::>(); let mut builtin = None; - for attr in normal_attrs.iter() - .filter(|item| item.path.segments.len() == 2 + for attr in normal_attrs.iter().filter(|item| { + item.path.segments.len() == 2 && item.path.segments[0].ident.as_str() == "prusti" - && item.path.segments[1].ident.as_str() == "builtin") { + && item.path.segments[1].ident.as_str() == "builtin" + }) { match &attr.args { - ast::AttrArgs::Eq( - _, - ast::AttrArgsEq::Hir(lit), - ) => { + ast::AttrArgs::Eq(_, ast::AttrArgsEq::Hir(lit)) => { assert!(builtin.is_none(), "multiple prusti::builtin"); builtin = Some(match lit.symbol.as_str() { "forall" => PrustiBuiltin::Forall, @@ -728,21 +835,26 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> _ => panic!("illegal prusti::builtin"), } } - + match builtin.expect("call to unknown non-pure function in pure code") { PrustiBuiltin::SnapshotEquality => { assert_eq!(args.len(), 2); let lhs = self.encode_operand(&curr_ver, &args[0]); let rhs = self.encode_operand(&curr_ver, &args[1]); - let eq_expr = self.vcx.alloc(vir::ExprGenData::BinOp(self.vcx.alloc(vir::BinOpGenData { - kind: vir::BinOpKind::CmpEq, - lhs, - rhs, - }))); - - let bool_cons = self.deps.require_ref::( - self.vcx.tcx.types.bool, - ).unwrap().expect_prim().prim_to_snap; + let eq_expr = + self.vcx + .alloc(vir::ExprGenData::BinOp(self.vcx.alloc(vir::BinOpGenData { + kind: vir::BinOpKind::CmpEq, + lhs, + rhs, + }))); + + let bool_cons = self + .deps + .require_ref::(self.vcx.tcx.types.bool) + .unwrap() + .expect_prim() + .prim_to_snap; bool_cons.apply(self.vcx, [eq_expr]) } PrustiBuiltin::Forall => { @@ -759,18 +871,22 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> _ => panic!("illegal prusti::forall"), }; - let qvars = self.vcx.alloc_slice(&qvar_tys.iter() - .enumerate() - .map(|(idx, qvar_ty)| { - let ty_out = self.deps.require_ref::( - qvar_ty, - ).unwrap(); - self.vcx.mk_local_decl( - vir::vir_format!(self.vcx, "qvar_{}_{idx}", self.encoding_depth), - ty_out.snapshot, - ) - }) - .collect::>()); + let qvars = self.vcx.alloc_slice( + &qvar_tys + .iter() + .enumerate() + .map(|(idx, qvar_ty)| { + let ty_out = self + .deps + .require_ref::(qvar_ty) + .unwrap(); + self.vcx.mk_local_decl( + vir::vir_format!(self.vcx, "qvar_{}_{idx}", self.encoding_depth), + ty_out.snapshot, + ) + }) + .collect::>(), + ); //let qvar_tuple_ref = self.deps.require_ref::( // qvars.len(), //).unwrap(); @@ -787,21 +903,24 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> // alternatively, can we have an "unlift" // operation, which will work like reify // but panicking on a Lazy(..)? - reify_args.push(unsafe { - std::mem::transmute(self.encode_operand(&curr_ver, &args[1])) - }); - reify_args.extend((0..qvars.len()) - .map(|idx| self.vcx.mk_local_ex( - vir::vir_format!(self.vcx, "qvar_{}_{idx}", self.encoding_depth), - ))); + reify_args + .push(unsafe { std::mem::transmute(self.encode_operand(&curr_ver, &args[1])) }); + reify_args.extend((0..qvars.len()).map(|idx| { + self.vcx.mk_local_ex(vir::vir_format!( + self.vcx, + "qvar_{}_{idx}", + self.encoding_depth + )) + })); // TODO: recursively invoke MirPure encoder to encode // the body of the closure; pass the closure as the // variable to use, then closure access = tuple access // (then hope to optimise this away later ...?) use vir::Reify; - let body = self.deps.require_local::( - MirPureEncoderTask { + let body = self + .deps + .require_local::(MirPureEncoderTask { encoding_depth: self.encoding_depth + 1, kind: PureKind::Closure, parent_def_id: cl_def_id, @@ -809,27 +928,31 @@ impl<'tcx, 'vir: 'enc, 'enc> Encoder<'tcx, 'vir, 'enc> param_env: self.vcx.tcx.param_env(cl_def_id), substs: ty::List::identity_for_item(self.vcx.tcx, cl_def_id), caller_def_id: self.def_id, - } - ).unwrap().expr - // arguments to the closure are - // - the closure itself - // - the qvars - .reify(self.vcx, ( - cl_def_id, - self.vcx.alloc_slice(&reify_args), - )) + }) + .unwrap() + .expr + // arguments to the closure are + // - the closure itself + // - the qvars + .reify(self.vcx, (cl_def_id, self.vcx.alloc_slice(&reify_args))) .lift(); - let bool = self.deps.require_ref::( - self.vcx.tcx.types.bool, - ).unwrap(); + let bool = self + .deps + .require_ref::(self.vcx.tcx.types.bool) + .unwrap(); let bool = bool.expect_prim(); - let forall = bool.prim_to_snap.apply(self.vcx, [self.vcx.alloc(ExprRetData::Forall(self.vcx.alloc(vir::ForallGenData { - qvars, - triggers: &[], // TODO - body: bool.snap_to_prim.apply(self.vcx, [body]), - })))]); + let forall = bool.prim_to_snap.apply( + self.vcx, + [self + .vcx + .alloc(ExprRetData::Forall(self.vcx.alloc(vir::ForallGenData { + qvars, + triggers: &[], // TODO + body: bool.snap_to_prim.apply(self.vcx, [body]), + })))], + ); forall } diff --git a/prusti-encoder/src/encoders/mir_pure_function.rs b/prusti-encoder/src/encoders/mir_pure_function.rs index 00a15b3edfe..baee1e6af09 100644 --- a/prusti-encoder/src/encoders/mir_pure_function.rs +++ b/prusti-encoder/src/encoders/mir_pure_function.rs @@ -1,11 +1,15 @@ -use prusti_rustc_interface::{middle::{mir, ty}, span::def_id::DefId}; +use prusti_rustc_interface::{ + middle::{mir, ty}, + span::def_id::DefId, +}; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; -use vir::{Reify, FunctionIdent, UnknownArity, CallableIdent}; use std::cell::RefCell; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use vir::{CallableIdent, FunctionIdent, Reify, UnknownArity}; use crate::encoders::{ - MirPureEncoder, MirPureEncoderTask, SpecEncoder, SpecEncoderTask, TypeEncoder, mir_pure::PureKind, + mir_pure::PureKind, MirPureEncoder, MirPureEncoderTask, SpecEncoder, SpecEncoderTask, + TypeEncoder, }; use super::TypeEncoderOutputRef; @@ -35,9 +39,9 @@ thread_local! { impl TaskEncoder for MirFunctionEncoder { type TaskDescription<'vir> = ( - DefId, // ID of the function + DefId, // ID of the function ty::GenericArgsRef<'vir>, // ? this should be the "signature", after applying the env/substs - DefId, // Caller DefID + DefId, // Caller DefID ); type OutputRef<'vir> = MirFunctionEncoderOutputRef<'vir>; @@ -76,36 +80,61 @@ impl TaskEncoder for MirFunctionEncoder { ), > { let (def_id, substs, caller_def_id) = *task_key; - let trusted = crate::encoders::with_proc_spec(def_id, |def_spec| + let trusted = crate::encoders::with_proc_spec(def_id, |def_spec| { def_spec.trusted.extract_inherit().unwrap_or_default() - ).unwrap_or_default(); + }) + .unwrap_or_default(); vir::with_vcx(|vcx| { - let local_defs = deps.require_local::( - (def_id, substs, Some(caller_def_id)), - ).unwrap(); + let local_defs = deps + .require_local::(( + def_id, + substs, + Some(caller_def_id), + )) + .unwrap(); tracing::debug!("encoding {def_id:?}"); let extra: String = substs.iter().map(|s| format!("_{s}")).collect(); let (krate, index) = (caller_def_id.krate, caller_def_id.index.index()); - let function_name = vir::vir_format!(vcx, "f_{}{extra}_CALLER_{krate}_{index}", vcx.tcx.item_name(def_id)); + let function_name = vir::vir_format!( + vcx, + "f_{}{extra}_CALLER_{krate}_{index}", + vcx.tcx.item_name(def_id) + ); let args: Vec<_> = (1..=local_defs.arg_count) .map(mir::Local::from) .map(|def_idx| local_defs.locals[def_idx].ty.snapshot) .collect(); let args = UnknownArity::new(vcx.alloc_slice(&args)); let function_ref = FunctionIdent::new(function_name, args); - deps.emit_output_ref::(*task_key, MirFunctionEncoderOutputRef { function_ref, return_type: local_defs.locals[mir::RETURN_PLACE].ty }); - - let spec = deps.require_local::( - (def_id, substs, Some(caller_def_id), true) - ).unwrap(); - - let func_args: Vec<_> = (1..=local_defs.arg_count).map(mir::Local::from).map(|arg| vcx.alloc(vir::LocalDeclData { - name: local_defs.locals[arg].local.name, - ty: local_defs.locals[arg].ty.snapshot, - })).collect(); + deps.emit_output_ref::( + *task_key, + MirFunctionEncoderOutputRef { + function_ref, + return_type: local_defs.locals[mir::RETURN_PLACE].ty, + }, + ); + + let spec = deps + .require_local::(( + def_id, + substs, + Some(caller_def_id), + true, + )) + .unwrap(); + + let func_args: Vec<_> = (1..=local_defs.arg_count) + .map(mir::Local::from) + .map(|arg| { + vcx.alloc(vir::LocalDeclData { + name: local_defs.locals[arg].local.name, + ty: local_defs.locals[arg].ty.snapshot, + }) + }) + .collect(); let expr = if trusted { None diff --git a/prusti-encoder/src/encoders/mod.rs b/prusti-encoder/src/encoders/mod.rs index a61a504c62f..41e26e814ac 100644 --- a/prusti-encoder/src/encoders/mod.rs +++ b/prusti-encoder/src/encoders/mod.rs @@ -10,30 +10,12 @@ pub mod pure; pub mod local_def; pub use generic::GenericEncoder; -pub use mir_builtin::{ - MirBuiltinEncoder, - MirBuiltinEncoderTask, -}; +pub use mir_builtin::{MirBuiltinEncoder, MirBuiltinEncoderTask}; pub use mir_impure::MirImpureEncoder; -pub use mir_pure::{ - MirPureEncoder, - MirPureEncoderTask, -}; -pub use spec::{ - SpecEncoder, - SpecEncoderOutput, - SpecEncoderTask, -}; +pub use mir_pure::{MirPureEncoder, MirPureEncoderTask}; pub(super) use spec::{init_def_spec, with_def_spec, with_proc_spec}; -pub use typ::{ - TypeEncoder, - TypeEncoderOutputRef, - TypeEncoderOutput, -}; -pub use viper_tuple::{ - ViperTupleEncoder, - ViperTupleEncoderOutputRef, - ViperTupleEncoderOutput, -}; +pub use spec::{SpecEncoder, SpecEncoderOutput, SpecEncoderTask}; +pub use typ::{TypeEncoder, TypeEncoderOutput, TypeEncoderOutputRef}; +pub use viper_tuple::{ViperTupleEncoder, ViperTupleEncoderOutput, ViperTupleEncoderOutputRef}; pub use mir_pure_function::MirFunctionEncoder; diff --git a/prusti-encoder/src/encoders/pure/spec.rs b/prusti-encoder/src/encoders/pure/spec.rs index 66ba6aa50fb..e7d0e8604c2 100644 --- a/prusti-encoder/src/encoders/pure/spec.rs +++ b/prusti-encoder/src/encoders/pure/spec.rs @@ -1,10 +1,13 @@ -use prusti_rustc_interface::{middle::{mir, ty}, span::def_id::DefId}; +use prusti_rustc_interface::{ + middle::{mir, ty}, + span::def_id::DefId, +}; +use std::cell::RefCell; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use vir::Reify; -use std::cell::RefCell; -use crate::encoders::{MirPureEncoder, mir_pure::PureKind}; +use crate::encoders::{mir_pure::PureKind, MirPureEncoder}; pub struct MirSpecEncoder; #[derive(Clone)] @@ -21,10 +24,10 @@ thread_local! { impl TaskEncoder for MirSpecEncoder { type TaskDescription<'tcx> = ( - DefId, // The function annotated with specs + DefId, // The function annotated with specs ty::GenericArgsRef<'tcx>, // ? this should be the "signature", after applying the env/substs - Option, // ID of the caller function, if any - bool, // If to encode as pure or not + Option, // ID of the caller function, if any + bool, // If to encode as pure or not ); type OutputFullLocal<'vir> = MirSpecEncoderOutput<'vir>; @@ -64,22 +67,26 @@ impl TaskEncoder for MirSpecEncoder { let (def_id, substs, caller_def_id, pure) = *task_key; deps.emit_output_ref::(*task_key, ()); - let local_defs = deps.require_local::( - (def_id, substs, caller_def_id), - ).unwrap(); - let specs = deps.require_local::( - crate::encoders::SpecEncoderTask { + let local_defs = deps + .require_local::(( + def_id, + substs, + caller_def_id, + )) + .unwrap(); + let specs = deps + .require_local::(crate::encoders::SpecEncoderTask { def_id, - } - ).unwrap(); + }) + .unwrap(); vir::with_vcx(|vcx| { let local_iter = (1..=local_defs.arg_count).map(mir::Local::from); let all_args: Vec<_> = if pure { - local_iter - .map(|local| local_defs.locals[local].local_ex) - .chain([vcx.mk_local_ex(vir::vir_format!(vcx, "result"))]) - .collect() + local_iter + .map(|local| local_defs.locals[local].local_ex) + .chain([vcx.mk_local_ex(vir::vir_format!(vcx, "result"))]) + .collect() } else { local_iter .map(|local| local_defs.locals[local].impure_snap) @@ -92,53 +99,69 @@ impl TaskEncoder for MirSpecEncoder { all_args }; - let to_bool = deps.require_ref::( - vcx.tcx.types.bool, - ).unwrap().expect_prim().snap_to_prim; + let to_bool = deps + .require_ref::(vcx.tcx.types.bool) + .unwrap() + .expect_prim() + .snap_to_prim; - let pres = specs.pres.iter().map(|spec_def_id| { - let expr = deps.require_local::( - crate::encoders::MirPureEncoderTask { - encoding_depth: 0, - kind: PureKind::Spec, - parent_def_id: *spec_def_id, - promoted: None, - param_env: vcx.tcx.param_env(spec_def_id), - substs, - // TODO: should this be `def_id` or `caller_def_id` - caller_def_id: def_id, - } - ).unwrap().expr; - let expr = expr.reify(vcx, (*spec_def_id, pre_args)); - to_bool.apply(vcx, [expr]) - }).collect::>>(); + let pres = specs + .pres + .iter() + .map(|spec_def_id| { + let expr = deps + .require_local::( + crate::encoders::MirPureEncoderTask { + encoding_depth: 0, + kind: PureKind::Spec, + parent_def_id: *spec_def_id, + promoted: None, + param_env: vcx.tcx.param_env(spec_def_id), + substs, + // TODO: should this be `def_id` or `caller_def_id` + caller_def_id: def_id, + }, + ) + .unwrap() + .expr; + let expr = expr.reify(vcx, (*spec_def_id, pre_args)); + to_bool.apply(vcx, [expr]) + }) + .collect::>>(); let post_args = if pure { all_args } else { - let post_args: Vec<_> = pre_args.iter().map(|arg| - vcx.alloc(vir::ExprData::Old(arg)) - ) + let post_args: Vec<_> = pre_args + .iter() + .map(|arg| vcx.alloc(vir::ExprData::Old(arg))) .chain([local_defs.locals[mir::RETURN_PLACE].impure_snap]) .collect(); vcx.alloc_slice(&post_args) }; - let posts = specs.posts.iter().map(|spec_def_id| { - let expr = deps.require_local::( - crate::encoders::MirPureEncoderTask { - encoding_depth: 0, - kind: PureKind::Spec, - parent_def_id: *spec_def_id, - promoted: None, - param_env: vcx.tcx.param_env(spec_def_id), - substs, - // TODO: should this be `def_id` or `caller_def_id` - caller_def_id: def_id, - } - ).unwrap().expr; - let expr = expr.reify(vcx, (*spec_def_id, post_args)); - to_bool.apply(vcx, [expr]) - }).collect::>>(); + let posts = specs + .posts + .iter() + .map(|spec_def_id| { + let expr = deps + .require_local::( + crate::encoders::MirPureEncoderTask { + encoding_depth: 0, + kind: PureKind::Spec, + parent_def_id: *spec_def_id, + promoted: None, + param_env: vcx.tcx.param_env(spec_def_id), + substs, + // TODO: should this be `def_id` or `caller_def_id` + caller_def_id: def_id, + }, + ) + .unwrap() + .expr; + let expr = expr.reify(vcx, (*spec_def_id, post_args)); + to_bool.apply(vcx, [expr]) + }) + .collect::>>(); let data = MirSpecEncoderOutput { pres, posts, diff --git a/prusti-encoder/src/encoders/spec.rs b/prusti-encoder/src/encoders/spec.rs index 74a3fd64fd8..0367e4a8d35 100644 --- a/prusti-encoder/src/encoders/spec.rs +++ b/prusti-encoder/src/encoders/spec.rs @@ -1,12 +1,6 @@ -use prusti_rustc_interface::{ - //middle::{mir, ty}, - span::def_id::DefId, -}; use prusti_interface::specs::typed::{DefSpecificationMap, ProcedureSpecification}; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use prusti_rustc_interface::span::def_id::DefId; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct SpecEncoder; @@ -42,7 +36,10 @@ where DEF_SPEC_MAP.with_borrow(|def_spec: &Option| { let def_spec = def_spec.as_ref().unwrap(); // TODO: handle `SpecGraph` better than simply taking the `base_spec` - def_spec.get_proc_spec(&def_id).map(|spec| &spec.base_spec).map(f) + def_spec + .get_proc_spec(&def_id) + .map(|spec| &spec.base_spec) + .map(f) }) } @@ -53,7 +50,7 @@ pub fn init_def_spec(def_spec: DefSpecificationMap) { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct SpecEncoderTask { pub def_id: DefId, // ID of the function - // TODO: substs here? + // TODO: substs here? } impl TaskEncoder for SpecEncoder { @@ -68,7 +65,8 @@ impl TaskEncoder for SpecEncoder { type EncodingError = SpecEncoderError; fn with_cache<'tcx, 'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, SpecEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, SpecEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -89,13 +87,16 @@ impl TaskEncoder for SpecEncoder { fn do_encode_full<'tcx: 'vir, 'vir>( task_key: &Self::TaskKey<'tcx>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { deps.emit_output_ref::(task_key.clone(), ()); vir::with_vcx(|vcx| { with_def_spec(|def_spec| { @@ -109,7 +110,7 @@ impl TaskEncoder for SpecEncoder { .and_then(|specs| specs.base_spec.posts.expect_empty_or_inherent()) .map(|specs| vcx.alloc_slice(specs)) .unwrap_or_default(); - Ok((SpecEncoderOutput { pres, posts, }, () )) + Ok((SpecEncoderOutput { pres, posts }, ())) }) }) } diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index cb8419e64c7..898c30f6632 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -1,10 +1,10 @@ use prusti_rustc_interface::middle::ty; use rustc_type_ir::sty::TyKind; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use vir::{ + BinaryArity, CallableIdent, FunctionIdent, MethodIdent, NullaryArity, PredicateIdent, + UnaryArity, UnknownArity, }; -use vir::{BinaryArity, UnaryArity, NullaryArity, UnknownArity, FunctionIdent, MethodIdent, PredicateIdent, CallableIdent}; pub struct TypeEncoder; @@ -92,16 +92,23 @@ impl<'vir> TypeEncoderOutputRef<'vir> { pub fn expr_from_u128(&self, val: u128) -> vir::Expr<'vir> { match self.expect_prim().prim_type { vir::TypeData::Bool => vir::with_vcx(|vcx| { - self.expect_prim().prim_to_snap.apply(vcx, [vcx.alloc(vir::ExprData::Const( - vcx.alloc(vir::ConstData::Bool(val != 0)), - ))]) + self.expect_prim().prim_to_snap.apply( + vcx, + [vcx.alloc(vir::ExprData::Const( + vcx.alloc(vir::ConstData::Bool(val != 0)), + ))], + ) }), vir::TypeData::Int { signed: false, .. } => vir::with_vcx(|vcx| { - self.expect_prim().prim_to_snap.apply(vcx, [ - vcx.alloc(vir::ExprData::Const(vcx.alloc(vir::ConstData::Int(val)))) - ]) + self.expect_prim().prim_to_snap.apply( + vcx, + [vcx.alloc(vir::ExprData::Const(vcx.alloc(vir::ConstData::Int(val))))], + ) }), - k => todo!("unsupported type in expr_from_u128: {k:?} ({:?})", self.snapshot), + k => todo!( + "unsupported type in expr_from_u128: {k:?} ({:?})", + self.snapshot + ), } } } @@ -134,7 +141,8 @@ impl TaskEncoder for TypeEncoder { type EncodingError = TypeEncoderError; fn with_cache<'tcx: 'vir, 'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, TypeEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, TypeEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -181,13 +189,16 @@ impl TaskEncoder for TypeEncoder { fn do_encode_full<'tcx: 'vir, 'vir>( task_key: &Self::TaskKey<'tcx>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { fn mk_unreachable<'vir, 'tcx>( vcx: &'vir vir::VirCtxt<'tcx>, unreachable_fn: FunctionIdent<'vir, NullaryArity>, @@ -208,8 +219,8 @@ impl TaskEncoder for TypeEncoder { field_name: &'vir str, ) -> vir::Predicate<'vir> { let predicate_body = vcx.alloc(vir::ExprData::AccField(vcx.alloc(vir::AccFieldData { - recv: vcx.mk_local_ex("self_p"), - field: field_name, + recv: vcx.mk_local_ex("self_p"), + field: field_name, }))); vir::vir_predicate! { vcx; predicate [predicate_name](self_p: Ref) { [predicate_body] } } } @@ -295,13 +306,8 @@ impl TaskEncoder for TypeEncoder { }) } - fn mk_predicate_ident<'vir>( - name_p: &'vir str, - ) -> PredicateIdent<'vir, UnaryArity<'vir>> { - PredicateIdent::new( - name_p, - UnaryArity::new([&vir::TypeData::Ref]), - ) + fn mk_predicate_ident<'vir>(name_p: &'vir str) -> PredicateIdent<'vir, UnaryArity<'vir>> { + PredicateIdent::new(name_p, UnaryArity::new([&vir::TypeData::Ref])) } fn mk_from_fields<'tcx, 'vir>( @@ -311,14 +317,14 @@ impl TaskEncoder for TypeEncoder { ) -> FunctionIdent<'vir, UnknownArity<'vir>> { FunctionIdent::new( vir::vir_format!(vcx, "{name_s}_cons"), - UnknownArity::new(args) + UnknownArity::new(args), ) } fn mk_function_snap_identifier<'tcx, 'vir>( vcx: &'vir vir::VirCtxt<'tcx>, name_p: &'vir str, - ty: vir::Type<'vir> + ty: vir::Type<'vir>, ) -> FunctionIdent<'vir, UnaryArity<'vir>> { FunctionIdent::new( vir::vir_format!(vcx, "{name_p}_snap"), @@ -331,7 +337,10 @@ impl TaskEncoder for TypeEncoder { name_s: &'vir str, ty_s: vir::Type<'vir>, ty_prim: vir::Type<'vir>, - ) -> (FunctionIdent<'vir, UnaryArity<'vir>>, FunctionIdent<'vir, UnaryArity<'vir>>) { + ) -> ( + FunctionIdent<'vir, UnaryArity<'vir>>, + FunctionIdent<'vir, UnaryArity<'vir>>, + ) { let val = FunctionIdent::new( vir::vir_format!(vcx, "{name_s}_val"), UnaryArity::new([ty_s]), @@ -373,20 +382,12 @@ impl TaskEncoder for TypeEncoder { field_ty_s: vir::Type<'vir>, ) -> FieldAccessFunctions<'vir> { let read = vir::vir_format!(vcx, "{name_s}_read_{idx}"); - let read = FunctionIdent::new( - &read, - UnaryArity::new([ty_s]), - ); + let read = FunctionIdent::new(&read, UnaryArity::new([ty_s])); let write = vir::vir_format!(vcx, "{name_s}_write_{idx}"); - let write = FunctionIdent::new( - &write, - BinaryArity::new([ty_s, field_ty_s]), - ); + let write = FunctionIdent::new(&write, BinaryArity::new([ty_s, field_ty_s])); let projection_p = vir::vir_format!(vcx, "{name_p}_field_{idx}"); - let projection_p = FunctionIdent::new( - &projection_p, - UnaryArity::new([&vir::TypeData::Ref]), - ); + let projection_p = + FunctionIdent::new(&projection_p, UnaryArity::new([&vir::TypeData::Ref])); FieldAccessFunctions { read, write, @@ -403,27 +404,20 @@ impl TaskEncoder for TypeEncoder { ) -> vir::Function<'vir> { let pred_app = vcx.alloc(vir::PredicateAppData { target: predicate_name, - args: vcx.alloc_slice(&[ - vcx.mk_local_ex("self"), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_ex("self")]), }); vcx.alloc(vir::FunctionData { name: snapshot_fn.name(), - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self", &vir::TypeData::Ref)]), ret: snapshot_ty, - pres: vcx.alloc_slice(&[ - vcx.alloc(vir::ExprData::PredicateApp(pred_app)), - ]), + pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), posts: &[], - expr: field_name.map(|field_name| vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: vcx.alloc(vir::ExprData::Field( - vcx.mk_local_ex("self"), - field_name, - )), - })))), + expr: field_name.map(|field_name| { + vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: vcx.alloc(vir::ExprData::Field(vcx.mk_local_ex("self"), field_name)), + }))) + }), }) } fn mk_structlike<'tcx, 'vir>( @@ -433,40 +427,55 @@ impl TaskEncoder for TypeEncoder { name_s: &'vir str, name_p: &'vir str, field_ty_out: Vec>, - ) -> Result<::OutputFullLocal<'vir>, ( - ::EncodingError, - Option<::OutputFullDependency<'vir>>, - )> { + ) -> Result< + ::OutputFullLocal<'vir>, + ( + ::EncodingError, + Option<::OutputFullDependency<'vir>>, + ), + > { let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); let mut field_access = Vec::new(); for idx in 0..field_ty_out.len() { - field_access.push(mk_function_field_projection(vcx, name_p, name_s, idx, ty_s, field_ty_out[idx].snapshot)); + field_access.push(mk_function_field_projection( + vcx, + name_p, + name_s, + idx, + ty_s, + field_ty_out[idx].snapshot, + )); } let field_access = vcx.alloc_slice(&field_access); - let snapshot_constructor_args = - vcx.alloc_slice(&field_ty_out.iter() - .map(|field_ty_out| field_ty_out.snapshot) - .collect::>()); + let snapshot_constructor_args = vcx.alloc_slice( + &field_ty_out + .iter() + .map(|field_ty_out| field_ty_out.snapshot) + .collect::>(), + ); let ref_to_pred = mk_predicate_ident(name_p); let ref_to_snap = mk_function_snap_identifier(vcx, name_p, ty_s); let field_snaps_to_snap = mk_from_fields(vcx, name_s, snapshot_constructor_args); let unreachable_to_snap = mk_function_unreachable_identifier(vcx, name_s); let method_assign = mk_function_assign(vcx, name_p, ty_s); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - ref_to_pred, - snapshot: vcx.alloc(vir::TypeData::Domain(name_s)), - unreachable_to_snap, - ref_to_snap, - //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), - specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { - field_snaps_to_snap, - field_access, - }), - method_assign, - }); + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + ref_to_pred, + snapshot: vcx.alloc(vir::TypeData::Domain(name_s)), + unreachable_to_snap, + ref_to_snap, + //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), + specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { + field_snaps_to_snap, + field_access, + }), + method_assign, + }, + ); let mut funcs: Vec> = vec![]; let mut axioms: Vec> = vec![]; @@ -489,9 +498,7 @@ impl TaskEncoder for TypeEncoder { field_projection_p.push(vcx.alloc(vir::FunctionData { name: fa.projection_p.name(), - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self", &vir::TypeData::Ref)]), ret: &vir::TypeData::Ref, pres: &[], posts: &[], @@ -504,30 +511,31 @@ impl TaskEncoder for TypeEncoder { for (read_idx, _read_ty_out) in field_ty_out.iter().enumerate() { let slf = vcx.mk_local_ex("self"); let val = vcx.mk_local_ex("val"); - let write_read = field_access[read_idx].read.apply(vcx, [ - field_access[write_idx].write.apply(vcx, [slf, val]), - ]); + let write_read = field_access[read_idx] + .read + .apply(vcx, [field_access[write_idx].write.apply(vcx, [slf, val])]); let rhs = if read_idx == write_idx { val } else { field_access[read_idx].read.apply(vcx, [slf]) }; axioms.push(vcx.alloc(vir::DomainAxiomData { - name: vir::vir_format!(vcx, "ax_{name_s}_write_{write_idx}_read_{read_idx}"), + name: vir::vir_format!( + vcx, + "ax_{name_s}_write_{write_idx}_read_{read_idx}" + ), expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", ty_s), - vcx.mk_local_decl("val", write_ty_out.snapshot), - ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - write_read - ])]), - body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: write_read, - rhs, - }))), - }))) + qvars: vcx.alloc_slice(&[ + vcx.mk_local_decl("self", ty_s), + vcx.mk_local_decl("val", write_ty_out.snapshot), + ]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[write_read])]), + body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::CmpEq, + lhs: write_read, + rhs, + }))), + }))), })); } } @@ -535,18 +543,21 @@ impl TaskEncoder for TypeEncoder { // constructor { let cons_qvars = vcx.alloc_slice( - &field_ty_out.iter() + &field_ty_out + .iter() .enumerate() - .map(|(idx, field_ty_out)| vcx.mk_local_decl( - vir::vir_format!(vcx, "f{idx}"), - field_ty_out.snapshot, - )) - .collect::>()); - let cons_args = field_ty_out.iter() + .map(|(idx, field_ty_out)| { + vcx.mk_local_decl( + vir::vir_format!(vcx, "f{idx}"), + field_ty_out.snapshot, + ) + }) + .collect::>(), + ); + let cons_args = field_ty_out + .iter() .enumerate() - .map(|(idx, _field_ty_out)| vcx.mk_local_ex( - vir::vir_format!(vcx, "f{idx}"), - )) + .map(|(idx, _field_ty_out)| vcx.mk_local_ex(vir::vir_format!(vcx, "f{idx}"))) .collect::>(); let cons_call = field_snaps_to_snap.apply(vcx, &cons_args); @@ -556,9 +567,7 @@ impl TaskEncoder for TypeEncoder { name: vir::vir_format!(vcx, "ax_{name_s}_cons_read_{read_idx}"), expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { qvars: cons_qvars.clone(), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - cons_read, - ])]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_read])]), body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpEq, lhs: cons_read, @@ -569,24 +578,21 @@ impl TaskEncoder for TypeEncoder { } if !field_ty_out.is_empty() { - let cons_call_with_reads = field_snaps_to_snap.apply(vcx, + let cons_call_with_reads = field_snaps_to_snap.apply( + vcx, &field_ty_out .iter() .enumerate() - .map(|(idx, _field_ty_out)| + .map(|(idx, _field_ty_out)| { field_access[idx].read.apply(vcx, [vcx.mk_local_ex("self")]) - ) + }) .collect::>(), ); axioms.push(vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_{name_s}_cons"), expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", ty_s), - ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - cons_call_with_reads, - ])]), + qvars: vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call_with_reads])]), body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpEq, lhs: cons_call_with_reads, @@ -600,20 +606,27 @@ impl TaskEncoder for TypeEncoder { // predicate let predicate = { let expr = (0..field_ty_out.len()) - .map(|idx| vcx.alloc(vir::ExprData::PredicateApp( - field_ty_out[idx].ref_to_pred.apply(vcx, [field_access[idx].projection_p.apply(vcx, [vcx.mk_local_ex("self_p")])]) - ))) - .reduce(|base, field_expr| vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::And, - lhs: base, - rhs: field_expr, - })))) + .map(|idx| { + vcx.alloc(vir::ExprData::PredicateApp( + field_ty_out[idx].ref_to_pred.apply( + vcx, + [field_access[idx] + .projection_p + .apply(vcx, [vcx.mk_local_ex("self_p")])], + ), + )) + }) + .reduce(|base, field_expr| { + vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::And, + lhs: base, + rhs: field_expr, + }))) + }) .unwrap_or_else(|| vcx.mk_bool::()); vcx.alloc(vir::PredicateData { name: name_p, - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self_p", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), expr: Some(expr), }) }; @@ -629,33 +642,36 @@ impl TaskEncoder for TypeEncoder { function_snap: { let pred_app = vcx.alloc(vir::PredicateAppData { target: name_p, - args: vcx.alloc_slice(&[ - vcx.mk_local_ex("self_p"), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_ex("self_p")]), }); vcx.alloc(vir::FunctionData { name: ref_to_snap.name(), - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self_p", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), ret: ty_s, - pres: vcx.alloc_slice(&[ - vcx.alloc(vir::ExprData::PredicateApp(pred_app)), - ]), + pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), posts: &[], - expr: Some(vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: field_snaps_to_snap.apply( - vcx, - &field_ty_out - .iter() - .enumerate() - .map(|(idx, field_ty_out)| field_ty_out.ref_to_snap.apply(vcx, [ - field_access[idx].projection_p.apply(vcx, [vcx.mk_local_ex("self_p")]) - ])) - .collect::>() - ), - })))), + expr: Some( + vcx.alloc(vir::ExprData::Unfolding( + vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: field_snaps_to_snap.apply( + vcx, + &field_ty_out + .iter() + .enumerate() + .map(|(idx, field_ty_out)| { + field_ty_out.ref_to_snap.apply( + vcx, + [field_access[idx] + .projection_p + .apply(vcx, [vcx.mk_local_ex("self_p")])], + ) + }) + .collect::>(), + ), + }), + )), + ), }) }, //method_refold: mk_refold(vcx, name_p, ty_s), @@ -671,41 +687,47 @@ impl TaskEncoder for TypeEncoder { let ref_to_snap = mk_function_snap_identifier(vcx, "p_Bool", ty_s); let unreachable_to_snap = mk_function_unreachable_identifier(vcx, "s_Bool"); let method_assign = mk_function_assign(vcx, "p_Bool", ty_s); - let (snap_to_prim, prim_to_snap) = mk_primitive(vcx, "s_Bool", ty_s, &vir::TypeData::Bool); + let (snap_to_prim, prim_to_snap) = + mk_primitive(vcx, "s_Bool", ty_s, &vir::TypeData::Bool); let specifics = TypeEncoderOutputRefSub::Primitive(TypeEncoderOutputRefSubPrim { prim_type: &vir::TypeData::Bool, snap_to_prim, prim_to_snap, }); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - ref_to_pred, - snapshot: ty_s, - unreachable_to_snap, - ref_to_snap, - //method_refold: "refold_p_Bool", - specifics, - method_assign, - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: "f_Bool", - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain s_Bool { - function prim_to_snap(Bool): s_Bool; - function snap_to_prim(s_Bool): Bool; - axiom_inverse(snap_to_prim, prim_to_snap, Bool); - } }, - predicate: mk_simple_predicate(vcx, "p_Bool", "f_Bool"), - unreachable_to_snap: mk_unreachable(vcx, unreachable_to_snap, ty_s), - function_snap: mk_snap(vcx, "p_Bool", ref_to_snap, Some("f_Bool"), ty_s), - //method_refold: mk_refold(vcx, "p_Bool", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Bool", method_assign, ref_to_snap, ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + ref_to_pred, + snapshot: ty_s, + unreachable_to_snap, + ref_to_snap, + //method_refold: "refold_p_Bool", + specifics, + method_assign, + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: "f_Bool", + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain s_Bool { + function prim_to_snap(Bool): s_Bool; + function snap_to_prim(s_Bool): Bool; + axiom_inverse(snap_to_prim, prim_to_snap, Bool); + } }, + predicate: mk_simple_predicate(vcx, "p_Bool", "f_Bool"), + unreachable_to_snap: mk_unreachable(vcx, unreachable_to_snap, ty_s), + function_snap: mk_snap(vcx, "p_Bool", ref_to_snap, Some("f_Bool"), ty_s), + //method_refold: mk_refold(vcx, "p_Bool", ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, "p_Bool", method_assign, ref_to_snap, ty_s), + }, + (), + )) } - TyKind::Int(_) | - TyKind::Uint(_) => { + TyKind::Int(_) | TyKind::Uint(_) => { let signed = task_key.is_signed(); let (sign, name_str) = match task_key.kind() { TyKind::Int(kind) => ("Int", kind.name_str()), @@ -713,7 +735,10 @@ impl TaskEncoder for TypeEncoder { _ => unreachable!(), }; let bit_width = Self::get_bit_width(vcx.tcx, *task_key); - let prim_type = vcx.alloc(vir::TypeData::Int { bit_width: bit_width as u8, signed }); + let prim_type = vcx.alloc(vir::TypeData::Int { + bit_width: bit_width as u8, + signed, + }); let name_s = vir::vir_format!(vcx, "s_{sign}_{name_str}"); let name_p = vir::vir_format!(vcx, "p_{sign}_{name_str}"); let ref_to_pred = mk_predicate_ident(name_p); @@ -728,50 +753,65 @@ impl TaskEncoder for TypeEncoder { snap_to_prim, prim_to_snap, }); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - ref_to_pred, - snapshot: ty_s, - unreachable_to_snap, - ref_to_snap, - //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), - specifics, - method_assign, - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: name_field, - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain [name_s] { - function prim_to_snap(Int): [ty_s]; - function snap_to_prim([ty_s]): Int; - axiom_inverse(snap_to_prim, prim_to_snap, Int); - } }, - predicate: mk_simple_predicate(vcx, name_p, name_field), - unreachable_to_snap: mk_unreachable(vcx, unreachable_to_snap, ty_s), - function_snap: mk_snap(vcx, name_p, ref_to_snap, Some(name_field), ty_s), - //method_refold: mk_refold(vcx, name_p, ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, name_p, method_assign, ref_to_snap, ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + ref_to_pred, + snapshot: ty_s, + unreachable_to_snap, + ref_to_snap, + //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), + specifics, + method_assign, + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: name_field, + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain [name_s] { + function prim_to_snap(Int): [ty_s]; + function snap_to_prim([ty_s]): Int; + axiom_inverse(snap_to_prim, prim_to_snap, Int); + } }, + predicate: mk_simple_predicate(vcx, name_p, name_field), + unreachable_to_snap: mk_unreachable(vcx, unreachable_to_snap, ty_s), + function_snap: mk_snap(vcx, name_p, ref_to_snap, Some(name_field), ty_s), + //method_refold: mk_refold(vcx, name_p, ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, name_p, method_assign, ref_to_snap, ty_s), + }, + (), + )) } TyKind::Tuple(tys) => { let field_ty_out = tys .iter() - .map(|ty| deps.require_ref::(ty).unwrap()) + .map(|ty| { + deps.require_ref::(ty) + .unwrap() + }) .collect::>(); // TODO: name the tuple according to its types, or make generic? - let tmp_ty_name: String = field_ty_out.iter().map(|e| format!("_{}", e.snapshot.get_domain().unwrap())).collect(); - - Ok((mk_structlike( - vcx, - deps, - task_key, - vir::vir_format!(vcx, "s_Tuple{}{tmp_ty_name}", tys.len()), - vir::vir_format!(vcx, "p_Tuple{}{tmp_ty_name}", tys.len()), - field_ty_out, - )?, ())) + let tmp_ty_name: String = field_ty_out + .iter() + .map(|e| format!("_{}", e.snapshot.get_domain().unwrap())) + .collect(); + + Ok(( + mk_structlike( + vcx, + deps, + task_key, + vir::vir_format!(vcx, "s_Tuple{}{tmp_ty_name}", tys.len()), + vir::vir_format!(vcx, "p_Tuple{}{tmp_ty_name}", tys.len()), + field_ty_out, + )?, + (), + )) /* let ty_len = tys.len(); @@ -808,49 +848,78 @@ impl TaskEncoder for TypeEncoder { } TyKind::Param(_param) => { - let param_out = deps.require_ref::(()).unwrap(); + let param_out = deps + .require_ref::(()) + .unwrap(); let ty_s = vcx.alloc(vir::TypeData::Domain(param_out.snapshot_param_name)); let ref_to_pred = mk_predicate_ident(param_out.predicate_param_name); - let ref_to_snap = mk_function_snap_identifier(vcx, param_out.predicate_param_name, ty_s); - let unreachable_to_snap = mk_function_unreachable_identifier(vcx, param_out.snapshot_param_name); + let ref_to_snap = + mk_function_snap_identifier(vcx, param_out.predicate_param_name, ty_s); + let unreachable_to_snap = + mk_function_unreachable_identifier(vcx, param_out.snapshot_param_name); let method_assign = mk_function_assign(vcx, param_out.predicate_param_name, ty_s); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - ref_to_pred, - snapshot: ty_s, - unreachable_to_snap, - ref_to_snap, - //method_refold: "refold_p_Param", - specifics: TypeEncoderOutputRefSub::Param, - method_assign, - }); - Ok((TypeEncoderOutput { - fields: &[], - snapshot: vir::vir_domain! { vcx; domain s_ParamTodo { // TODO: should not be emitted -- make outputs vectors - } }, - predicate: vir::vir_predicate! { vcx; predicate p_ParamTodo(self_p: Ref) }, - unreachable_to_snap: mk_unreachable(vcx, unreachable_to_snap, ty_s), - function_snap: mk_snap(vcx, param_out.predicate_param_name, ref_to_snap, None, ty_s), - //method_refold: mk_refold(vcx, param_out.predicate_param_name, ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, param_out.predicate_param_name, method_assign, ref_to_snap, ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + ref_to_pred, + snapshot: ty_s, + unreachable_to_snap, + ref_to_snap, + //method_refold: "refold_p_Param", + specifics: TypeEncoderOutputRefSub::Param, + method_assign, + }, + ); + Ok(( + TypeEncoderOutput { + fields: &[], + snapshot: vir::vir_domain! { vcx; domain s_ParamTodo { // TODO: should not be emitted -- make outputs vectors + } }, + predicate: vir::vir_predicate! { vcx; predicate p_ParamTodo(self_p: Ref) }, + unreachable_to_snap: mk_unreachable(vcx, unreachable_to_snap, ty_s), + function_snap: mk_snap( + vcx, + param_out.predicate_param_name, + ref_to_snap, + None, + ty_s, + ), + //method_refold: mk_refold(vcx, param_out.predicate_param_name, ty_s), + field_projection_p: &[], + method_assign: mk_assign( + vcx, + param_out.predicate_param_name, + method_assign, + ref_to_snap, + ty_s, + ), + }, + (), + )) } TyKind::Adt(adt_def, substs) if adt_def.is_struct() => { println!("encoding ADT {adt_def:?} with substs {substs:?}"); let substs = ty::List::identity_for_item(vcx.tcx, adt_def.did()); - let field_ty_out = adt_def.all_fields() - .map(|field| deps.require_ref::(field.ty(vcx.tcx, substs)).unwrap()) + let field_ty_out = adt_def + .all_fields() + .map(|field| { + deps.require_ref::(field.ty(vcx.tcx, substs)) + .unwrap() + }) .collect::>(); let did_name = vcx.tcx.item_name(adt_def.did()).to_ident_string(); - Ok((mk_structlike( - vcx, - deps, - task_key, - vir::vir_format!(vcx, "s_Adt_{did_name}"), - vir::vir_format!(vcx, "p_Adt_{did_name}"), - field_ty_out, - )?, ())) + Ok(( + mk_structlike( + vcx, + deps, + task_key, + vir::vir_format!(vcx, "s_Adt_{did_name}"), + vir::vir_format!(vcx, "p_Adt_{did_name}"), + field_ty_out, + )?, + (), + )) } TyKind::Never => { let ty_s = vcx.alloc(vir::TypeData::Domain("s_Never")); @@ -858,30 +927,35 @@ impl TaskEncoder for TypeEncoder { let ref_to_snap = mk_function_snap_identifier(vcx, "p_Never", ty_s); let unreachable_to_snap = mk_function_unreachable_identifier(vcx, "s_Never"); let method_assign = mk_function_assign(vcx, "p_Never", ty_s); - let specifics = TypeEncoderOutputRefSub::EnumLike(TypeEncoderOutputRefSubEnum { - }); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - ref_to_pred, - snapshot: ty_s, - unreachable_to_snap, - ref_to_snap, - //method_refold: "refold_p_Never", - specifics, - method_assign, - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: vir::vir_format!(vcx, "f_Never"), - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain s_Never {} }, - predicate: vir::vir_predicate! { vcx; predicate p_Never(self_p: Ref) }, - unreachable_to_snap: mk_unreachable(vcx, unreachable_to_snap, ty_s), - function_snap: mk_snap(vcx, "p_Never", ref_to_snap, None, ty_s), - //method_refold: mk_refold(vcx, "p_Never", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Never", method_assign, ref_to_snap, ty_s), - }, ())) + let specifics = TypeEncoderOutputRefSub::EnumLike(TypeEncoderOutputRefSubEnum {}); + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + ref_to_pred, + snapshot: ty_s, + unreachable_to_snap, + ref_to_snap, + //method_refold: "refold_p_Never", + specifics, + method_assign, + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: vir::vir_format!(vcx, "f_Never"), + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain s_Never {} }, + predicate: vir::vir_predicate! { vcx; predicate p_Never(self_p: Ref) }, + unreachable_to_snap: mk_unreachable(vcx, unreachable_to_snap, ty_s), + function_snap: mk_snap(vcx, "p_Never", ref_to_snap, None, ty_s), + //method_refold: mk_refold(vcx, "p_Never", ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, "p_Never", method_assign, ref_to_snap, ty_s), + }, + (), + )) } //_ => Err((TypeEncoderError::UnsupportedType, None)), unsupported_type => todo!("type not supported: {unsupported_type:?}"), diff --git a/prusti-encoder/src/encoders/viper_tuple.rs b/prusti-encoder/src/encoders/viper_tuple.rs index ac32ff3dc9d..b7f08532dd1 100644 --- a/prusti-encoder/src/encoders/viper_tuple.rs +++ b/prusti-encoder/src/encoders/viper_tuple.rs @@ -1,9 +1,6 @@ -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; -use vir::{FunctionIdent, UnknownArity, CallableIdent, UnaryArity}; use std::cell::RefCell; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use vir::{CallableIdent, FunctionIdent, UnaryArity, UnknownArity}; pub struct ViperTupleEncoder; @@ -20,7 +17,7 @@ impl<'vir> ViperTupleEncoderOutputRef<'vir> { pub fn mk_cons<'tcx, Curr, Next>( &self, vcx: &'vir vir::VirCtxt<'tcx>, - elems: &[vir::ExprGen<'vir, Curr, Next>] + elems: &[vir::ExprGen<'vir, Curr, Next>], ) -> vir::ExprGen<'vir, Curr, Next> { if self.elem_count == 1 { return elems[0]; @@ -59,7 +56,8 @@ impl TaskEncoder for ViperTupleEncoder { type EncodingError = (); fn with_cache<'tcx, 'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, ViperTupleEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'tcx, 'vir, ViperTupleEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -77,13 +75,16 @@ impl TaskEncoder for ViperTupleEncoder { fn do_encode_full<'tcx: 'vir, 'vir>( task_key: &Self::TaskKey<'tcx>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { let is_unit_tuple = *task_key == 0; vir::with_vcx(|vcx| { let domain_name = vir::vir_format!(vcx, "Tuple_{task_key}"); @@ -94,61 +95,80 @@ impl TaskEncoder for ViperTupleEncoder { let typaram_names = (0..*task_key) .map(|idx| vir::vir_format!(vcx, "T{idx}")) .collect::>(); - let typaram_tys = vcx.alloc_slice(&typaram_names.iter() - .map(|name| vcx.alloc(vir::TypeData::Domain(name))) - .collect::>()); + let typaram_tys = vcx.alloc_slice( + &typaram_names + .iter() + .map(|name| vcx.alloc(vir::TypeData::Domain(name))) + .collect::>(), + ); let constructor = FunctionIdent::new(cons_name, UnknownArity::new(typaram_tys)); let domain_type = vcx.alloc(vir::TypeData::Domain(domain_name)); - let elem_getters = elem_names.iter().map(|name| { - FunctionIdent::new(name, UnaryArity::new([domain_type])) - }).collect::>(); - deps.emit_output_ref::(*task_key, ViperTupleEncoderOutputRef { - elem_count: *task_key, - domain_type, - constructor, - elem_getters: vcx.alloc_slice(&elem_getters), - }); + let elem_getters = elem_names + .iter() + .map(|name| FunctionIdent::new(name, UnaryArity::new([domain_type]))) + .collect::>(); + deps.emit_output_ref::( + *task_key, + ViperTupleEncoderOutputRef { + elem_count: *task_key, + domain_type, + constructor, + elem_getters: vcx.alloc_slice(&elem_getters), + }, + ); let domain_ty = vcx.alloc(vir::TypeData::DomainParams(domain_name, typaram_tys)); let qvars_names = (0..*task_key) .map(|idx| vir::vir_format!(vcx, "elem{idx}")) .collect::>(); let mut axioms = Vec::new(); if !is_unit_tuple { - let qvars_decl = vcx.alloc_slice(&(0..*task_key) - .map(|idx| vcx.mk_local_decl(qvars_names[idx], typaram_tys[idx])) - .collect::>()); + let qvars_decl = vcx.alloc_slice( + &(0..*task_key) + .map(|idx| vcx.mk_local_decl(qvars_names[idx], typaram_tys[idx])) + .collect::>(), + ); let qvars_ex = (0..*task_key) .map(|idx| vcx.mk_local_ex(qvars_names[idx])) .collect::>(); - let cons_call = constructor.apply(vcx, - &qvars_names.iter() + let cons_call = constructor.apply( + vcx, + &qvars_names + .iter() .map(|qvar| vcx.mk_local_ex(qvar)) .collect::>(), ); let axiom = vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_Tuple_{task_key}_elem"), - expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: qvars_decl, - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), - body: vcx.mk_conj(&(0..*task_key) - .map(|idx| vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: elem_getters[idx].apply(vcx, [cons_call]), - rhs: qvars_ex[idx], - })))) - .collect::>()), - }))), + expr: vcx.alloc(vir::ExprData::Forall( + vcx.alloc(vir::ForallData { + qvars: qvars_decl, + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), + body: vcx.mk_conj( + &(0..*task_key) + .map(|idx| { + vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::CmpEq, + lhs: elem_getters[idx].apply(vcx, [cons_call]), + rhs: qvars_ex[idx], + }))) + }) + .collect::>(), + ), + }), + )), }); axioms.push(axiom); } let elem_args = vcx.alloc_slice(&[domain_ty]); let mut functions = (0..*task_key) - .map(|idx| vcx.alloc(vir::DomainFunctionData { - unique: false, - name: elem_names[idx], - args: elem_args, - ret: typaram_tys[idx], - })) + .map(|idx| { + vcx.alloc(vir::DomainFunctionData { + unique: false, + name: elem_names[idx], + args: elem_args, + ret: typaram_tys[idx], + }) + }) .collect::>(); functions.push(vcx.alloc(vir::DomainFunctionData { unique: false, @@ -156,14 +176,17 @@ impl TaskEncoder for ViperTupleEncoder { args: typaram_tys, ret: domain_ty, })); - Ok((ViperTupleEncoderOutput { - domain: Some(vcx.alloc(vir::DomainData { - name: domain_name, - typarams: vcx.alloc_slice(&typaram_names), - axioms: vcx.alloc_slice(&axioms), - functions: vcx.alloc_slice(&functions), - })), - }, ())) + Ok(( + ViperTupleEncoderOutput { + domain: Some(vcx.alloc(vir::DomainData { + name: domain_name, + typarams: vcx.alloc_slice(&typaram_names), + axioms: vcx.alloc_slice(&axioms), + functions: vcx.alloc_slice(&functions), + })), + }, + (), + )) }) } } diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index 62bc24a5bb8..f69c86ebde5 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -10,10 +10,7 @@ extern crate rustc_type_ir; mod encoders; use prusti_interface::{environment::EnvBody, specs::typed::SpecificationItem}; -use prusti_rustc_interface::{ - middle::ty, - hir, -}; +use prusti_rustc_interface::{hir, middle::ty}; /* struct MirBodyPureEncoder; @@ -100,7 +97,7 @@ impl<'vir, 'tcx> TaskEncoder<'vir, 'tcx> for MirBodyImpureEncoder<'vir, 'tcx> { ); // TaskKey, OutputRef same as above type OutputFull = vir::Method<'vir>; -} +} struct MirTyEncoder<'vir, 'tcx>(PhantomData<&'vir ()>, PhantomData<&'tcx ()>); impl<'vir, 'tcx> TaskEncoder<'vir, 'tcx> for MirTyEncoder<'vir, 'tcx> { @@ -134,18 +131,18 @@ pub fn test_entrypoint<'tcx>( continue; }*/ match kind { - hir::def::DefKind::Fn | - hir::def::DefKind::AssocFn => { + hir::def::DefKind::Fn | hir::def::DefKind::AssocFn => { let def_id = def_id.to_def_id(); if prusti_interface::specs::is_spec_fn(tcx, def_id) { continue; } let (is_pure, is_trusted) = crate::encoders::with_proc_spec(def_id, |proc_spec| { - let is_pure = proc_spec.kind.is_pure().unwrap_or_default(); - let is_trusted = proc_spec.trusted.extract_inherit().unwrap_or_default(); - (is_pure, is_trusted) - }).unwrap_or_default(); + let is_pure = proc_spec.kind.is_pure().unwrap_or_default(); + let is_trusted = proc_spec.trusted.extract_inherit().unwrap_or_default(); + (is_pure, is_trusted) + }) + .unwrap_or_default(); if !(is_trusted && is_pure) { let substs = ty::GenericArgs::identity_for_item(tcx, def_id); @@ -220,20 +217,20 @@ pub fn test_entrypoint<'tcx>( std::fs::write("local-testing/simple.vpr", viper_code).unwrap(); - vir::with_vcx(|vcx| vcx.alloc(vir::ProgramData { - fields: &[], - domains: &[], - predicates: &[], - functions: vcx.alloc_slice(&[ - vcx.alloc(vir::FunctionData { + vir::with_vcx(|vcx| { + vcx.alloc(vir::ProgramData { + fields: &[], + domains: &[], + predicates: &[], + functions: vcx.alloc_slice(&[vcx.alloc(vir::FunctionData { name: "test_function", args: &[], ret: &vir::TypeData::Bool, pres: &[], posts: &[], expr: None, - }), - ]), - methods: &[], - })) + })]), + methods: &[], + }) + }) } diff --git a/task-encoder/src/lib.rs b/task-encoder/src/lib.rs index 7ebc8e1a64c..115b0da1c95 100644 --- a/task-encoder/src/lib.rs +++ b/task-encoder/src/lib.rs @@ -8,7 +8,6 @@ impl OutputRefAny for () {} pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { // None, // indicated by absence in the cache - /// Task was enqueued but not yet started. Enqueued, @@ -29,9 +28,7 @@ pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { }, /// An error occurred when enqueing the task. - ErrorEnqueue { - error: TaskEncoderError, - }, + ErrorEnqueue { error: TaskEncoderError }, /// An error occurred when encoding the task. The full "local" encoding is /// not available. However, tasks which depend on this task may still @@ -49,16 +46,12 @@ pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { /// Cache for a task encoder. See `TaskEncoderCacheState` for a description of /// the possible values in the encoding process. -pub type Cache<'tcx, 'vir, E> = LinkedHashMap< - ::TaskKey<'tcx>, - TaskEncoderCacheState<'vir, E>, ->; +pub type Cache<'tcx, 'vir, E> = + LinkedHashMap<::TaskKey<'tcx>, TaskEncoderCacheState<'vir, E>>; pub type CacheRef<'tcx, 'vir, E> = RefCell>; -pub type CacheStatic = LinkedHashMap< - ::TaskKey<'static>, - TaskEncoderCacheState<'static, E>, ->; +pub type CacheStatic = + LinkedHashMap<::TaskKey<'static>, TaskEncoderCacheState<'static, E>>; pub type CacheStaticRef = RefCell>; /* pub struct TaskEncoderOutput<'vir, E: TaskEncoder>( @@ -85,7 +78,8 @@ pub enum TaskEncoderError { } impl std::fmt::Debug for TaskEncoderError - where ::EncodingError: std::fmt::Debug +where + ::EncodingError: std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut helper = f.debug_struct("TaskEncoderError"); @@ -118,30 +112,21 @@ impl<'a> TaskEncoderDependencies<'a> { pub fn require_ref<'vir, 'tcx: 'vir, E: TaskEncoder>( &mut self, task: ::TaskDescription<'tcx>, - ) -> Result< - ::OutputRef<'vir>, - TaskEncoderError, - > { + ) -> Result<::OutputRef<'vir>, TaskEncoderError> { E::encode_ref(task) } pub fn require_local<'vir, 'tcx: 'vir, E: TaskEncoder + 'vir>( &mut self, task: ::TaskDescription<'tcx>, - ) -> Result< - ::OutputFullLocal<'vir>, - TaskEncoderError, - > { + ) -> Result<::OutputFullLocal<'vir>, TaskEncoderError> { E::encode(task).map(|(_output_ref, output_local, _output_dep)| output_local) } pub fn require_dep<'vir, 'tcx: 'vir, E: TaskEncoder + 'vir>( &mut self, task: ::TaskDescription<'tcx>, - ) -> Result< - ::OutputFullDependency<'vir>, - TaskEncoderError, - > { + ) -> Result<::OutputFullDependency<'vir>, TaskEncoderError> { E::encode(task).map(|(_output_ref, _output_local, output_dep)| output_dep) } @@ -150,10 +135,12 @@ impl<'a> TaskEncoderDependencies<'a> { task_key: E::TaskKey<'tcx>, output_ref: E::OutputRef<'vir>, ) { - assert!(E::with_cache(move |cache| matches!(cache.borrow_mut().insert( - task_key, - TaskEncoderCacheState::Started { output_ref }, - ), Some(TaskEncoderCacheState::Enqueued)))); + assert!(E::with_cache(move |cache| matches!( + cache + .borrow_mut() + .insert(task_key, TaskEncoderCacheState::Started { output_ref },), + Some(TaskEncoderCacheState::Enqueued) + ))); } } @@ -166,7 +153,8 @@ pub trait TaskEncoder { /// for example if the description should be normalised or some non-trivial /// resolution needs to happen. In other words, multiple descriptions may /// lead to the same key and hence the same output. - type TaskKey<'tcx>: std::hash::Hash + Eq + Clone + std::fmt::Debug = Self::TaskDescription<'tcx>; + type TaskKey<'tcx>: std::hash::Hash + Eq + Clone + std::fmt::Debug = + Self::TaskDescription<'tcx>; /// A reference to an encoded item. Should be non-unit for tasks which can /// be "referred" to from other parts of a program, as opposed to tasks @@ -190,7 +178,9 @@ pub trait TaskEncoder { /// Enters the given function with a reference to the cache for this /// encoder. fn with_cache<'tcx: 'vir, 'vir, F, R>(f: F) -> R - where Self: 'vir, F: FnOnce(&'vir CacheRef<'tcx, 'vir, Self>) -> R; + where + Self: 'vir, + F: FnOnce(&'vir CacheRef<'tcx, 'vir, Self>) -> R; //fn get_all_outputs() -> Self::CacheRef<'vir> { // todo!() @@ -199,7 +189,8 @@ pub trait TaskEncoder { //} fn enqueue<'vir>(task: Self::TaskDescription<'vir>) - where Self: 'vir + where + Self: 'vir, { let task_key = Self::task_to_key(&task); let task_key_clone = task_key.clone(); // TODO: remove? @@ -209,28 +200,32 @@ pub trait TaskEncoder { } // enqueue, expecting no entry (we just checked) - assert!(Self::with_cache(move |cache| cache.borrow_mut().insert( - task_key, - TaskEncoderCacheState::Enqueued, - ).is_none())); + assert!(Self::with_cache(move |cache| cache + .borrow_mut() + .insert(task_key, TaskEncoderCacheState::Enqueued,) + .is_none())); } - fn encode_ref<'tcx: 'vir, 'vir>(task: Self::TaskDescription<'tcx>) -> Result< - Self::OutputRef<'vir>, - TaskEncoderError, - > - where Self: 'vir + fn encode_ref<'tcx: 'vir, 'vir>( + task: Self::TaskDescription<'tcx>, + ) -> Result, TaskEncoderError> + where + Self: 'vir, { let task_key = Self::task_to_key(&task); // is there an output ref available already? let task_key_clone = task_key.clone(); - if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { - Some(TaskEncoderCacheState::Started { output_ref, .. }) - | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) - | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), - _ => None, - }) { + if let Some(output_ref) = + Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { + Some(TaskEncoderCacheState::Started { output_ref, .. }) + | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) + | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => { + Some(output_ref.clone()) + } + _ => None, + }) + { return Ok(output_ref); } @@ -252,24 +247,34 @@ pub trait TaskEncoder { Self::encode(task)?; let task_key_clone = task_key.clone(); - if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { - Some(TaskEncoderCacheState::Started { output_ref, .. }) - | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) - | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), - _ => None, - }) { + if let Some(output_ref) = + Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { + Some(TaskEncoderCacheState::Started { output_ref, .. }) + | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) + | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => { + Some(output_ref.clone()) + } + _ => None, + }) + { return Ok(output_ref); } panic!("output ref not found after encoding") // TODO: error? } - fn encode<'tcx: 'vir, 'vir>(task: Self::TaskDescription<'tcx>) -> Result<( - Self::OutputRef<'vir>, - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), TaskEncoderError> - where Self: 'vir + fn encode<'tcx: 'vir, 'vir>( + task: Self::TaskDescription<'tcx>, + ) -> Result< + ( + Self::OutputRef<'vir>, + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + TaskEncoderError, + > + where + Self: 'vir, { let task_key = Self::task_to_key(&task); @@ -290,8 +295,9 @@ pub trait TaskEncoder { output_local.clone(), output_dep.clone(), ))), - TaskEncoderCacheState::Enqueued | TaskEncoderCacheState::Started { .. } => - panic!("Encoding already started or enqueued"), + TaskEncoderCacheState::Enqueued | TaskEncoderCacheState::Started { .. } => { + panic!("Encoding already started or enqueued") + } }, None => { // enqueue @@ -314,160 +320,172 @@ pub trait TaskEncoder { match encode_result { Ok((output_local, output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { - output_ref: output_ref.clone(), - deps, - output_local: output_local.clone(), - output_dep: output_dep.clone(), - })); - Ok(( - output_ref, - output_local, - output_dep, - )) + Self::with_cache(|cache| { + cache.borrow_mut().insert( + task_key, + TaskEncoderCacheState::Encoded { + output_ref: output_ref.clone(), + deps, + output_local: output_local.clone(), + output_dep: output_dep.clone(), + }, + ) + }); + Ok((output_ref, output_local, output_dep)) } Err((err, maybe_output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::ErrorEncode { - output_ref: output_ref.clone(), - deps, - error: TaskEncoderError::EncodingError(err.clone()), - output_dep: maybe_output_dep, - })); + Self::with_cache(|cache| { + cache.borrow_mut().insert( + task_key, + TaskEncoderCacheState::ErrorEncode { + output_ref: output_ref.clone(), + deps, + error: TaskEncoderError::EncodingError(err.clone()), + output_dep: maybe_output_dep, + }, + ) + }); Err(TaskEncoderError::EncodingError(err)) } } } /* - /// Given a task description for this encoder, enqueue it and return the - /// reference to the output. If the task is already enqueued, the output - /// reference already exists. - fn encode<'vir>(task: Self::TaskDescription<'vir>) -> Self::OutputRef<'vir> - where Self: 'vir - { - let task_key = Self::task_to_key(&task); - let task_key_clone = task_key.clone(); - if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { - Some(TaskEncoderCacheState::Enqueued { output_ref }) - | Some(TaskEncoderCacheState::Started { output_ref, .. }) - | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) - | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), - _ => None, - }) { - return output_ref; + /// Given a task description for this encoder, enqueue it and return the + /// reference to the output. If the task is already enqueued, the output + /// reference already exists. + fn encode<'vir>(task: Self::TaskDescription<'vir>) -> Self::OutputRef<'vir> + where Self: 'vir + { + let task_key = Self::task_to_key(&task); + let task_key_clone = task_key.clone(); + if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { + Some(TaskEncoderCacheState::Enqueued { output_ref }) + | Some(TaskEncoderCacheState::Started { output_ref, .. }) + | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) + | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), + _ => None, + }) { + return output_ref; + } + let task_ref = Self::task_to_output_ref(&task); + let task_key_clone = task_key.clone(); + let task_ref_clone = task_ref.clone(); + assert!(Self::with_cache(move |cache| cache.borrow_mut().insert( + task_key_clone, + TaskEncoderCacheState::Enqueued { output_ref: task_ref_clone }, + ).is_none())); + task_ref } - let task_ref = Self::task_to_output_ref(&task); - let task_key_clone = task_key.clone(); - let task_ref_clone = task_ref.clone(); - assert!(Self::with_cache(move |cache| cache.borrow_mut().insert( - task_key_clone, - TaskEncoderCacheState::Enqueued { output_ref: task_ref_clone }, - ).is_none())); - task_ref - } - - // TODO: this function should not be needed - fn encode_eager<'vir>(task: Self::TaskDescription<'vir>) -> Result<( - Self::OutputRef<'vir>, - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), TaskEncoderError> - where Self: 'vir - { - let task_key = Self::task_to_key(&task); - // enqueue - let output_ref = Self::encode(task); - // process - Self::encode_full(task_key) - .map(|(output_full_local, output_full_dep)| (output_ref, output_full_local, output_full_dep)) - } - /// Given a task key, fully encode the given task. If this task was already - /// finished, the encoding is not repeated. If this task was enqueued, but - /// not finished, return a `CyclicError`. - fn encode_full<'vir>(task_key: Self::TaskKey<'vir>) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), TaskEncoderError> - where Self: 'vir - { - let mut output_ref_opt = None; - let ret = Self::with_cache(|cache| { - // should be queued by now - match cache.borrow().get(&task_key).unwrap() { - TaskEncoderCacheState::Enqueued { output_ref } => { - output_ref_opt = Some(output_ref.clone()); - None - } - TaskEncoderCacheState::Started { .. } => Some(Err(TaskEncoderError::CyclicError)), - TaskEncoderCacheState::Encoded { output_local, output_dep, .. } => - Some(Ok(( - output_local.clone(), - output_dep.clone(), - ))), - TaskEncoderCacheState::ErrorEncode { error, .. } => - Some(Err(error.clone())), - } - }); - if let Some(ret) = ret { - return ret; + // TODO: this function should not be needed + fn encode_eager<'vir>(task: Self::TaskDescription<'vir>) -> Result<( + Self::OutputRef<'vir>, + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), TaskEncoderError> + where Self: 'vir + { + let task_key = Self::task_to_key(&task); + // enqueue + let output_ref = Self::encode(task); + // process + Self::encode_full(task_key) + .map(|(output_full_local, output_full_dep)| (output_ref, output_full_local, output_full_dep)) } - let output_ref = output_ref_opt.unwrap(); - let mut deps: TaskEncoderDependencies<'vir> = Default::default(); - match Self::do_encode_full(&task_key, &mut deps) { - Ok((output_local, output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { - output_ref: output_ref.clone(), - deps, - output_local: output_local.clone(), - output_dep: output_dep.clone(), - })); - Ok(( - output_local, - output_dep, - )) + /// Given a task key, fully encode the given task. If this task was already + /// finished, the encoding is not repeated. If this task was enqueued, but + /// not finished, return a `CyclicError`. + fn encode_full<'vir>(task_key: Self::TaskKey<'vir>) -> Result<( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), TaskEncoderError> + where Self: 'vir + { + let mut output_ref_opt = None; + let ret = Self::with_cache(|cache| { + // should be queued by now + match cache.borrow().get(&task_key).unwrap() { + TaskEncoderCacheState::Enqueued { output_ref } => { + output_ref_opt = Some(output_ref.clone()); + None + } + TaskEncoderCacheState::Started { .. } => Some(Err(TaskEncoderError::CyclicError)), + TaskEncoderCacheState::Encoded { output_local, output_dep, .. } => + Some(Ok(( + output_local.clone(), + output_dep.clone(), + ))), + TaskEncoderCacheState::ErrorEncode { error, .. } => + Some(Err(error.clone())), + } + }); + if let Some(ret) = ret { + return ret; } - Err((err, maybe_output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::ErrorEncode { - output_ref: output_ref.clone(), - deps, - error: TaskEncoderError::EncodingError(err.clone()), - output_dep: maybe_output_dep, - })); - Err(TaskEncoderError::EncodingError(err)) + let output_ref = output_ref_opt.unwrap(); + + let mut deps: TaskEncoderDependencies<'vir> = Default::default(); + match Self::do_encode_full(&task_key, &mut deps) { + Ok((output_local, output_dep)) => { + Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { + output_ref: output_ref.clone(), + deps, + output_local: output_local.clone(), + output_dep: output_dep.clone(), + })); + Ok(( + output_local, + output_dep, + )) + } + Err((err, maybe_output_dep)) => { + Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::ErrorEncode { + output_ref: output_ref.clone(), + deps, + error: TaskEncoderError::EncodingError(err.clone()), + output_dep: maybe_output_dep, + })); + Err(TaskEncoderError::EncodingError(err)) + } } } - } -*/ + */ /// Given a task description, create a key for storing it in the cache. - fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> // Result< - Self::TaskKey<'vir>;//, - // Self::EnqueueingError, - //> -/* - /// Given a task description, create a reference to the output. - fn task_to_output_ref<'vir>(task: &Self::TaskDescription<'vir>) -> Self::OutputRef<'vir>; -*/ + fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir>; //, + // Self::EnqueueingError, + //> + /* + /// Given a task description, create a reference to the output. + fn task_to_output_ref<'vir>(task: &Self::TaskDescription<'vir>) -> Self::OutputRef<'vir>; + */ fn do_encode_full<'tcx: 'vir, 'vir>( task_key: &Self::TaskKey<'tcx>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )>; + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + >; fn all_outputs<'vir>() -> Vec> - where Self: 'vir + where + Self: 'vir, { Self::with_cache(|cache| { let mut ret = vec![]; for (_task_key, cache_state) in cache.borrow().iter() { - match cache_state { // TODO: make this into an iterator chain - TaskEncoderCacheState::Encoded { output_local, .. } => ret.push(output_local.clone()), + match cache_state { + // TODO: make this into an iterator chain + TaskEncoderCacheState::Encoded { output_local, .. } => { + ret.push(output_local.clone()) + } _ => {} } } diff --git a/vir/src/callable_idents.rs b/vir/src/callable_idents.rs index 3b009fd8589..3520068b55f 100644 --- a/vir/src/callable_idents.rs +++ b/vir/src/callable_idents.rs @@ -1,4 +1,6 @@ -use crate::{ExprGen, PredicateAppGen, Type, PredicateAppGenData, StmtGenData, MethodCallGenData, VirCtxt}; +use crate::{ + ExprGen, MethodCallGenData, PredicateAppGen, PredicateAppGenData, StmtGenData, Type, VirCtxt, +}; use sealed::sealed; pub trait CallableIdent<'vir, A: Arity> { @@ -56,7 +58,11 @@ pub trait Arity: Copy { fn check<'vir, Curr: 'vir, Next: 'vir>(&self, name: &str, args: &[ExprGen<'vir, Curr, Next>]) { if cfg!(debug_assertions) { let args_len = args.len(); - assert!(self.len_matches(args_len), "{name} called with {args_len} args (expected {})", self.args().len()); + assert!( + self.len_matches(args_len), + "{name} called with {args_len} args (expected {})", + self.args().len() + ); for (_arg, _ty) in args.iter().zip(self.args()) { // TODO: check that the types match } @@ -108,8 +114,8 @@ impl<'vir, const N: usize> FunctionIdent<'vir, KnownArity<'vir, N>> { pub fn apply<'tcx, Curr: 'vir, Next: 'vir>( &self, vcx: &'vir VirCtxt<'tcx>, - args: [ExprGen<'vir, Curr, Next>; N] - ) -> ExprGen<'vir, Curr, Next>{ + args: [ExprGen<'vir, Curr, Next>; N], + ) -> ExprGen<'vir, Curr, Next> { self.1.check(self.name(), &args); vcx.mk_func_app(self.name(), &args) } @@ -118,8 +124,8 @@ impl<'vir, const N: usize> PredicateIdent<'vir, KnownArity<'vir, N>> { pub fn apply<'tcx, Curr: 'vir, Next: 'vir>( &self, vcx: &'vir VirCtxt<'tcx>, - args: [ExprGen<'vir, Curr, Next>; N] - ) -> PredicateAppGen<'vir, Curr, Next>{ + args: [ExprGen<'vir, Curr, Next>; N], + ) -> PredicateAppGen<'vir, Curr, Next> { self.1.check(self.name(), &args); vcx.alloc(PredicateAppGenData { target: self.name(), @@ -131,8 +137,8 @@ impl<'vir, const N: usize> MethodIdent<'vir, KnownArity<'vir, N>> { pub fn apply<'tcx, Curr: 'vir, Next: 'vir>( &self, vcx: &'vir VirCtxt<'tcx>, - args: [ExprGen<'vir, Curr, Next>; N] - ) -> StmtGenData<'vir, Curr, Next>{ + args: [ExprGen<'vir, Curr, Next>; N], + ) -> StmtGenData<'vir, Curr, Next> { self.1.check(self.name(), &args); StmtGenData::MethodCall(vcx.alloc(MethodCallGenData { targets: &[], @@ -148,8 +154,8 @@ impl<'vir> FunctionIdent<'vir, UnknownArity<'vir>> { pub fn apply<'tcx, Curr: 'vir, Next: 'vir>( &self, vcx: &'vir VirCtxt<'tcx>, - args: &[ExprGen<'vir, Curr, Next>] - ) -> ExprGen<'vir, Curr, Next>{ + args: &[ExprGen<'vir, Curr, Next>], + ) -> ExprGen<'vir, Curr, Next> { self.1.check(self.name(), args); vcx.mk_func_app(self.name(), args) } @@ -158,8 +164,8 @@ impl<'vir> PredicateIdent<'vir, UnknownArity<'vir>> { pub fn apply<'tcx, Curr: 'vir, Next: 'vir>( &self, vcx: &'vir VirCtxt<'tcx>, - args: &[ExprGen<'vir, Curr, Next>] - ) -> PredicateAppGen<'vir, Curr, Next>{ + args: &[ExprGen<'vir, Curr, Next>], + ) -> PredicateAppGen<'vir, Curr, Next> { self.1.check(self.name(), args); vcx.alloc(PredicateAppGenData { target: self.name(), @@ -171,8 +177,8 @@ impl<'vir> MethodIdent<'vir, UnknownArity<'vir>> { pub fn apply<'tcx, Curr: 'vir, Next: 'vir>( &self, vcx: &'vir VirCtxt<'tcx>, - args: &[ExprGen<'vir, Curr, Next>] - ) -> StmtGenData<'vir, Curr, Next>{ + args: &[ExprGen<'vir, Curr, Next>], + ) -> StmtGenData<'vir, Curr, Next> { self.1.check(self.name(), args); StmtGenData::MethodCall(vcx.alloc(MethodCallGenData { targets: &[], diff --git a/vir/src/context.rs b/vir/src/context.rs index d4592b06650..273c8983fcc 100644 --- a/vir/src/context.rs +++ b/vir/src/context.rs @@ -1,11 +1,8 @@ -use std::cell::RefCell; use prusti_interface::environment::EnvBody; use prusti_rustc_interface::middle::ty; +use std::cell::RefCell; -use crate::data::*; -use crate::gendata::*; -use crate::genrefs::*; -use crate::refs::*; +use crate::{data::*, gendata::*, genrefs::*, refs::*}; /// The VIR context is a data structure used throughout the encoding process. pub struct VirCtxt<'tcx> { @@ -20,13 +17,11 @@ pub struct VirCtxt<'tcx> { pub span_stack: Vec, // TODO: span stack // TODO: error positions? - /// The compiler's typing context. This allows convenient access to most /// of the compiler's APIs. pub tcx: ty::TyCtxt<'tcx>, pub body: RefCell>, - } impl<'tcx> VirCtxt<'tcx> { @@ -48,25 +43,23 @@ impl<'tcx> VirCtxt<'tcx> { &*self.arena.alloc_str(val) } -/* pub fn alloc_slice<'a, T: Copy>(&'tcx self, val: &'a [T]) -> &'tcx [T] { - &*self.arena.alloc_slice_copy(val) - }*/ + /* pub fn alloc_slice<'a, T: Copy>(&'tcx self, val: &'a [T]) -> &'tcx [T] { + &*self.arena.alloc_slice_copy(val) + }*/ pub fn alloc_slice(&self, val: &[T]) -> &[T] { &*self.arena.alloc_slice_copy(val) } pub fn mk_local<'vir>(&'vir self, name: &'vir str) -> Local<'vir> { - self.arena.alloc(LocalData { - name, - }) + self.arena.alloc(LocalData { name }) } pub fn mk_local_decl<'vir>(&'vir self, name: &'vir str, ty: Type<'vir>) -> LocalDecl<'vir> { - self.arena.alloc(LocalDeclData { - name, - ty, - }) + self.arena.alloc(LocalDeclData { name, ty }) } - pub fn mk_local_ex_local<'vir, Curr, Next>(&'vir self, local: Local<'vir>) -> ExprGen<'vir, Curr, Next> { + pub fn mk_local_ex_local<'vir, Curr, Next>( + &'vir self, + local: Local<'vir>, + ) -> ExprGen<'vir, Curr, Next> { self.arena.alloc(ExprGenData::Local(local)) } pub fn mk_local_ex<'vir, Curr, Next>(&'vir self, name: &'vir str) -> ExprGen<'vir, Curr, Next> { @@ -77,16 +70,18 @@ impl<'tcx> VirCtxt<'tcx> { target: &'vir str, src_args: &[ExprGen<'vir, Curr, Next>], ) -> ExprGen<'vir, Curr, Next> { - self.arena.alloc(ExprGenData::FuncApp(self.arena.alloc(FuncAppGenData { - target, - args: self.alloc_slice(src_args), - }))) + self.arena + .alloc(ExprGenData::FuncApp(self.arena.alloc(FuncAppGenData { + target, + args: self.alloc_slice(src_args), + }))) } pub fn mk_pred_app<'vir>(&'vir self, target: &'vir str, src_args: &[Expr<'vir>]) -> Expr<'vir> { - self.arena.alloc(ExprData::PredicateApp(self.arena.alloc(PredicateAppData { - target, - args: self.alloc_slice(src_args), - }))) + self.arena + .alloc(ExprData::PredicateApp(self.arena.alloc(PredicateAppData { + target, + args: self.alloc_slice(src_args), + }))) } pub const fn mk_bool<'vir, const VALUE: bool>(&'vir self) -> Expr<'vir> { @@ -149,12 +144,20 @@ impl<'tcx> VirCtxt<'tcx> { (u128::BITS, _) => { // TODO: make this a `const` once `Expr` isn't invariant in `'vir` so that it can be `'const` instead let half = self.mk_uint::<{ 1_u128 << u64::BITS }>(); - self.alloc(ExprData::BinOp(self.alloc(BinOpGenData { kind: BinOpKind::Add, lhs: half, rhs: half }))) + self.alloc(ExprData::BinOp(self.alloc(BinOpGenData { + kind: BinOpKind::Add, + lhs: half, + rhs: half, + }))) } _ => unreachable!(), } } - pub fn get_signed_shift_int<'vir>(&'vir self, ty: Type, rust_ty: &ty::TyKind) -> Option> { + pub fn get_signed_shift_int<'vir>( + &'vir self, + ty: Type, + rust_ty: &ty::TyKind, + ) -> Option> { let int = match Self::get_int_data(ty, rust_ty) { (_, false) => return None, (u8::BITS, true) => self.mk_uint::<{ 1_u128 << (u8::BITS - 1) }>(), diff --git a/vir/src/data.rs b/vir/src/data.rs index 519f9e4eb91..ff3c27276d6 100644 --- a/vir/src/data.rs +++ b/vir/src/data.rs @@ -1,7 +1,7 @@ use std::fmt::Debug; -use prusti_rustc_interface::middle::mir; use crate::refs::*; +use prusti_rustc_interface::middle::mir; pub struct LocalData<'vir> { pub name: &'vir str, // TODO: identifiers @@ -86,10 +86,7 @@ pub enum ConstData { } pub enum TypeData<'vir> { - Int { - bit_width: u8, - signed: bool, - }, + Int { bit_width: u8, signed: bool }, Bool, Domain(&'vir str), // TODO: identifiers DomainParams(&'vir str, &'vir [Type<'vir>]), diff --git a/vir/src/debug.rs b/vir/src/debug.rs index 8edc9e5cf5b..dd8ff417fa4 100644 --- a/vir/src/debug.rs +++ b/vir/src/debug.rs @@ -1,13 +1,14 @@ use std::fmt::{Debug, Display, Formatter, Result as FmtResult}; -use crate::data::*; -use crate::gendata::*; +use crate::{data::*, gendata::*}; fn fmt_comma_sep_display(f: &mut Formatter<'_>, els: &[T]) -> FmtResult { els.iter() .enumerate() .map(|(idx, el)| { - if idx > 0 { write!(f, ", ")? } + if idx > 0 { + write!(f, ", ")? + } el.fmt(f) }) .collect::() @@ -16,7 +17,9 @@ fn fmt_comma_sep(f: &mut Formatter<'_>, els: &[T]) -> FmtResult { els.iter() .enumerate() .map(|(idx, el)| { - if idx > 0 { write!(f, ", ")? } + if idx > 0 { + write!(f, ", ")? + } el.fmt(f) }) .collect::() @@ -32,10 +35,7 @@ fn fmt_comma_sep_lines(f: &mut Formatter<'_>, els: &[T]) -> FmtResult Ok(()) } fn indent(s: String) -> String { - s - .split("\n") - .intersperse("\n ") - .collect::() + s.split("\n").intersperse("\n ").collect::() } impl<'vir, Curr, Next> Debug for AccFieldGenData<'vir, Curr, Next> { @@ -48,19 +48,23 @@ impl<'vir, Curr, Next> Debug for BinOpGenData<'vir, Curr, Next> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "(")?; self.lhs.fmt(f)?; - write!(f, ") {} (", match self.kind { - BinOpKind::CmpEq => "==", - BinOpKind::CmpNe => "!=", - BinOpKind::CmpGt => ">", - BinOpKind::CmpGe => ">=", - BinOpKind::CmpLt => "<", - BinOpKind::CmpLe => "<=", - BinOpKind::And => "&&", - BinOpKind::Or => "||", - BinOpKind::Add => "+", - BinOpKind::Sub => "-", - BinOpKind::Mod => "%", - })?; + write!( + f, + ") {} (", + match self.kind { + BinOpKind::CmpEq => "==", + BinOpKind::CmpNe => "!=", + BinOpKind::CmpGt => ">", + BinOpKind::CmpGe => ">=", + BinOpKind::CmpLt => "<", + BinOpKind::CmpLe => "<=", + BinOpKind::And => "&&", + BinOpKind::Or => "||", + BinOpKind::Add => "+", + BinOpKind::Sub => "-", + BinOpKind::Mod => "%", + } + )?; self.rhs.fmt(f)?; write!(f, ")") } @@ -95,8 +99,14 @@ impl<'vir, Curr, Next> Debug for DomainGenData<'vir, Curr, Next> { write!(f, "]")?; } writeln!(f, " {{")?; - self.axioms.iter().map(|el| el.fmt(f)).collect::()?; - self.functions.iter().map(|el| el.fmt(f)).collect::()?; + self.axioms + .iter() + .map(|el| el.fmt(f)) + .collect::()?; + self.functions + .iter() + .map(|el| el.fmt(f)) + .collect::()?; writeln!(f, "}}") } } @@ -176,8 +186,14 @@ impl<'vir, Curr, Next> Debug for FunctionGenData<'vir, Curr, Next> { writeln!(f, "function {}(", self.name)?; fmt_comma_sep_lines(f, &self.args)?; writeln!(f, "): {:?}", self.ret)?; - self.pres.iter().map(|el| writeln!(f, " requires {:?}", el)).collect::()?; - self.posts.iter().map(|el| writeln!(f, " ensures {:?}", el)).collect::()?; + self.pres + .iter() + .map(|el| writeln!(f, " requires {:?}", el)) + .collect::()?; + self.posts + .iter() + .map(|el| writeln!(f, " ensures {:?}", el)) + .collect::()?; if let Some(expr) = self.expr { write!(f, "{{\n ")?; expr.fmt(f)?; @@ -224,8 +240,14 @@ impl<'vir, Curr, Next> Debug for MethodGenData<'vir, Curr, Next> { } else { writeln!(f, ")")?; } - self.pres.iter().map(|el| writeln!(f, " requires {:?}", el)).collect::()?; - self.posts.iter().map(|el| writeln!(f, " ensures {:?}", el)).collect::()?; + self.pres + .iter() + .map(|el| writeln!(f, " requires {:?}", el)) + .collect::()?; + self.posts + .iter() + .map(|el| writeln!(f, " ensures {:?}", el)) + .collect::()?; if let Some(blocks) = self.blocks.as_ref() { writeln!(f, "{{")?; for block in blocks.iter() { @@ -304,7 +326,11 @@ impl<'vir, Curr, Next> Debug for TerminatorStmtGenData<'vir, Curr, Next> { write!(f, "goto {:?}", data.otherwise) } else { for target in data.targets { - write!(f, "if ({:?} == {:?}) {{ goto {:?} }}\n else", data.value, target.0, target.1)?; + write!( + f, + "if ({:?} == {:?}) {{ goto {:?} }}\n else", + data.value, target.0, target.1 + )?; } write!(f, " {{ goto {:?} }}", data.otherwise) } @@ -345,10 +371,15 @@ impl<'vir> Debug for TypeData<'vir> { impl<'vir, Curr, Next> Debug for UnOpGenData<'vir, Curr, Next> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - write!(f, "{}({:?})", match self.kind { - UnOpKind::Neg => "-", - UnOpKind::Not => "!", - }, self.expr) + write!( + f, + "{}({:?})", + match self.kind { + UnOpKind::Neg => "-", + UnOpKind::Not => "!", + }, + self.expr + ) } } diff --git a/vir/src/gendata.rs b/vir/src/gendata.rs index 4ed17048b23..e69bc2d2f82 100644 --- a/vir/src/gendata.rs +++ b/vir/src/gendata.rs @@ -1,20 +1,20 @@ use std::fmt::Debug; -use crate::data::*; -use crate::genrefs::*; -use crate::refs::*; +use crate::{data::*, genrefs::*, refs::*}; use vir_proc_macro::*; #[derive(Reify)] pub struct UnOpGenData<'vir, Curr, Next> { - #[reify_copy] pub kind: UnOpKind, + #[reify_copy] + pub kind: UnOpKind, pub expr: ExprGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct BinOpGenData<'vir, Curr, Next> { - #[reify_copy] pub kind: BinOpKind, + #[reify_copy] + pub kind: BinOpKind, pub lhs: ExprGen<'vir, Curr, Next>, pub rhs: ExprGen<'vir, Curr, Next>, } @@ -28,20 +28,23 @@ pub struct TernaryGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct ForallGenData<'vir, Curr, Next> { - #[reify_copy] pub qvars: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub qvars: &'vir [LocalDecl<'vir>], pub triggers: &'vir [&'vir [ExprGen<'vir, Curr, Next>]], pub body: ExprGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct FuncAppGenData<'vir, Curr, Next> { - #[reify_copy] pub target: &'vir str, // TODO: identifiers + #[reify_copy] + pub target: &'vir str, // TODO: identifiers pub args: &'vir [ExprGen<'vir, Curr, Next>], } #[derive(Reify)] pub struct PredicateAppGenData<'vir, Curr, Next> { - #[reify_copy] pub target: &'vir str, // TODO: identifiers + #[reify_copy] + pub target: &'vir str, // TODO: identifiers pub args: &'vir [ExprGen<'vir, Curr, Next>], } @@ -54,13 +57,15 @@ pub struct UnfoldingGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct AccFieldGenData<'vir, Curr, Next> { pub recv: ExprGen<'vir, Curr, Next>, - #[reify_copy] pub field: &'vir str, // TODO: identifiers - // TODO: permission amount + #[reify_copy] + pub field: &'vir str, // TODO: identifiers + // TODO: permission amount } #[derive(Reify)] pub struct LetGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, + #[reify_copy] + pub name: &'vir str, pub val: ExprGen<'vir, Curr, Next>, pub expr: ExprGen<'vir, Curr, Next>, } @@ -103,8 +108,10 @@ pub enum ExprGenData<'vir, Curr: 'vir, Next: 'vir> { PredicateApp(PredicateAppGen<'vir, Curr, Next>), // TODO: this should not be used instead of acc? // domain func app // inhale/exhale - - Lazy(&'vir str, Box Fn(&'vir crate::VirCtxt<'a>, Curr) -> Next + 'vir>), + Lazy( + &'vir str, + Box Fn(&'vir crate::VirCtxt<'a>, Curr) -> Next + 'vir>, + ), Todo(&'vir str), } @@ -120,30 +127,39 @@ impl<'vir, Curr, Next> ExprGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct DomainAxiomGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // ? or comment, then auto-gen the names? + #[reify_copy] + pub name: &'vir str, // ? or comment, then auto-gen the names? pub expr: ExprGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct DomainGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub typarams: &'vir [&'vir str], + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub typarams: &'vir [&'vir str], pub axioms: &'vir [DomainAxiomGen<'vir, Curr, Next>], - #[reify_copy] pub functions: &'vir [DomainFunction<'vir>], + #[reify_copy] + pub functions: &'vir [DomainFunction<'vir>], } #[derive(Reify)] pub struct PredicateGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub args: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub args: &'vir [LocalDecl<'vir>], pub expr: Option>, } #[derive(Reify)] pub struct FunctionGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub args: &'vir [LocalDecl<'vir>], - #[reify_copy] pub ret: Type<'vir>, + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub args: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub ret: Type<'vir>, pub pres: &'vir [ExprGen<'vir, Curr, Next>], pub posts: &'vir [ExprGen<'vir, Curr, Next>], pub expr: Option>, @@ -160,14 +176,19 @@ pub struct PureAssignGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct MethodCallGenData<'vir, Curr, Next> { - #[reify_copy] pub targets: &'vir [Local<'vir>], - #[reify_copy] pub method: &'vir str, + #[reify_copy] + pub targets: &'vir [Local<'vir>], + #[reify_copy] + pub method: &'vir str, pub args: &'vir [ExprGen<'vir, Curr, Next>], } #[derive(Reify)] pub enum StmtGenData<'vir, Curr, Next> { - LocalDecl(#[reify_copy] LocalDecl<'vir>, Option>), + LocalDecl( + #[reify_copy] LocalDecl<'vir>, + Option>, + ), PureAssign(PureAssignGen<'vir, Curr, Next>), Inhale(ExprGen<'vir, Curr, Next>), Exhale(ExprGen<'vir, Curr, Next>), @@ -182,7 +203,8 @@ pub enum StmtGenData<'vir, Curr, Next> { pub struct GotoIfGenData<'vir, Curr, Next> { pub value: ExprGen<'vir, Curr, Next>, pub targets: &'vir [(ExprGen<'vir, Curr, Next>, CfgBlockLabel<'vir>)], - #[reify_copy] pub otherwise: CfgBlockLabel<'vir>, + #[reify_copy] + pub otherwise: CfgBlockLabel<'vir>, } #[derive(Reify)] @@ -196,16 +218,20 @@ pub enum TerminatorStmtGenData<'vir, Curr, Next> { #[derive(Debug, Reify)] pub struct CfgBlockGenData<'vir, Curr, Next> { - #[reify_copy] pub label: CfgBlockLabel<'vir>, + #[reify_copy] + pub label: CfgBlockLabel<'vir>, pub stmts: &'vir [StmtGen<'vir, Curr, Next>], pub terminator: TerminatorStmtGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct MethodGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub args: &'vir [LocalDecl<'vir>], - #[reify_copy] pub rets: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub args: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub rets: &'vir [LocalDecl<'vir>], // TODO: pre/post - add a comment variant pub pres: &'vir [ExprGen<'vir, Curr, Next>], pub posts: &'vir [ExprGen<'vir, Curr, Next>], @@ -214,7 +240,8 @@ pub struct MethodGenData<'vir, Curr, Next> { #[derive(Debug, Reify)] pub struct ProgramGenData<'vir, Curr, Next> { - #[reify_copy] pub fields: &'vir [Field<'vir>], + #[reify_copy] + pub fields: &'vir [Field<'vir>], pub domains: &'vir [DomainGen<'vir, Curr, Next>], pub predicates: &'vir [PredicateGen<'vir, Curr, Next>], pub functions: &'vir [FunctionGen<'vir, Curr, Next>], diff --git a/vir/src/genrefs.rs b/vir/src/genrefs.rs index f21ca3c8ef8..f98eb0bef1a 100644 --- a/vir/src/genrefs.rs +++ b/vir/src/genrefs.rs @@ -1,7 +1,8 @@ pub type AccFieldGen<'vir, Curr, Next> = &'vir crate::gendata::AccFieldGenData<'vir, Curr, Next>; pub type BinOpGen<'vir, Curr, Next> = &'vir crate::gendata::BinOpGenData<'vir, Curr, Next>; pub type CfgBlockGen<'vir, Curr, Next> = &'vir crate::gendata::CfgBlockGenData<'vir, Curr, Next>; -pub type DomainAxiomGen<'vir, Curr, Next> = &'vir crate::gendata::DomainAxiomGenData<'vir, Curr, Next>; +pub type DomainAxiomGen<'vir, Curr, Next> = + &'vir crate::gendata::DomainAxiomGenData<'vir, Curr, Next>; pub type DomainGen<'vir, Curr, Next> = &'vir crate::gendata::DomainGenData<'vir, Curr, Next>; pub type ExprGen<'vir, Curr, Next> = &'vir crate::gendata::ExprGenData<'vir, Curr, Next>; pub type ForallGen<'vir, Curr, Next> = &'vir crate::gendata::ForallGenData<'vir, Curr, Next>; @@ -10,13 +11,17 @@ pub type FunctionGen<'vir, Curr, Next> = &'vir crate::gendata::FunctionGenData<' pub type GotoIfGen<'vir, Curr, Next> = &'vir crate::gendata::GotoIfGenData<'vir, Curr, Next>; pub type LetGen<'vir, Curr, Next> = &'vir crate::gendata::LetGenData<'vir, Curr, Next>; pub type MethodGen<'vir, Curr, Next> = &'vir crate::gendata::MethodGenData<'vir, Curr, Next>; -pub type MethodCallGen<'vir, Curr, Next> = &'vir crate::gendata::MethodCallGenData<'vir, Curr, Next>; +pub type MethodCallGen<'vir, Curr, Next> = + &'vir crate::gendata::MethodCallGenData<'vir, Curr, Next>; pub type PredicateGen<'vir, Curr, Next> = &'vir crate::gendata::PredicateGenData<'vir, Curr, Next>; -pub type PredicateAppGen<'vir, Curr, Next> = &'vir crate::gendata::PredicateAppGenData<'vir, Curr, Next>; +pub type PredicateAppGen<'vir, Curr, Next> = + &'vir crate::gendata::PredicateAppGenData<'vir, Curr, Next>; pub type ProgramGen<'vir, Curr, Next> = &'vir crate::gendata::ProgramGenData<'vir, Curr, Next>; -pub type PureAssignGen<'vir, Curr, Next> = &'vir crate::gendata::PureAssignGenData<'vir, Curr, Next>; +pub type PureAssignGen<'vir, Curr, Next> = + &'vir crate::gendata::PureAssignGenData<'vir, Curr, Next>; pub type StmtGen<'vir, Curr, Next> = &'vir crate::gendata::StmtGenData<'vir, Curr, Next>; -pub type TerminatorStmtGen<'vir, Curr, Next> = &'vir crate::gendata::TerminatorStmtGenData<'vir, Curr, Next>; +pub type TerminatorStmtGen<'vir, Curr, Next> = + &'vir crate::gendata::TerminatorStmtGenData<'vir, Curr, Next>; pub type TernaryGen<'vir, Curr, Next> = &'vir crate::gendata::TernaryGenData<'vir, Curr, Next>; pub type UnOpGen<'vir, Curr, Next> = &'vir crate::gendata::UnOpGenData<'vir, Curr, Next>; pub type UnfoldingGen<'vir, Curr, Next> = &'vir crate::gendata::UnfoldingGenData<'vir, Curr, Next>; diff --git a/vir/src/lib.rs b/vir/src/lib.rs index eb9c2f352c7..098f62db01d 100644 --- a/vir/src/lib.rs +++ b/vir/src/lib.rs @@ -13,13 +13,13 @@ mod refs; mod reify; mod callable_idents; +pub use callable_idents::*; pub use context::*; pub use data::*; pub use gendata::*; pub use genrefs::*; pub use refs::*; pub use reify::*; -pub use callable_idents::*; // for all arena-allocated types, there are two type definitions: one with // a `Data` suffix, containing the actual data; and one without the suffix, diff --git a/vir/src/macros.rs b/vir/src/macros.rs index 8441cc466fa..8a8685fde99 100644 --- a/vir/src/macros.rs +++ b/vir/src/macros.rs @@ -1,6 +1,6 @@ //#[macro_export] //macro_rules! vir_expr_nopos { -// +// //} //#[macro_export] @@ -112,8 +112,12 @@ macro_rules! vir_expr { #[macro_export] macro_rules! vir_ident { - ($vcx:expr; [ $name:expr ]) => { $name }; - ($vcx:expr; $name:ident ) => { $vcx.alloc_str(stringify!($name)) }; + ($vcx:expr; [ $name:expr ]) => { + $name + }; + ($vcx:expr; $name:ident ) => { + $vcx.alloc_str(stringify!($name)) + }; } #[macro_export] @@ -123,11 +127,27 @@ macro_rules! vir_format { #[macro_export] macro_rules! vir_type { - ($vcx:expr; Bool) => { & $crate::TypeData::Bool }; - ($vcx:expr; Ref) => { & $crate::TypeData::Ref }; - ($vcx:expr; Uint($bit_width:expr)) => { $vcx.alloc($crate::TypeData::Int { signed: false, bit_width: $bit_width }) }; - ($vcx:expr; Int($bit_width:expr)) => { $vcx.alloc($crate::TypeData::Int { signed: true, bit_width: $bit_width }) }; - ($vcx:expr; [ $ty:expr ]) => { $ty }; + ($vcx:expr; Bool) => { + &$crate::TypeData::Bool + }; + ($vcx:expr; Ref) => { + &$crate::TypeData::Ref + }; + ($vcx:expr; Uint($bit_width:expr)) => { + $vcx.alloc($crate::TypeData::Int { + signed: false, + bit_width: $bit_width, + }) + }; + ($vcx:expr; Int($bit_width:expr)) => { + $vcx.alloc($crate::TypeData::Int { + signed: true, + bit_width: $bit_width, + }) + }; + ($vcx:expr; [ $ty:expr ]) => { + $ty + }; ($vcx:expr; $name:ident) => { $vcx.alloc($crate::TypeData::Domain($vcx.alloc_str(stringify!($name)))) }; diff --git a/vir/src/reify.rs b/vir/src/reify.rs index 9c720b8cc10..6ddb96884fd 100644 --- a/vir/src/reify.rs +++ b/vir/src/reify.rs @@ -1,18 +1,11 @@ -use crate::VirCtxt; -use crate::gendata::*; -use crate::genrefs::*; -use crate::refs::*; +use crate::{gendata::*, genrefs::*, refs::*, VirCtxt}; pub use vir_proc_macro::*; pub trait Reify<'vir, Curr> { type Next: Sized; - fn reify<'tcx>( - &self, - vcx: &'vir VirCtxt<'tcx>, - lctx: Curr, - ) -> Self::Next; + fn reify<'tcx>(&self, vcx: &'vir VirCtxt<'tcx>, lctx: Curr) -> Self::Next; } impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr> @@ -31,7 +24,9 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr> ExprGenData::Forall(v) => vcx.alloc(ExprGenData::Forall(v.reify(vcx, lctx))), ExprGenData::Let(v) => vcx.alloc(ExprGenData::Let(v.reify(vcx, lctx))), ExprGenData::FuncApp(v) => vcx.alloc(ExprGenData::FuncApp(v.reify(vcx, lctx))), - ExprGenData::PredicateApp(v) => vcx.alloc(ExprGenData::PredicateApp(v.reify(vcx, lctx))), + ExprGenData::PredicateApp(v) => { + vcx.alloc(ExprGenData::PredicateApp(v.reify(vcx, lctx))) + } ExprGenData::Local(v) => vcx.alloc(ExprGenData::Local(v)), ExprGenData::Const(v) => vcx.alloc(ExprGenData::Const(v)), @@ -51,9 +46,12 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr> { type Next = &'vir [ExprGen<'vir, NextA, NextB>]; fn reify<'tcx>(&self, vcx: &'vir VirCtxt<'tcx>, lctx: Curr) -> Self::Next { - vcx.alloc_slice(&self.iter() - .map(|elem| elem.reify(vcx, lctx)) - .collect::>()) + vcx.alloc_slice( + &self + .iter() + .map(|elem| elem.reify(vcx, lctx)) + .collect::>(), + ) } } @@ -62,20 +60,29 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr> { type Next = &'vir [&'vir [ExprGen<'vir, NextA, NextB>]]; fn reify<'tcx>(&self, vcx: &'vir VirCtxt<'tcx>, lctx: Curr) -> Self::Next { - vcx.alloc_slice(&self.iter() - .map(|elem| elem.reify(vcx, lctx)) - .collect::>()) + vcx.alloc_slice( + &self + .iter() + .map(|elem| elem.reify(vcx, lctx)) + .collect::>(), + ) } } impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr> - for [(ExprGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>, CfgBlockLabel<'vir>)] + for [( + ExprGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>, + CfgBlockLabel<'vir>, + )] { type Next = &'vir [(ExprGen<'vir, NextA, NextB>, CfgBlockLabel<'vir>)]; fn reify<'tcx>(&self, vcx: &'vir VirCtxt<'tcx>, lctx: Curr) -> Self::Next { - vcx.alloc_slice(&self.iter() - .map(|(elem, label)| (elem.reify(vcx, lctx), *label)) - .collect::>()) + vcx.alloc_slice( + &self + .iter() + .map(|(elem, label)| (elem.reify(vcx, lctx), *label)) + .collect::>(), + ) } } @@ -97,7 +104,6 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr> } } - /* impl< 'vir,