From 4f70d403edf9feecb7df8a97bb704f5aa14d9893 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Fri, 27 Oct 2023 12:24:52 +0200 Subject: [PATCH 01/18] Miscellaneous small fixes --- Cargo.lock | 1 + prusti-encoder/src/encoders/mir_builtin.rs | 16 +++++-- prusti-encoder/src/encoders/mir_impure.rs | 30 ++++++++++-- prusti-encoder/src/encoders/mir_pure.rs | 46 +++++++++++++++---- .../src/encoders/mir_pure_function.rs | 12 ++--- prusti-encoder/src/encoders/typ.rs | 13 ++++-- prusti-encoder/src/encoders/viper_tuple.rs | 2 +- prusti-encoder/src/lib.rs | 8 ++-- tracing/Cargo.toml | 2 +- vir/src/data.rs | 1 + vir/src/debug.rs | 1 + 11 files changed, 98 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c05b071396d..100df83d4ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2558,6 +2558,7 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", diff --git a/prusti-encoder/src/encoders/mir_builtin.rs b/prusti-encoder/src/encoders/mir_builtin.rs index 032d69fdd45..69902ee0094 100644 --- a/prusti-encoder/src/encoders/mir_builtin.rs +++ b/prusti-encoder/src/encoders/mir_builtin.rs @@ -189,8 +189,16 @@ impl TaskEncoder for MirBuiltinEncoder { }, ())) } - MirBuiltinEncoderTask::CheckedBinOp(mir::BinOp::Add | mir::BinOp::AddUnchecked, ty) => { - let name = vir::vir_format!(vcx, "mir_checkedbinop_add_{}", int_name(*ty)); + MirBuiltinEncoderTask::CheckedBinOp(op @ (mir::BinOp::Add | mir::BinOp::AddUnchecked | mir::BinOp::Sub | mir::BinOp::SubUnchecked), ty) => { + + let (op_name, vir_op) = match op { + mir::BinOp::Add | mir::BinOp::AddUnchecked => ("add", vir::BinOpKind::Add), + mir::BinOp::Sub | mir::BinOp::SubUnchecked => ("sub", vir::BinOpKind::Sub), + + _ => unreachable!() + }; + let name = vir::vir_format!(vcx, "mir_checkedbinop_{}_{}", op_name, int_name(*ty)); + deps.emit_output_ref::(task_key.clone(), MirBuiltinEncoderOutputRef { name, }); @@ -220,7 +228,7 @@ impl TaskEncoder for MirBuiltinEncoder { vcx.mk_func_app( ty_in.from_primitive.unwrap(), &[vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::Add, + kind: vir_op, lhs: vcx.mk_func_app( ty_in.to_primitive.unwrap(), &[vcx.mk_local_ex("arg1")], @@ -242,7 +250,7 @@ impl TaskEncoder for MirBuiltinEncoder { }, ())) } - _ => todo!(), + other => todo!("{other:?}"), } }) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index de64582a988..619bfee2f00 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -613,7 +613,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { rhs: self.encode_operand_snap(r), })))], )), - mir::Rvalue::BinaryOp(mir::BinOp::Lt, box (l, r)) => { + mir::Rvalue::BinaryOp(op, box (l, r)) => { let ty_l = self.deps.require_ref::( l.ty(self.local_decls, self.vcx.tcx), ).unwrap().to_primitive.unwrap(); @@ -624,7 +624,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { Some(self.vcx.mk_func_app( "s_Bool_cons", // TODO: go through type encoder &[self.vcx.alloc(vir::ExprData::BinOp(self.vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpLt, + kind: vir::BinOpKind::from(op), lhs: self.vcx.mk_func_app( ty_l, &[self.encode_operand_snap(l)], @@ -686,7 +686,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { } mir::Rvalue::Aggregate( - box mir::AggregateKind::Adt(..), + box mir::AggregateKind::Adt(..) | box mir::AggregateKind::Tuple, fields, ) => { let dest_ty_struct = dest_ty_out.expect_structlike(); @@ -712,7 +712,8 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { //mir::Rvalue::Discriminant(Place<'tcx>) => {} //mir::Rvalue::ShallowInitBox(Operand<'tcx>, Ty<'tcx>) => {} //mir::Rvalue::CopyForDeref(Place<'tcx>) => {} - _ => { + other => { + tracing::error!("unsupported rvalue {other:?}"); Some(self.vcx.alloc(vir::ExprData::Todo( vir::vir_format!(self.vcx, "rvalue {rvalue:?}"), ))) @@ -735,7 +736,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { // no-ops mir::StatementKind::FakeRead(_) | mir::StatementKind::Retag(..) - //| mir::StatementKind::PlaceMention(_) + | mir::StatementKind::PlaceMention(_) | mir::StatementKind::AscribeUserType(..) | mir::StatementKind::Coverage(_) //| mir::StatementKind::ConstEvalCounter @@ -818,6 +819,25 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.unwrap().as_usize())), )) } + mir::TerminatorKind::Assert { cond, expected, msg, target, unwind } => { + + let otherwise = match unwind { + mir::UnwindAction::Cleanup(bb) => bb, + _ => todo!() + }; + + let enc = self.encode_operand_snap(cond); + let enc = self.vcx.mk_func_app("s_Bool_val", &[enc]); + + let target_bb = self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())); + + self.vcx.alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc(vir::GotoIfData { + value: enc, + targets: self.vcx.alloc_slice(&[(self.vcx.alloc(vir::ExprData::Const(self.vcx.alloc(vir::ConstData::Bool(*expected)))) + , &target_bb)]), + otherwise: self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(otherwise.as_usize())), + }))) + } unsupported_kind => self.vcx.alloc(vir::TerminatorStmtData::Dummy( vir::vir_format!(self.vcx, "terminator {unsupported_kind:?}"), )), diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index 0d312e49adc..bdf59ab05d1 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -268,7 +268,7 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> update.versions.get(local).copied().unwrap_or_else(|| { // TODO: remove (debug) if !curr_ver.contains_key(&local) { - println!("unknown version of local! {}", local.as_usize()); + tracing::error!("unknown version of local! {}", local.as_usize()); return 0xff } curr_ver[local] @@ -416,8 +416,7 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> for (elem_idx, local) in mod_locals.iter().enumerate() { let expr = self.mk_phi_acc(tuple_ref.clone(), phi_idx, elem_idx); self.bump_version(&mut phi_update, *local, expr); - // TODO: add to curr_ver here ? - //new_curr_ver.insert(*local, ); + new_curr_ver.insert(*local, phi_update.versions[local]); } // walk `join` -> `end` @@ -546,6 +545,13 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> )) .lift(); + // TODO: use type encoder + let body = + self.vcx.mk_func_app( + "s_Bool_val", + &[body], + ); + // TODO: use type encoder let forall = self.vcx.mk_func_app( "s_Bool_cons", @@ -566,7 +572,9 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> stmt_update.merge(term_update).merge(end_update) } - None => todo!(), + None => { + todo!("call not supported {func:?}"); + } } } @@ -581,10 +589,15 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> ) -> Update<'vir> { let mut update = Update::new(); match &stmt.kind { - mir::StatementKind::StorageLive(..) - | mir::StatementKind::StorageDead(..) + mir::StatementKind::StorageLive(local) => { + let new_version = self.version_ctr.get(local).copied().unwrap_or(0usize); + self.version_ctr.insert(*local, new_version + 1); + update.versions.insert(*local, new_version); + }, + mir::StatementKind::StorageDead(..) | mir::StatementKind::FakeRead(..) - | mir::StatementKind::AscribeUserType(..) => {}, // nop + | mir::StatementKind::AscribeUserType(..) + | mir::StatementKind::PlaceMention(..)=> {}, // nop mir::StatementKind::Assign(box (dest, rvalue)) => { assert!(dest.projection.is_empty()); let expr = self.encode_rvalue(curr_ver, rvalue); @@ -676,7 +689,22 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> .map(|field| self.encode_operand(curr_ver, field)) .collect::>()) } - _ => todo!(), + _ => todo!("Unsupported Rvalue::AggregateKind: {kind:?}"), + } + mir::Rvalue::CheckedBinaryOp(binop, box (l, r)) => { + let binop_function = self.deps.require_ref::( + crate::encoders::MirBuiltinEncoderTask::CheckedBinOp( + *binop, + l.ty(self.body, self.vcx.tcx), // TODO: ? + ), + ).unwrap().name; + self.vcx.mk_func_app( + binop_function, + &[ + self.encode_operand(curr_ver, l), + self.encode_operand(curr_ver, r), + ], + ) } // ShallowInitBox // CopyForDeref @@ -734,7 +762,7 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> ) -> ExprRet<'vir> { // TODO: remove (debug) if !curr_ver.contains_key(&place.local) { - println!("unknown version of local! {}", place.local.as_usize()); + tracing::error!("unknown version of local! {}", place.local.as_usize()); return self.vcx.alloc(ExprRetData::Todo( vir::vir_format!(self.vcx, "unknown_version_{}", place.local.as_usize()), )); diff --git a/prusti-encoder/src/encoders/mir_pure_function.rs b/prusti-encoder/src/encoders/mir_pure_function.rs index 60cab5ac4a9..b506afd6e3e 100644 --- a/prusti-encoder/src/encoders/mir_pure_function.rs +++ b/prusti-encoder/src/encoders/mir_pure_function.rs @@ -16,13 +16,13 @@ pub enum MirFunctionEncoderError { #[derive(Clone, Debug)] pub struct MirFunctionEncoderOutputRef<'vir> { - pub method_name: &'vir str, + pub function_name: &'vir str, } impl<'vir> task_encoder::OutputRefAny<'vir> for MirFunctionEncoderOutputRef<'vir> {} #[derive(Clone, Debug)] pub struct MirFunctionEncoderOutput<'vir> { - pub method: vir::Function<'vir>, + pub function: vir::Function<'vir>, } thread_local! { @@ -72,8 +72,8 @@ impl TaskEncoder for MirFunctionEncoder { tracing::debug!("encoding {def_id:?}"); - let method_name = vir::vir_format!(vcx, "f_{}", vcx.tcx.item_name(def_id)); - deps.emit_output_ref::(def_id, MirFunctionEncoderOutputRef { method_name }); + let function_name = vir::vir_format!(vcx, "f_{}", vcx.tcx.item_name(def_id)); + deps.emit_output_ref::(def_id, MirFunctionEncoderOutputRef { function_name }); let local_defs = deps.require_local::( def_id, @@ -104,8 +104,8 @@ impl TaskEncoder for MirFunctionEncoder { Ok(( MirFunctionEncoderOutput { - method: vcx.alloc(vir::FunctionData { - name: method_name, + function: vcx.alloc(vir::FunctionData { + name: function_name, args: vcx.alloc_slice(&func_args), ret: local_defs.locals[mir::RETURN_PLACE].snapshot, pres: vcx.alloc_slice(&spec.pres), diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index cc240f5a7dd..7ba981fc377 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -709,14 +709,19 @@ impl TaskEncoder for TypeEncoder { .iter() .map(|ty| deps.require_ref::(ty).unwrap()) .collect::>(); - // TODO: name the tuple according to its types, or make generic? + + // TODO: Properly name the tuple according to its types, or make generic? + + // Temporary fix to make it possisble to have multiple tuples of the same size with different element types + let tmp_ty_name = field_ty_out.iter().map(|e| e.snapshot_name).collect::>().join("_"); + Ok((mk_structlike( vcx, deps, task_key, - vir::vir_format!(vcx, "s_Tuple{}", tys.len()), - vir::vir_format!(vcx, "p_Tuple{}", tys.len()), + vir::vir_format!(vcx, "s_Tuple{}_{}", tys.len(), tmp_ty_name), + vir::vir_format!(vcx, "p_Tuple{}_{}", tys.len(), tmp_ty_name), field_ty_out, )?, ())) @@ -782,7 +787,7 @@ impl TaskEncoder for TypeEncoder { }, ())) } TyKind::Adt(adt_def, substs) if adt_def.is_struct() => { - println!("encoding ADT {adt_def:?} with substs {substs:?}"); + tracing::debug!("encoding ADT {adt_def:?} with substs {substs:?}"); let substs = ty::List::identity_for_item(vcx.tcx, adt_def.did()); let field_ty_out = adt_def.all_fields() .map(|field| deps.require_ref::(field.ty(vcx.tcx, substs)).unwrap()) diff --git a/prusti-encoder/src/encoders/viper_tuple.rs b/prusti-encoder/src/encoders/viper_tuple.rs index 87bee1488cc..708665d88a1 100644 --- a/prusti-encoder/src/encoders/viper_tuple.rs +++ b/prusti-encoder/src/encoders/viper_tuple.rs @@ -150,7 +150,7 @@ impl TaskEncoder for ViperTupleEncoder { domain: Some(vcx.alloc(vir::DomainData { name: domain_name, typarams: vcx.alloc_slice(&typaram_names), - axioms: vcx.alloc_slice(&[axiom]), + axioms: if task_key == &0 {&[]} else {vcx.alloc_slice(&[axiom])}, functions: vcx.alloc_slice(&functions), })), }, ())) diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index cdf1c061312..37b008342ca 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -9,7 +9,7 @@ extern crate rustc_type_ir; mod encoders; -use prusti_interface::environment::EnvBody; +use prusti_interface::{environment::EnvBody, specs::typed::SpecificationItem}; use prusti_rustc_interface::{ middle::ty, hir, @@ -127,7 +127,7 @@ pub fn test_entrypoint<'tcx>( // TODO: this should be a "crate" encoder, which will deps.require all the methods in the crate for def_id in tcx.hir_crate_items(()).definitions() { - //println!("item: {def_id:?}"); + tracing::debug!("test_entrypoint item: {def_id:?}"); let kind = tcx.def_kind(def_id); //println!(" kind: {:?}", kind); /*if !format!("{def_id:?}").contains("foo") { @@ -162,7 +162,7 @@ pub fn test_entrypoint<'tcx>( }*/ } unsupported_item_kind => { - println!("another item: {unsupported_item_kind:?}"); + tracing::debug!("unsupported item: {unsupported_item_kind:?}"); } } } @@ -182,7 +182,7 @@ pub fn test_entrypoint<'tcx>( header(&mut viper_code, "functions"); for output in crate::encoders::MirFunctionEncoder::all_outputs() { - viper_code.push_str(&format!("{:?}\n", output.method)); + viper_code.push_str(&format!("{:?}\n", output.function)); } header(&mut viper_code, "MIR builtins"); diff --git a/tracing/Cargo.toml b/tracing/Cargo.toml index d96a0d59852..3c9026321eb 100644 --- a/tracing/Cargo.toml +++ b/tracing/Cargo.toml @@ -8,5 +8,5 @@ edition = "2021" doctest = false [dependencies] -tracing = "0.1" +tracing = { version = "0.1", features = ["log"] } proc-macro-tracing = { path = "proc-macro-tracing" } diff --git a/vir/src/data.rs b/vir/src/data.rs index b4a6f21cf9d..d201bb83a40 100644 --- a/vir/src/data.rs +++ b/vir/src/data.rs @@ -41,6 +41,7 @@ pub enum BinOpKind { CmpLe, And, Add, + Sub, // ... } impl From for BinOpKind { diff --git a/vir/src/debug.rs b/vir/src/debug.rs index 7ba9cf57145..e475019abf9 100644 --- a/vir/src/debug.rs +++ b/vir/src/debug.rs @@ -57,6 +57,7 @@ impl<'vir, Curr, Next> Debug for BinOpGenData<'vir, Curr, Next> { BinOpKind::CmpLe => "<=", BinOpKind::And => "&&", BinOpKind::Add => "+", + BinOpKind::Sub => "-", })?; self.rhs.fmt(f)?; write!(f, ")") From 8e28cbfb6fbd3f9ee2e033989d8e3077bee71cce Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Fri, 27 Oct 2023 12:44:02 +0200 Subject: [PATCH 02/18] Implement pure, trusted, snapshot_equality and field access --- prusti-encoder/src/encoders/mir_impure.rs | 129 +++++++++++++----- prusti-encoder/src/encoders/mir_pure.rs | 126 ++++++++++++++--- .../src/encoders/mir_pure_function.rs | 77 ++++++++--- prusti-encoder/src/encoders/pure/spec.rs | 19 +-- prusti-encoder/src/lib.rs | 21 ++- 5 files changed, 287 insertions(+), 85 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 619bfee2f00..71533edd1b7 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -161,29 +161,49 @@ impl TaskEncoder for MirImpureEncoder { posts.push(local_defs.locals[mir::RETURN_PLACE].impure_pred); posts.extend(spec_posts); - let mut visitor = EncoderVisitor { - vcx, - deps, - local_decls: &body.local_decls, - //ssa_analysis, - fpcs_analysis, - local_defs, - tmp_ctr: 0, + // TODO: dedup with mir_pure + let attrs = vcx.tcx.get_attrs_unchecked(def_id); + let is_trusted = attrs + .iter() + .filter(|attr| !attr.is_doc_comment()) + .map(|attr| attr.get_normal_item()) + .any(|item| { + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "trusted" + }); + + let blocks = if is_trusted { + None + } + else { + let mut visitor = EncoderVisitor { + vcx, + deps, + local_decls: &body.local_decls, + //ssa_analysis, + fpcs_analysis, + local_defs, + + tmp_ctr: 0, + + current_fpcs: None, + + current_stmts: None, + current_terminator: None, + encoded_blocks, + }; + visitor.visit_body(&body); - current_fpcs: None, + visitor.encoded_blocks.push(vcx.alloc(vir::CfgBlockData { + label: vcx.alloc(vir::CfgBlockLabelData::End), + stmts: &[], + terminator: vcx.alloc(vir::TerminatorStmtData::Exit), + })); - current_stmts: None, - current_terminator: None, - encoded_blocks, + Some(vcx.alloc_slice(&visitor.encoded_blocks)) }; - visitor.visit_body(&body); - - visitor.encoded_blocks.push(vcx.alloc(vir::CfgBlockData { - label: vcx.alloc(vir::CfgBlockLabelData::End), - stmts: &[], - terminator: vcx.alloc(vir::TerminatorStmtData::Exit), - })); Ok((MirImpureEncoderOutput { method: vcx.alloc(vir::MethodData { @@ -192,7 +212,7 @@ impl TaskEncoder for MirImpureEncoder { rets: &[], pres: vcx.alloc_slice(&pres), posts: vcx.alloc_slice(&posts), - blocks: Some(vcx.alloc_slice(&visitor.encoded_blocks)), + blocks, }), }, ())) }) @@ -803,18 +823,65 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { _ => todo!(), }; - let func_out = self.deps.require_ref::( - *func_def_id, - ).unwrap(); + // TODO: dedup with mir_pure + let attrs = self.vcx.tcx.get_attrs_unchecked(*func_def_id); + let is_pure = attrs.iter() + .filter(|attr| !attr.is_doc_comment()) + .map(|attr| attr.get_normal_item()).any(|item| + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "pure" + ); + + let dest = self.encode_place(Place::from(*destination)); + let call_args = args.iter().map(|op| + if is_pure { + self.encode_operand_snap(op) + } + else { + self.encode_operand(op) + } + ); + // self.encode_operand(op) + + if is_pure { + let func_args = call_args.collect::>(); + let pure_func_name = self.deps.require_ref::( + *func_def_id, + ).unwrap().function_name; + + let pure_func_app = self.vcx.mk_func_app(pure_func_name, &func_args); + + let method_assign = { + //TODO: can we get the method_assign is a better way? Maybe from the MirFunctionEncoder? + let body = self.vcx.body.borrow_mut().get_impure_fn_body_identity(func_def_id.expect_local()); + let return_type = self.deps + .require_ref::(body.return_ty()) + .unwrap(); + return_type.method_assign + }; + + self.stmt(vir::StmtData::MethodCall(self.vcx.alloc(vir::MethodCallData { + targets: &[], + method: method_assign, + args: self.vcx.alloc_slice(&[dest, pure_func_app]), + }))); + } + else { + let meth_args = std::iter::once(dest) + .chain(call_args) + .collect::>(); + let func_out = self.deps.require_ref::( + *func_def_id, + ).unwrap(); + + self.stmt(vir::StmtData::MethodCall(self.vcx.alloc(vir::MethodCallData { + targets: &[], + method: func_out.method_name, + args: self.vcx.alloc_slice(&meth_args), + }))); + } - let destination = self.encode_place(Place::from(*destination)); - let args = args.iter().map(|op| self.encode_operand(op)); - let args: Vec<_> = std::iter::once(destination).chain(args).collect(); - self.stmt(vir::StmtData::MethodCall(self.vcx.alloc(vir::MethodCallData { - targets: &[], - method: func_out.method_name, - args: self.vcx.alloc_slice(&args), - }))); self.vcx.alloc(vir::TerminatorStmtData::Goto( self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.unwrap().as_usize())), )) diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index bdf59ab05d1..3e3bbaac4df 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -9,6 +9,7 @@ use task_encoder::{ TaskEncoderDependencies, }; use std::collections::HashMap; +use crate::encoders::{ViperTupleEncoder, TypeEncoder}; pub struct MirPureEncoder; @@ -448,9 +449,10 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> TyKind::FnDef(def_id, arg_tys) => { // TODO: this attribute extraction should be done elsewhere? let attrs = self.vcx.tcx.get_attrs_unchecked(*def_id); - attrs.iter() + let normal_attrs = attrs.iter() .filter(|attr| !attr.is_doc_comment()) - .map(|attr| attr.get_normal_item()) + .map(|attr| attr.get_normal_item()).collect::>(); + normal_attrs.iter() .filter(|item| item.path.segments.len() == 2 && item.path.segments[0].ident.as_str() == "prusti" && item.path.segments[1].ident.as_str() == "builtin") @@ -467,6 +469,66 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> } _ => panic!("illegal prusti::builtin"), }); + + + let is_pure = normal_attrs.iter().any(|item| + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "pure" + ); + + // TODO: detect snapshot_equality properly + let is_snapshot_eq = self.vcx.tcx.opt_item_name(*def_id).map(|e| e.as_str() == "snapshot_equality") == Some(true) + && self.vcx.tcx.crate_name(def_id.krate).as_str() == "prusti_contracts"; + + let func_call = if is_pure { + assert!(builtin.is_none(), "Function is pure and builtin?"); + let pure_func = self.deps.require_ref::(*def_id).unwrap().function_name; + + let encoded_args = args.iter() + .map(|oper| self.encode_operand(&new_curr_ver, oper)) + .collect::>(); + + let func_call = self.vcx.mk_func_app(pure_func, &encoded_args); + + Some(func_call) + } + + else if is_snapshot_eq { + assert!(builtin.is_none(), "Function is snapshot_equality and builtin?"); + let encoded_args = args.iter() + .map(|oper| self.encode_operand(&new_curr_ver, oper)) + .collect::>(); + + assert_eq!(encoded_args.len(), 2); + + + let eq_expr = self.vcx.alloc(vir::ExprGenData::BinOp(self.vcx.alloc(vir::BinOpGenData { + kind: vir::BinOpKind::CmpEq, + lhs: encoded_args[0], + rhs: encoded_args[1], + }))); + + + // TODO: type encoder + Some(self.vcx.mk_func_app("s_Bool_cons", &[eq_expr])) + } + else { + None + }; + + + if let Some(func_call) = func_call { + let mut term_update = Update::new(); + assert!(destination.projection.is_empty()); + self.bump_version(&mut term_update, destination.local, func_call); + term_update.add_to_map(&mut new_curr_ver); + + // walk rest of CFG + let end_update = self.encode_cfg(&new_curr_ver, target.unwrap(), end); + + return stmt_update.merge(term_update).merge(end_update); + } } _ => todo!(), } @@ -769,24 +831,48 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> } let local = self.mk_local_ex(place.local, curr_ver[&place.local]); - if !place.projection.is_empty() { - // TODO: for now, assume this is a closure argument - assert_eq!(place.projection[0], mir::ProjectionElem::Deref); - assert!(matches!(place.projection[1], mir::ProjectionElem::Field(..))); - assert_eq!(place.projection[2], mir::ProjectionElem::Deref); - assert_eq!(place.projection.len(), 3); - let upvars = match self.body.local_decls[place.local].ty.peel_refs().kind() { - TyKind::Closure(_def_id, args) => args.as_closure().upvar_tys().collect::>().len(), - _ => unreachable!(), - }; - let tuple_ref = self.deps.require_ref::( - upvars, - ).unwrap(); - return match place.projection[1] { - mir::ProjectionElem::Field(idx, _) => tuple_ref.mk_elem(self.vcx, local, idx.as_usize()), - _ => todo!(), - }; + let mut partent_ty = self.body.local_decls[place.local].ty; + let mut expr = local; + + for elem in place.projection { + (partent_ty, expr) = self.encode_place_element(partent_ty, elem, expr); + } + + expr + } + + + fn encode_place_element(&mut self, parent_ty: ty::Ty<'vir>, elem: mir::PlaceElem<'vir>, expr: ExprRet<'vir>) -> (ty::Ty<'vir>, ExprRet<'vir>) { + let parent_ty = parent_ty.peel_refs(); + + match elem { + mir::ProjectionElem::Deref => { + (parent_ty, expr) + } + mir::ProjectionElem::Field(field_idx, field_ty) => { + let field_idx= field_idx.as_usize(); + match parent_ty.kind() { + TyKind::Closure(_def_id, args) => { + let upvars = args.as_closure().upvar_tys().collect::>().len(); + let tuple_ref = self.deps.require_ref::( + upvars, + ).unwrap(); + let tup = tuple_ref.mk_elem(self.vcx, expr, field_idx); + + (field_ty, tup) + } + _ => { + let local_encoded_ty = self.deps.require_ref::(parent_ty).unwrap(); + let struct_like = local_encoded_ty.expect_structlike(); + let proj = struct_like.field_read[field_idx]; + + let app = self.vcx.mk_func_app(proj, self.vcx.alloc_slice(&[expr])); + + (field_ty, app) + } + } + } + _ => todo!("Unsupported ProjectionElem {:?}", elem), } - local } } diff --git a/prusti-encoder/src/encoders/mir_pure_function.rs b/prusti-encoder/src/encoders/mir_pure_function.rs index b506afd6e3e..d977550fe2b 100644 --- a/prusti-encoder/src/encoders/mir_pure_function.rs +++ b/prusti-encoder/src/encoders/mir_pure_function.rs @@ -73,32 +73,65 @@ impl TaskEncoder for MirFunctionEncoder { tracing::debug!("encoding {def_id:?}"); let function_name = vir::vir_format!(vcx, "f_{}", vcx.tcx.item_name(def_id)); - deps.emit_output_ref::(def_id, MirFunctionEncoderOutputRef { function_name }); + deps.emit_output_ref::(*task_key, MirFunctionEncoderOutputRef { function_name }); + + let local_def_id = def_id.expect_local(); + let body = vcx.body.borrow_mut().get_impure_fn_body_identity(local_def_id); - let local_defs = deps.require_local::( - def_id, - ).unwrap(); let spec = deps.require_local::( (def_id, true) ).unwrap(); - let func_args: Vec<_> = (1..=local_defs.arg_count).map(mir::Local::from).map(|arg| vcx.alloc(vir::LocalDeclData { - name: local_defs.locals[arg].local.name, - ty: local_defs.locals[arg].snapshot, - })).collect(); - - // Encode the body of the function - let expr = deps - .require_local::(MirPureEncoderTask { - encoding_depth: 0, - parent_def_id: def_id, - promoted: None, - param_env: vcx.tcx.param_env(def_id), - substs: ty::List::identity_for_item(vcx.tcx, def_id), - }) + let mut func_args = Vec::with_capacity(body.arg_count); + + for (arg_idx0, arg_local) in body.args_iter().enumerate() { + let arg_idx = arg_idx0 + 1; // enumerate is 0 based but we want to start at 1 + + let arg_decl = body.local_decls.get(arg_local).unwrap(); + let arg_type_ref = deps.require_ref::(arg_decl.ty).unwrap(); + + let name_p = vir::vir_format!(vcx, "_{arg_idx}p"); + func_args.push(vcx.alloc(vir::LocalDeclData { + name: name_p, + ty: arg_type_ref.snapshot, + })); + } + + // TODO: dedup with mir_pure + let attrs = vcx.tcx.get_attrs_unchecked(def_id); + let is_trusted = attrs + .iter() + .filter(|attr| !attr.is_doc_comment()) + .map(|attr| attr.get_normal_item()) + .any(|item| { + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "trusted" + }); + + let expr = if is_trusted { + None + } else { + // Encode the body of the function + let expr = deps + .require_local::(MirPureEncoderTask { + encoding_depth: 0, + parent_def_id: def_id, + promoted: None, + param_env: vcx.tcx.param_env(def_id), + substs: ty::List::identity_for_item(vcx.tcx, def_id), + }) + .unwrap() + .expr; + + Some(expr.reify(vcx, (def_id, &spec.pre_args.split_last().unwrap().1))) + }; + + // Snapshot type of the return type + let ret = deps + .require_ref::(body.return_ty()) .unwrap() - .expr; - let expr = expr.reify(vcx, (def_id, &spec.pre_args[1..])); + .snapshot; tracing::debug!("finished {def_id:?}"); @@ -107,10 +140,10 @@ impl TaskEncoder for MirFunctionEncoder { function: vcx.alloc(vir::FunctionData { name: function_name, args: vcx.alloc_slice(&func_args), - ret: local_defs.locals[mir::RETURN_PLACE].snapshot, + ret, pres: vcx.alloc_slice(&spec.pres), posts: vcx.alloc_slice(&spec.posts), - expr: Some(expr), + expr, }), }, (), diff --git a/prusti-encoder/src/encoders/pure/spec.rs b/prusti-encoder/src/encoders/pure/spec.rs index 462783f8f74..8c3a88a77d1 100644 --- a/prusti-encoder/src/encoders/pure/spec.rs +++ b/prusti-encoder/src/encoders/pure/spec.rs @@ -71,20 +71,23 @@ impl TaskEncoder for MirSpecEncoder { ).unwrap(); vir::with_vcx(|vcx| { - let pre_args: Vec<_> = (0..=local_defs.arg_count) + let mut pre_args: Vec<_> = (1..=local_defs.arg_count) .map(mir::Local::from) .map(|local| { if pure { - if local.index() == 0 { - vcx.mk_local_ex(vir::vir_format!(vcx, "result")) - } else { - local_defs.locals[local].local_ex - } + local_defs.locals[local].local_ex } else { local_defs.locals[local].impure_snap } }) .collect(); + + pre_args.push(if pure { + vcx.mk_local_ex(vir::vir_format!(vcx, "result")) + } else { + local_defs.locals[mir::Local::from(0u32)].impure_snap + }); + let pre_args = vcx.alloc_slice(&pre_args); let to_bool = deps.require_ref::( @@ -105,14 +108,14 @@ impl TaskEncoder for MirSpecEncoder { to_bool, &[expr], ); - expr.reify(vcx, (*spec_def_id, &pre_args[1..])) + expr.reify(vcx, (*spec_def_id, &pre_args.split_last().unwrap().1)) }).collect::>>(); let post_args = if pure { pre_args } else { let post_args: Vec<_> = pre_args.iter().enumerate().map(|(idx, arg)| { - if idx == 0 { + if idx == pre_args.len() - 1 { arg } else { vcx.alloc(vir::ExprData::Old(arg)) diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index 37b008342ca..6ca82e1d8a9 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -142,13 +142,26 @@ pub fn test_entrypoint<'tcx>( let res = crate::encoders::MirImpureEncoder::encode(def_id.to_def_id()); assert!(res.is_ok()); - let kind = crate::encoders::with_def_spec(|def_spec| - def_spec + let (is_pure, is_trusted) = crate::encoders::with_def_spec(|def_spec| { + let base_spec = def_spec .get_proc_spec(&def_id.to_def_id()) - .map(|e| e.base_spec.kind) + .map(|e| &e.base_spec); + + let is_pure = base_spec.and_then(|kind| kind.kind.is_pure().ok()).unwrap_or_default(); + let is_trusted = matches!(base_spec.map(|spec| spec.trusted), Some(SpecificationItem::Inherent( + true, + ))); + (is_pure, is_trusted) + } ); - if kind.and_then(|kind| kind.is_pure().ok()).unwrap_or_default() { + if ! (is_trusted && is_pure) { + let res = crate::encoders::MirImpureEncoder::encode(def_id.to_def_id()); + assert!(res.is_ok()); + } + + + if is_pure { tracing::debug!("Encoding {def_id:?} as a pure function because it is labeled as pure"); let res = crate::encoders::MirFunctionEncoder::encode(def_id.to_def_id()); assert!(res.is_ok()); From 3f7a7e6d4e35da646d6169fe751c3f57e47f6777 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Fri, 27 Oct 2023 12:57:10 +0200 Subject: [PATCH 03/18] Change axioms and temporarily fix some triggering issues --- prusti-encoder/src/encoders/mir_pure.rs | 131 +++++++++++++++++++++++- prusti-encoder/src/encoders/typ.rs | 9 +- 2 files changed, 133 insertions(+), 7 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index 3e3bbaac4df..369f48edfd2 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -30,6 +30,122 @@ pub struct MirPureEncoderOutput<'vir> { pub expr: ExprRet<'vir>, } +/// Optimize a vir expresison +/// +/// This is a temporary fix for the issue where variables that are quantified over and are then stored in a let binding +/// cause issues with triggering somehow. +/// +/// This was also intended to make debugging easier by making the resulting viper code a bit more readable +/// +/// This should be replaced with a proper solution +fn opt<'vir, Cur, Next> (expr: vir::ExprGen<'vir, Cur, Next>, rename: &mut HashMap) -> vir::ExprGen<'vir, Cur, Next>{ + match expr { + + vir::ExprGenData::Local(d) => { + let nam = rename.get(d.name).map(|e| e.as_str()).unwrap_or(d.name).to_owned(); + vir::with_vcx(move |vcx| { + vcx.mk_local_ex(vcx.alloc_str(&nam)) + }) + }, + + vir::ExprGenData::Let(vir::LetGenData{name, val, expr}) => { + + let val = opt(val, rename); + + match val { + // let name = loc.name + vir::ExprGenData::Local(loc) => { + let t = rename.get(loc.name).map(|e| e.to_owned()).unwrap_or(loc.name.to_string()); + assert!(rename.insert(name.to_string(), t).is_none()); + return opt(expr, rename) + } + _ => {} + } + + let expr = opt(expr, rename); + + match expr { + vir::ExprGenData::Local(inner_local) => { + if &inner_local.name == name { + return val + } + } + _ => {} + } + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::Let(vcx.alloc(vir::LetGenData{name, val, expr}))) + }) + }, + vir::ExprGenData::FuncApp( vir::FuncAppGenData{target, args}) => { + let n_args = args.iter().map(|arg| opt(arg, rename)).collect::>(); + vir::with_vcx(move |vcx| { + vcx.mk_func_app(target, &n_args) + }) + } + + + vir::ExprGenData::PredicateApp( vir::PredicateAppGenData{target, args}) => { + let n_args = args.iter().map(|arg| opt(arg, rename)).collect::>(); + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::PredicateApp(vcx.alloc(vir::PredicateAppGenData { + target, + args: vcx.alloc_slice(&n_args), + }))) + }) + } + + vir::ExprGenData::Forall(vir::ForallGenData{qvars, triggers, body}) => { + let body = opt(body, rename); + + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::Forall(vcx.alloc(vir::ForallGenData{qvars, triggers, body}))) + }) + } + + + vir::ExprGenData::Ternary(vir::TernaryGenData{cond, then, else_}) => { + let cond = opt(cond, rename); + let then = opt(then, rename); + let else_ = opt(else_, rename); + + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::Ternary(vcx.alloc(vir::TernaryGenData{cond, then, else_}))) + }) + } + + vir::ExprGenData::BinOp(vir::BinOpGenData {kind, lhs, rhs}) => { + let lhs = opt(lhs, rename); + let rhs = opt(rhs, rename); + + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::BinOp(vcx.alloc(vir::BinOpGenData{kind: kind.clone(), lhs, rhs}))) + }) + } + + + vir::ExprGenData::UnOp(vir::UnOpGenData {kind, expr}) =>{ + let expr = opt(expr, rename); + vir::with_vcx(move |vcx| { + vcx.alloc(vir::ExprGenData::UnOp(vcx.alloc(vir::UnOpGenData{kind: kind.clone(), expr}))) + }) + } + + todo@( + vir::ExprGenData::Unfolding(_) | + vir::ExprGenData::Field(_, _) | + vir::ExprGenData::Old(_) | + vir::ExprGenData::AccField(_)| + vir::ExprGenData::Lazy(_, _)) + => todo, + + other @ ( + vir::ExprGenData::Const(_) | + vir::ExprGenData::Todo(_) ) => other + + + } +} + use std::cell::RefCell; thread_local! { static CACHE: task_encoder::CacheStaticRef = RefCell::new(Default::default()); @@ -118,7 +234,20 @@ impl TaskEncoder for MirPureEncoder { assert_eq!(lctx.1.len(), body.arg_count); use vir::Reify; - expr_inner.reify(vcx, lctx) + let expr_inner = expr_inner.reify(vcx, lctx); + + let expr_inner = if true { + tracing::warn!("before opt {expr_inner:?}"); + let mut rename = HashMap::new(); + let opted = opt(expr_inner, &mut rename); + tracing::warn!("after opt {opted:?}"); + opted + } + else { + expr_inner + }; + + expr_inner }), )) }); diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index 7ba981fc377..48b8a2ba5bd 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -468,12 +468,7 @@ impl TaskEncoder for TypeEncoder { name: vir::vir_format!(vcx, "ax_{name_s}_cons_read_{read_idx}"), expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { qvars: cons_qvars.clone(), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[cons_call], - ), - ])]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpEq, lhs: vcx.mk_func_app( @@ -620,6 +615,7 @@ impl TaskEncoder for TypeEncoder { function s_Bool_cons(Bool): s_Bool; function s_Bool_val(s_Bool): Bool; axiom_inverse(s_Bool_val, s_Bool_cons, Bool); + axiom_inverse(s_Bool_cons, s_Bool_val, s_Bool); } }, predicate: mk_simple_predicate(vcx, "p_Bool", "f_Bool"), function_unreachable: mk_unreachable(vcx, "s_Bool", ty_s), @@ -663,6 +659,7 @@ impl TaskEncoder for TypeEncoder { function [name_cons](Int): [ty_s]; function [name_val]([ty_s]): Int; axiom_inverse([name_val], [name_cons], Int); + axiom_inverse([name_cons], [name_val], [ty_s]); } }, predicate: mk_simple_predicate(vcx, name_p, name_field), function_unreachable: mk_unreachable(vcx, name_s, ty_s), From fa4aa23074ba6712abfe8bc7f0715fc509be5bc0 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Fri, 27 Oct 2023 13:59:21 +0200 Subject: [PATCH 04/18] Cleanup --- prusti-encoder/src/encoders/mir_impure.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 71533edd1b7..5b06f614487 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -826,12 +826,12 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { // TODO: dedup with mir_pure let attrs = self.vcx.tcx.get_attrs_unchecked(*func_def_id); let is_pure = attrs.iter() - .filter(|attr| !attr.is_doc_comment()) - .map(|attr| attr.get_normal_item()).any(|item| - item.path.segments.len() == 2 - && item.path.segments[0].ident.as_str() == "prusti" - && item.path.segments[1].ident.as_str() == "pure" - ); + .filter(|attr| !attr.is_doc_comment()) + .map(|attr| attr.get_normal_item()).any(|item| + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "pure" + ); let dest = self.encode_place(Place::from(*destination)); let call_args = args.iter().map(|op| @@ -842,7 +842,6 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { self.encode_operand(op) } ); - // self.encode_operand(op) if is_pure { let func_args = call_args.collect::>(); @@ -853,7 +852,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { let pure_func_app = self.vcx.mk_func_app(pure_func_name, &func_args); let method_assign = { - //TODO: can we get the method_assign is a better way? Maybe from the MirFunctionEncoder? + //TODO: Can we get `method_assign` in a better way? Maybe from the MirFunctionEncoder? let body = self.vcx.body.borrow_mut().get_impure_fn_body_identity(func_def_id.expect_local()); let return_type = self.deps .require_ref::(body.return_ty()) From 070d50ec92123e34cabd85c2b5574cb0dda8d76f Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Thu, 2 Nov 2023 12:22:53 +0100 Subject: [PATCH 05/18] Run rustfmt --- mir-ssa-analysis/src/lib.rs | 64 +- mir-state-analysis/src/free_pcs/impl/local.rs | 7 +- .../src/free_pcs/impl/triple.rs | 15 +- mir-state-analysis/src/utils/repacker.rs | 3 +- prusti-contracts/prusti-contracts/src/lib.rs | 2 +- prusti-contracts/prusti-specs/src/lib.rs | 16 +- prusti-encoder/src/encoders/generic.rs | 85 +- prusti-encoder/src/encoders/local_def.rs | 65 +- prusti-encoder/src/encoders/mir_builtin.rs | 334 ++++---- prusti-encoder/src/encoders/mir_impure.rs | 585 +++++++------ prusti-encoder/src/encoders/mir_pure.rs | 799 ++++++++++-------- .../src/encoders/mir_pure_function.rs | 20 +- prusti-encoder/src/encoders/mod.rs | 32 +- prusti-encoder/src/encoders/pure/spec.rs | 127 +-- prusti-encoder/src/encoders/spec.rs | 34 +- prusti-encoder/src/encoders/typ.rs | 709 ++++++++-------- prusti-encoder/src/encoders/viper_tuple.rs | 127 +-- prusti-encoder/src/lib.rs | 53 +- prusti-interface/src/environment/body.rs | 24 +- prusti-interface/src/environment/query.rs | 121 +-- prusti-interface/src/specs/encoder.rs | 5 +- prusti-interface/src/specs/external.rs | 5 +- prusti-interface/src/utils.rs | 9 +- prusti/src/callbacks.rs | 10 +- prusti/src/driver.rs | 26 +- prusti/src/verifier.rs | 58 +- task-encoder/src/lib.rs | 398 ++++----- test-crates/src/main.rs | 62 +- tracing/proc-macro-tracing/src/lib.rs | 13 +- tracing/src/lib.rs | 2 +- vir-proc-macro/src/lib.rs | 36 +- vir/src/context.rs | 47 +- vir/src/data.rs | 2 +- vir/src/debug.rs | 91 +- vir/src/gendata.rs | 89 +- vir/src/genrefs.rs | 15 +- vir/src/macros.rs | 26 +- vir/src/reify.rs | 48 +- 38 files changed, 2305 insertions(+), 1859 deletions(-) diff --git a/mir-ssa-analysis/src/lib.rs b/mir-ssa-analysis/src/lib.rs index 0671abbe7e3..0214618ba9b 100644 --- a/mir-ssa-analysis/src/lib.rs +++ b/mir-ssa-analysis/src/lib.rs @@ -1,9 +1,6 @@ #![feature(rustc_private)] -use prusti_rustc_interface::{ - index::IndexVec, - middle::mir, -}; +use prusti_rustc_interface::{index::IndexVec, middle::mir}; use std::collections::HashMap; pub type SsaVersion = usize; @@ -34,9 +31,15 @@ impl SsaAnalysis { let local_count = body.local_decls.len(); let block_count = body.basic_blocks.reverse_postorder().len(); - let block_started = std::iter::repeat(false).take(block_count).collect::>(); - let initial_version_in_block = std::iter::repeat(None).take(block_count).collect::>(); - let final_version_in_block = std::iter::repeat(None).take(block_count).collect::>(); + let block_started = std::iter::repeat(false) + .take(block_count) + .collect::>(); + let initial_version_in_block = std::iter::repeat(None) + .take(block_count) + .collect::>(); + let final_version_in_block = std::iter::repeat(None) + .take(block_count) + .collect::>(); let mut ssa_visitor = SsaVisitor { version_counter: IndexVec::from_raw(vec![0; local_count]), @@ -69,11 +72,7 @@ struct SsaVisitor { } impl<'tcx> SsaVisitor { - fn walk_block<'a>( - &mut self, - body: &'a mir::Body<'tcx>, - block: mir::BasicBlock, - ) { + fn walk_block<'a>(&mut self, body: &'a mir::Body<'tcx>, block: mir::BasicBlock) { if self.final_version_in_block[block.as_usize()].is_some() { return; } @@ -95,7 +94,9 @@ impl<'tcx> SsaVisitor { // TODO: cfg cycles prev_versions.push(( *pred, - self.final_version_in_block[pred.as_usize()].as_ref().unwrap()[local.into()], + self.final_version_in_block[pred.as_usize()] + .as_ref() + .unwrap()[local.into()], )); } if prev_versions.is_empty() { @@ -117,7 +118,9 @@ impl<'tcx> SsaVisitor { assert!(self.analysis.phi.insert(block, phis).is_none()); } - assert!(self.initial_version_in_block[block.as_usize()].replace(initial_versions.clone()).is_none()); + assert!(self.initial_version_in_block[block.as_usize()] + .replace(initial_versions.clone()) + .is_none()); use mir::visit::Visitor; self.last_version = initial_versions; @@ -127,12 +130,14 @@ impl<'tcx> SsaVisitor { .map(|local| self.last_version[local.into()]) .collect::>(); for local in 0..self.local_count { - self.analysis.version.insert(( - body.terminator_loc(block), - local.into(), - ), final_versions[local.into()]); + self.analysis.version.insert( + (body.terminator_loc(block), local.into()), + final_versions[local.into()], + ); } - assert!(self.final_version_in_block[block.as_usize()].replace(final_versions).is_none()); + assert!(self.final_version_in_block[block.as_usize()] + .replace(final_versions) + .is_none()); use prusti_rustc_interface::data_structures::graph::WithSuccessors; for succ in body.basic_blocks.successors(block) { @@ -152,7 +157,9 @@ impl<'tcx> mir::visit::Visitor<'tcx> for SsaVisitor { ) { let local = place.local; - assert!(self.analysis.version + assert!(self + .analysis + .version .insert((location, local), self.last_version[local]) .is_none()); @@ -161,12 +168,17 @@ impl<'tcx> mir::visit::Visitor<'tcx> for SsaVisitor { let new_version = self.version_counter[local] + 1; self.version_counter[local] = new_version; self.last_version[local] = new_version; - assert!(self.analysis.updates - .insert(location, SsaUpdate { - local, - old_version, - new_version, - }) + assert!(self + .analysis + .updates + .insert( + location, + SsaUpdate { + local, + old_version, + new_version, + } + ) .is_none()); } } diff --git a/mir-state-analysis/src/free_pcs/impl/local.rs b/mir-state-analysis/src/free_pcs/impl/local.rs index 9067adb669c..e0a09437d86 100644 --- a/mir-state-analysis/src/free_pcs/impl/local.rs +++ b/mir-state-analysis/src/free_pcs/impl/local.rs @@ -42,7 +42,9 @@ impl Default for CapabilityLocal<'_> { impl<'tcx> CapabilityLocal<'tcx> { pub fn get_allocated_mut(&mut self) -> &mut CapabilityProjections<'tcx> { - let Self::Allocated(cps) = self else { panic!("Expected allocated local") }; + let Self::Allocated(cps) = self else { + panic!("Expected allocated local") + }; cps } pub fn new(local: Local, perm: CapabilityKind) -> Self { @@ -198,8 +200,7 @@ impl<'tcx> CapabilityProjections<'tcx> { } let mut ops = Vec::new(); for (to, from, _) in collapsed { - let removed_perms: Vec<_> = - old_caps.extract_if(|old, _| to.is_prefix(*old)).collect(); + let removed_perms: Vec<_> = old_caps.extract_if(|old, _| to.is_prefix(*old)).collect(); let perm = removed_perms .iter() .fold(CapabilityKind::Exclusive, |acc, (_, p)| { diff --git a/mir-state-analysis/src/free_pcs/impl/triple.rs b/mir-state-analysis/src/free_pcs/impl/triple.rs index 91370d382be..da6ba700a6c 100644 --- a/mir-state-analysis/src/free_pcs/impl/triple.rs +++ b/mir-state-analysis/src/free_pcs/impl/triple.rs @@ -58,7 +58,8 @@ impl<'tcx> Visitor<'tcx> for Fpcs<'_, 'tcx> { self.ensures_unalloc(local); } &Retag(_, box place) => self.requires_exclusive(place), - AscribeUserType(..) | PlaceMention(..) | Coverage(..) | Intrinsic(..) | ConstEvalCounter | Nop => (), + AscribeUserType(..) | PlaceMention(..) | Coverage(..) | Intrinsic(..) + | ConstEvalCounter | Nop => (), }; } @@ -88,11 +89,19 @@ impl<'tcx> Visitor<'tcx> for Fpcs<'_, 'tcx> { } } } - &Drop { place, replace: false, .. } => { + &Drop { + place, + replace: false, + .. + } => { self.requires_write(place); self.ensures_write(place); } - &Drop { place, replace: true, .. } => { + &Drop { + place, + replace: true, + .. + } => { self.requires_write(place); self.ensures_exclusive(place); } diff --git a/mir-state-analysis/src/utils/repacker.rs b/mir-state-analysis/src/utils/repacker.rs index 21a88b281d8..0a2cd4c317d 100644 --- a/mir-state-analysis/src/utils/repacker.rs +++ b/mir-state-analysis/src/utils/repacker.rs @@ -11,8 +11,7 @@ use prusti_rustc_interface::{ index::bit_set::BitSet, middle::{ mir::{ - tcx::PlaceTy, Body, HasLocalDecls, Local, Mutability, Place as MirPlace, - ProjectionElem, + tcx::PlaceTy, Body, HasLocalDecls, Local, Mutability, Place as MirPlace, ProjectionElem, }, ty::{TyCtxt, TyKind}, }, diff --git a/prusti-contracts/prusti-contracts/src/lib.rs b/prusti-contracts/prusti-contracts/src/lib.rs index e36319181df..eced6374b07 100644 --- a/prusti-contracts/prusti-contracts/src/lib.rs +++ b/prusti-contracts/prusti-contracts/src/lib.rs @@ -342,7 +342,7 @@ pub fn old(arg: T) -> T { /// Universal quantifier. /// /// This is a Prusti-internal representation of the `forall` syntax. -#[prusti::builtin="forall"] +#[prusti::builtin = "forall"] pub fn forall(_trigger_set: T, _closure: F) -> bool { true } diff --git a/prusti-contracts/prusti-specs/src/lib.rs b/prusti-contracts/prusti-specs/src/lib.rs index 3403260c345..a16b21cd4dc 100644 --- a/prusti-contracts/prusti-specs/src/lib.rs +++ b/prusti-contracts/prusti-specs/src/lib.rs @@ -77,7 +77,9 @@ fn extract_prusti_attributes( // tokens identical to the ones passed by the native procedural // macro call. let mut iter = attr.tokens.into_iter(); - let TokenTree::Group(group) = iter.next().unwrap() else { unreachable!() }; + let TokenTree::Group(group) = iter.next().unwrap() else { + unreachable!() + }; assert!(iter.next().is_none(), "Unexpected shape of an attribute."); group.stream() } @@ -596,10 +598,14 @@ pub fn refine_trait_spec(_attr: TokenStream, tokens: TokenStream) -> TokenStream let parsed_predicate = handle_result!(predicate::parse_predicate_in_impl(makro.mac.tokens.clone())); - let ParsedPredicate::Impl(predicate) = parsed_predicate else { unreachable!() }; + let ParsedPredicate::Impl(predicate) = parsed_predicate else { + unreachable!() + }; // Patch spec function: Rewrite self with _self: - let syn::Item::Fn(spec_function) = predicate.spec_function else { unreachable!() }; + let syn::Item::Fn(spec_function) = predicate.spec_function else { + unreachable!() + }; generated_spec_items.push(spec_function); // Add patched predicate function to new items @@ -872,7 +878,9 @@ fn extract_prusti_attributes_for_types( // tokens identical to the ones passed by the native procedural // macro call. let mut iter = attr.tokens.into_iter(); - let TokenTree::Group(group) = iter.next().unwrap() else { unreachable!() }; + let TokenTree::Group(group) = iter.next().unwrap() else { + unreachable!() + }; group.stream() } }; diff --git a/prusti-encoder/src/encoders/generic.rs b/prusti-encoder/src/encoders/generic.rs index dff23c44cb2..bfa8aac6502 100644 --- a/prusti-encoder/src/encoders/generic.rs +++ b/prusti-encoder/src/encoders/generic.rs @@ -1,7 +1,4 @@ -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct GenericEncoder; @@ -39,7 +36,8 @@ impl TaskEncoder for GenericEncoder { type EncodingError = GenericEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, GenericEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, GenericEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -57,38 +55,49 @@ impl TaskEncoder for GenericEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { - deps.emit_output_ref::(*task_key, GenericEncoderOutputRef { - snapshot_param_name: "s_Param", - predicate_param_name: "p_Param", - domain_type_name: "s_Type", - }); - vir::with_vcx(|vcx| Ok((GenericEncoderOutput { - snapshot_param: vir::vir_domain! { vcx; domain s_Param {} }, - predicate_param: vir::vir_predicate! { vcx; predicate p_Param(self_p: Ref/*, self_s: s_Param*/) }, - domain_type: vir::vir_domain! { vcx; domain s_Type { - // TODO: only emit these when the types are actually used? - // emit instead from type encoder, to be consistent with the ADT case? - unique function s_Type_Bool(): s_Type; - unique function s_Type_Int_isize(): s_Type; - unique function s_Type_Int_i8(): s_Type; - unique function s_Type_Int_i16(): s_Type; - unique function s_Type_Int_i32(): s_Type; - unique function s_Type_Int_i64(): s_Type; - unique function s_Type_Int_i128(): s_Type; - unique function s_Type_Uint_usize(): s_Type; - unique function s_Type_Uint_u8(): s_Type; - unique function s_Type_Uint_u16(): s_Type; - unique function s_Type_Uint_u32(): s_Type; - unique function s_Type_Uint_u64(): s_Type; - unique function s_Type_Uint_u128(): s_Type; - } }, - }, ()))) + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { + deps.emit_output_ref::( + *task_key, + GenericEncoderOutputRef { + snapshot_param_name: "s_Param", + predicate_param_name: "p_Param", + domain_type_name: "s_Type", + }, + ); + vir::with_vcx(|vcx| { + Ok(( + GenericEncoderOutput { + snapshot_param: vir::vir_domain! { vcx; domain s_Param {} }, + predicate_param: vir::vir_predicate! { vcx; predicate p_Param(self_p: Ref/*, self_s: s_Param*/) }, + domain_type: vir::vir_domain! { vcx; domain s_Type { + // TODO: only emit these when the types are actually used? + // emit instead from type encoder, to be consistent with the ADT case? + unique function s_Type_Bool(): s_Type; + unique function s_Type_Int_isize(): s_Type; + unique function s_Type_Int_i8(): s_Type; + unique function s_Type_Int_i16(): s_Type; + unique function s_Type_Int_i32(): s_Type; + unique function s_Type_Int_i64(): s_Type; + unique function s_Type_Int_i128(): s_Type; + unique function s_Type_Uint_usize(): s_Type; + unique function s_Type_Uint_u8(): s_Type; + unique function s_Type_Uint_u16(): s_Type; + unique function s_Type_Uint_u32(): s_Type; + unique function s_Type_Uint_u64(): s_Type; + unique function s_Type_Uint_u128(): s_Type; + } }, + }, + (), + )) + }) } } diff --git a/prusti-encoder/src/encoders/local_def.rs b/prusti-encoder/src/encoders/local_def.rs index d6720a3d0c9..3b385fb8278 100644 --- a/prusti-encoder/src/encoders/local_def.rs +++ b/prusti-encoder/src/encoders/local_def.rs @@ -1,11 +1,7 @@ -use prusti_rustc_interface::{ - index::IndexVec, - middle::mir, - span::def_id::DefId -}; +use prusti_rustc_interface::{index::IndexVec, middle::mir, span::def_id::DefId}; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use std::cell::RefCell; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct MirLocalDefEncoder; #[derive(Clone, Copy)] @@ -70,32 +66,37 @@ impl TaskEncoder for MirLocalDefEncoder { deps.emit_output_ref::(def_id, ()); vir::with_vcx(|vcx| { - let local_def_id = def_id.expect_local(); - let body = vcx.body.borrow_mut().get_impure_fn_body_identity(local_def_id); - let locals = IndexVec::from_fn_n(|arg: mir::Local| { - let local = vcx.mk_local(vir::vir_format!(vcx, "_{}p", arg.index())); - let ty = deps.require_ref::( - body.local_decls[arg].ty, - ).unwrap(); - let snapshot = ty.snapshot; - let local_ex = vcx.mk_local_ex_local(local); - let impure_snap = vcx.mk_func_app( - ty.function_snap, - &[local_ex], - ); - let impure_pred = vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc(vir::PredicateAppData { - target: ty.predicate_name, - args: vcx.alloc_slice(&[local_ex]), - }))); - LocalDef { - local, - snapshot, - local_ex, - impure_snap, - impure_pred, - ty: vcx.alloc(ty), - } - }, body.local_decls.len()); + let local_def_id = def_id.expect_local(); + let body = vcx + .body + .borrow_mut() + .get_impure_fn_body_identity(local_def_id); + let locals = IndexVec::from_fn_n( + |arg: mir::Local| { + let local = vcx.mk_local(vir::vir_format!(vcx, "_{}p", arg.index())); + let ty = deps + .require_ref::(body.local_decls[arg].ty) + .unwrap(); + let snapshot = ty.snapshot; + let local_ex = vcx.mk_local_ex_local(local); + let impure_snap = vcx.mk_func_app(ty.function_snap, &[local_ex]); + let impure_pred = vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc( + vir::PredicateAppData { + target: ty.predicate_name, + args: vcx.alloc_slice(&[local_ex]), + }, + ))); + LocalDef { + local, + snapshot, + local_ex, + impure_snap, + impure_pred, + ty: vcx.alloc(ty), + } + }, + body.local_decls.len(), + ); let data = MirLocalDefEncoderOutput { locals: vcx.alloc(locals), arg_count: body.arg_count, diff --git a/prusti-encoder/src/encoders/mir_builtin.rs b/prusti-encoder/src/encoders/mir_builtin.rs index 69902ee0094..a3654207246 100644 --- a/prusti-encoder/src/encoders/mir_builtin.rs +++ b/prusti-encoder/src/encoders/mir_builtin.rs @@ -1,11 +1,5 @@ -use prusti_rustc_interface::{ - middle::ty, - middle::mir, -}; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use prusti_rustc_interface::middle::{mir, ty}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct MirBuiltinEncoder; @@ -47,7 +41,8 @@ impl TaskEncoder for MirBuiltinEncoder { type EncodingError = MirBuiltinEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirBuiltinEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirBuiltinEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -65,13 +60,16 @@ impl TaskEncoder for MirBuiltinEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { // TODO: this function is also useful for the type encoder, extract? fn int_name<'tcx>(ty: ty::Ty<'tcx>) -> &'static str { match ty.kind() { @@ -87,172 +85,202 @@ impl TaskEncoder for MirBuiltinEncoder { assert_eq!(*ty.kind(), ty::TyKind::Bool); let name = "mir_unop_not"; - deps.emit_output_ref::(task_key.clone(), MirBuiltinEncoderOutputRef { - name, - }); + deps.emit_output_ref::( + task_key.clone(), + MirBuiltinEncoderOutputRef { name }, + ); /* function mir_unop_not(arg: s_Bool): s_Bool { s_Bool$cons(!s_Bool$val(val)) } */ - let ty_s = deps.require_ref::( - *ty, - ).unwrap().snapshot; - Ok((MirBuiltinEncoderOutput { - function: vcx.alloc(vir::FunctionData { - name, - args: vcx.alloc_slice(&[vcx.mk_local_decl("arg", ty_s)]), - ret: ty_s, - pres: &[], - posts: &[], - expr: Some(vcx.mk_func_app( - "s_Bool_cons", - &[vcx.alloc(vir::ExprData::UnOp(vcx.alloc(vir::UnOpData { - kind: vir::UnOpKind::Not, - expr: vcx.mk_func_app( - "s_Bool_val", - &[vcx.mk_local_ex("arg")], - ), - })))], - )), - }), - }, ())) + let ty_s = deps + .require_ref::(*ty) + .unwrap() + .snapshot; + Ok(( + MirBuiltinEncoderOutput { + function: vcx.alloc(vir::FunctionData { + name, + args: vcx.alloc_slice(&[vcx.mk_local_decl("arg", ty_s)]), + ret: ty_s, + pres: &[], + posts: &[], + expr: Some(vcx.mk_func_app( + "s_Bool_cons", + &[vcx.alloc(vir::ExprData::UnOp( + vcx.alloc(vir::UnOpData { + kind: vir::UnOpKind::Not, + expr: vcx.mk_func_app( + "s_Bool_val", + &[vcx.mk_local_ex("arg")], + ), + }), + ))], + )), + }), + }, + (), + )) } MirBuiltinEncoderTask::UnOp(mir::UnOp::Neg, ty) => { let name = vir::vir_format!(vcx, "mir_unop_neg_{}", int_name(*ty)); - deps.emit_output_ref::(task_key.clone(), MirBuiltinEncoderOutputRef { - name, - }); + deps.emit_output_ref::( + task_key.clone(), + MirBuiltinEncoderOutputRef { name }, + ); /* function mir_unop_neg(arg: s_I32): s_I32 { cons(-val(arg)) } */ - let ty_out = deps.require_ref::( - *ty, - ).unwrap(); - Ok((MirBuiltinEncoderOutput { - function: vcx.alloc(vir::FunctionData { - name, - args: vcx.alloc_slice(&[vcx.mk_local_decl("arg", ty_out.snapshot)]), - ret: ty_out.snapshot, - pres: &[], - posts: &[], - expr: Some(vcx.mk_func_app( - ty_out.from_primitive.unwrap(), - &[vcx.alloc(vir::ExprData::UnOp(vcx.alloc(vir::UnOpData { - kind: vir::UnOpKind::Neg, - expr: vcx.mk_func_app( - ty_out.to_primitive.unwrap(), - &[vcx.mk_local_ex("arg")], - ), - })))], - )), - }), - }, ())) + let ty_out = deps + .require_ref::(*ty) + .unwrap(); + Ok(( + MirBuiltinEncoderOutput { + function: vcx.alloc(vir::FunctionData { + name, + args: vcx.alloc_slice(&[vcx.mk_local_decl("arg", ty_out.snapshot)]), + ret: ty_out.snapshot, + pres: &[], + posts: &[], + expr: Some(vcx.mk_func_app( + ty_out.from_primitive.unwrap(), + &[vcx.alloc(vir::ExprData::UnOp(vcx.alloc(vir::UnOpData { + kind: vir::UnOpKind::Neg, + expr: vcx.mk_func_app( + ty_out.to_primitive.unwrap(), + &[vcx.mk_local_ex("arg")], + ), + })))], + )), + }), + }, + (), + )) } // TODO: should these be two different functions? precondition? MirBuiltinEncoderTask::BinOp(mir::BinOp::Add | mir::BinOp::AddUnchecked, ty) => { let name = vir::vir_format!(vcx, "mir_binop_add_{}", int_name(*ty)); - deps.emit_output_ref::(task_key.clone(), MirBuiltinEncoderOutputRef { - name, - }); + deps.emit_output_ref::( + task_key.clone(), + MirBuiltinEncoderOutputRef { name }, + ); - let ty_out = deps.require_ref::( - *ty, - ).unwrap(); - Ok((MirBuiltinEncoderOutput { - function: vcx.alloc(vir::FunctionData { - name, - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("arg1", ty_out.snapshot), - vcx.mk_local_decl("arg2", ty_out.snapshot), - ]), - ret: ty_out.snapshot, - pres: &[], - posts: &[], - expr: Some(vcx.mk_func_app( - ty_out.from_primitive.unwrap(), - &[vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::Add, - lhs: vcx.mk_func_app( - ty_out.to_primitive.unwrap(), - &[vcx.mk_local_ex("arg1")], - ), - rhs: vcx.mk_func_app( - ty_out.to_primitive.unwrap(), - &[vcx.mk_local_ex("arg2")], - ), - })))], - )), - }), - }, ())) + let ty_out = deps + .require_ref::(*ty) + .unwrap(); + Ok(( + MirBuiltinEncoderOutput { + function: vcx.alloc(vir::FunctionData { + name, + args: vcx.alloc_slice(&[ + vcx.mk_local_decl("arg1", ty_out.snapshot), + vcx.mk_local_decl("arg2", ty_out.snapshot), + ]), + ret: ty_out.snapshot, + pres: &[], + posts: &[], + expr: Some(vcx.mk_func_app( + ty_out.from_primitive.unwrap(), + &[vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::Add, + lhs: vcx.mk_func_app( + ty_out.to_primitive.unwrap(), + &[vcx.mk_local_ex("arg1")], + ), + rhs: vcx.mk_func_app( + ty_out.to_primitive.unwrap(), + &[vcx.mk_local_ex("arg2")], + ), + })))], + )), + }), + }, + (), + )) } - MirBuiltinEncoderTask::CheckedBinOp(op @ (mir::BinOp::Add | mir::BinOp::AddUnchecked | mir::BinOp::Sub | mir::BinOp::SubUnchecked), ty) => { - + MirBuiltinEncoderTask::CheckedBinOp( + op @ (mir::BinOp::Add + | mir::BinOp::AddUnchecked + | mir::BinOp::Sub + | mir::BinOp::SubUnchecked), + ty, + ) => { let (op_name, vir_op) = match op { - mir::BinOp::Add | mir::BinOp::AddUnchecked => ("add", vir::BinOpKind::Add), - mir::BinOp::Sub | mir::BinOp::SubUnchecked => ("sub", vir::BinOpKind::Sub), + mir::BinOp::Add | mir::BinOp::AddUnchecked => ("add", vir::BinOpKind::Add), + mir::BinOp::Sub | mir::BinOp::SubUnchecked => ("sub", vir::BinOpKind::Sub), - _ => unreachable!() + _ => unreachable!(), }; - let name = vir::vir_format!(vcx, "mir_checkedbinop_{}_{}", op_name, int_name(*ty)); + let name = + vir::vir_format!(vcx, "mir_checkedbinop_{}_{}", op_name, int_name(*ty)); - deps.emit_output_ref::(task_key.clone(), MirBuiltinEncoderOutputRef { - name, - }); + deps.emit_output_ref::( + task_key.clone(), + MirBuiltinEncoderOutputRef { name }, + ); - let ty_in = deps.require_ref::( - *ty, - ).unwrap(); - let ty_out = deps.require_ref::( - vcx.tcx.mk_ty_from_kind(ty::TyKind::Tuple(vcx.tcx.mk_type_list(&[ - *ty, - vcx.tcx.mk_ty_from_kind(ty::TyKind::Bool), - ]))), - ).unwrap(); - Ok((MirBuiltinEncoderOutput { - function: vcx.alloc(vir::FunctionData { - name, - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("arg1", ty_in.snapshot), - vcx.mk_local_decl("arg2", ty_in.snapshot), - ]), - ret: ty_out.snapshot, - pres: &[], - posts: &[], - expr: Some(vcx.mk_func_app( - vir::vir_format!(vcx, "{}_cons", ty_out.snapshot_name), - &[ - vcx.mk_func_app( - ty_in.from_primitive.unwrap(), - &[vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir_op, - lhs: vcx.mk_func_app( - ty_in.to_primitive.unwrap(), - &[vcx.mk_local_ex("arg1")], - ), - rhs: vcx.mk_func_app( - ty_in.to_primitive.unwrap(), - &[vcx.mk_local_ex("arg2")], - ), - })))], - ), - // TODO: overflow condition! - vcx.mk_func_app( - "s_Bool_cons", - &[&vir::ExprData::Const(&vir::ConstData::Bool(false))], - ), - ], - )), - }), - }, ())) + let ty_in = deps + .require_ref::(*ty) + .unwrap(); + let ty_out = deps + .require_ref::(vcx.tcx.mk_ty_from_kind( + ty::TyKind::Tuple( + vcx.tcx.mk_type_list(&[ + *ty, + vcx.tcx.mk_ty_from_kind(ty::TyKind::Bool), + ]), + ), + )) + .unwrap(); + Ok(( + MirBuiltinEncoderOutput { + function: vcx.alloc(vir::FunctionData { + name, + args: vcx.alloc_slice(&[ + vcx.mk_local_decl("arg1", ty_in.snapshot), + vcx.mk_local_decl("arg2", ty_in.snapshot), + ]), + ret: ty_out.snapshot, + pres: &[], + posts: &[], + expr: Some(vcx.mk_func_app( + vir::vir_format!(vcx, "{}_cons", ty_out.snapshot_name), + &[ + vcx.mk_func_app( + ty_in.from_primitive.unwrap(), + &[vcx.alloc(vir::ExprData::BinOp(vcx.alloc( + vir::BinOpData { + kind: vir_op, + lhs: vcx.mk_func_app( + ty_in.to_primitive.unwrap(), + &[vcx.mk_local_ex("arg1")], + ), + rhs: vcx.mk_func_app( + ty_in.to_primitive.unwrap(), + &[vcx.mk_local_ex("arg2")], + ), + }, + )))], + ), + // TODO: overflow condition! + vcx.mk_func_app( + "s_Bool_cons", + &[&vir::ExprData::Const(&vir::ConstData::Bool(false))], + ), + ], + )), + }), + }, + (), + )) } other => todo!("{other:?}"), } - }) } } diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 5b06f614487..91bffdf6765 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -1,17 +1,15 @@ +use mir_state_analysis::{ + free_pcs::{CapabilityKind, FreePcsAnalysis, FreePcsBasicBlock, RepackOp}, + utils::Place, +}; use prusti_rustc_interface::{ middle::{mir, ty}, span::def_id::DefId, }; -use mir_state_analysis::{ - free_pcs::{FreePcsAnalysis, FreePcsBasicBlock, RepackOp, CapabilityKind}, utils::Place, -}; //use mir_ssa_analysis::{ // SsaAnalysis, //}; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct MirImpureEncoder; @@ -48,7 +46,8 @@ impl TaskEncoder for MirImpureEncoder { type EncodingError = MirImpureEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirImpureEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirImpureEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -66,31 +65,35 @@ impl TaskEncoder for MirImpureEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { use mir::visit::Visitor; vir::with_vcx(|vcx| { let def_id = *task_key; let method_name = vir::vir_format!(vcx, "m_{}", vcx.tcx.item_name(def_id)); - deps.emit_output_ref::(def_id, MirImpureEncoderOutputRef { - method_name, - }); - - let local_defs = deps.require_local::( - def_id, - ).unwrap(); - let spec = deps.require_local::( - (def_id, false) - ).unwrap(); + deps.emit_output_ref::(def_id, MirImpureEncoderOutputRef { method_name }); + + let local_defs = deps + .require_local::(def_id) + .unwrap(); + let spec = deps + .require_local::((def_id, false)) + .unwrap(); let (spec_pres, spec_posts) = (spec.pres, spec.posts); - let local_def_id = def_id.expect_local(); - let body = vcx.body.borrow_mut().get_impure_fn_body_identity(local_def_id); + let local_def_id = def_id.expect_local(); + let body = vcx + .body + .borrow_mut() + .get_impure_fn_body_identity(local_def_id); // let body = vcx.tcx.mir_promoted(local_def_id).0.borrow(); //let ssa_analysis = SsaAnalysis::analyse(&body); @@ -129,14 +132,13 @@ impl TaskEncoder for MirImpureEncoder { ))) } if ENCODE_REACH_BB { - start_stmts.extend((0..block_count) - .map(|block| { - let name = vir::vir_format!(vcx, "_reach_bb{block}"); - vcx.alloc(vir::StmtData::LocalDecl( - vir::vir_local_decl! { vcx; [name] : Bool }, - Some(vcx.alloc(vir::ExprData::Todo("false"))), - )) - })); + start_stmts.extend((0..block_count).map(|block| { + let name = vir::vir_format!(vcx, "_reach_bb{block}"); + vcx.alloc(vir::StmtData::LocalDecl( + vir::vir_local_decl! { vcx; [name] : Bool }, + Some(vcx.alloc(vir::ExprData::Todo("false"))), + )) + })); } encoded_blocks.push(vcx.alloc(vir::CfgBlockData { label: vcx.alloc(vir::CfgBlockLabelData::Start), @@ -161,23 +163,21 @@ impl TaskEncoder for MirImpureEncoder { posts.push(local_defs.locals[mir::RETURN_PLACE].impure_pred); posts.extend(spec_posts); - - // TODO: dedup with mir_pure - let attrs = vcx.tcx.get_attrs_unchecked(def_id); - let is_trusted = attrs - .iter() - .filter(|attr| !attr.is_doc_comment()) - .map(|attr| attr.get_normal_item()) - .any(|item| { - item.path.segments.len() == 2 - && item.path.segments[0].ident.as_str() == "prusti" - && item.path.segments[1].ident.as_str() == "trusted" - }); + // TODO: dedup with mir_pure + let attrs = vcx.tcx.get_attrs_unchecked(def_id); + let is_trusted = attrs + .iter() + .filter(|attr| !attr.is_doc_comment()) + .map(|attr| attr.get_normal_item()) + .any(|item| { + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "trusted" + }); let blocks = if is_trusted { None - } - else { + } else { let mut visitor = EncoderVisitor { vcx, deps, @@ -205,22 +205,26 @@ impl TaskEncoder for MirImpureEncoder { Some(vcx.alloc_slice(&visitor.encoded_blocks)) }; - Ok((MirImpureEncoderOutput { - method: vcx.alloc(vir::MethodData { - name: method_name, - args: vcx.alloc_slice(&args), - rets: &[], - pres: vcx.alloc_slice(&pres), - posts: vcx.alloc_slice(&posts), - blocks, - }), - }, ())) + Ok(( + MirImpureEncoderOutput { + method: vcx.alloc(vir::MethodData { + name: method_name, + args: vcx.alloc_slice(&args), + rets: &[], + pres: vcx.alloc_slice(&pres), + posts: vcx.alloc_slice(&posts), + blocks, + }), + }, + (), + )) }) } } struct EncoderVisitor<'vir, 'enc> - where 'vir: 'enc +where + 'vir: 'enc, { vcx: &'vir vir::VirCtxt<'vir>, deps: &'enc mut TaskEncoderDependencies<'vir>, @@ -257,13 +261,15 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { match projection { mir::ProjectionElem::Field(f, ty) => { let ty_out_struct = ty_out.expect_structlike(); - let field_ty_out = self.deps.require_ref::( - ty, - ).unwrap(); - (self.vcx.mk_func_app( - ty_out_struct.field_projection_p[f.as_usize()], - &[base], - ), field_ty_out) + let field_ty_out = self + .deps + .require_ref::(ty) + .unwrap(); + ( + self.vcx + .mk_func_app(ty_out_struct.field_projection_p[f.as_usize()], &[base]), + field_ty_out, + ) } _ => panic!("unsupported projection"), } @@ -275,8 +281,11 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { ty_out: crate::encoders::TypeEncoderOutputRef<'vir>, projection: &'vir [mir::PlaceElem<'vir>], ) -> (vir::Expr<'vir>, crate::encoders::TypeEncoderOutputRef<'vir>) { - projection.iter() - .fold((base, ty_out), |(base, ty_out), proj| self.project_one(base, ty_out, *proj)) + projection + .iter() + .fold((base, ty_out), |(base, ty_out), proj| { + self.project_one(base, ty_out, *proj) + }) } /* @@ -321,11 +330,10 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { } */ - fn fpcs_location( - &mut self, - location: mir::Location, - ) { - let repacks = self.current_fpcs.as_ref().unwrap().statements[location.statement_index].repacks.clone(); + fn fpcs_location(&mut self, location: mir::Location) { + let repacks = self.current_fpcs.as_ref().unwrap().statements[location.statement_index] + .repacks + .clone(); for repack_op in repacks { match repack_op { RepackOp::Expand(place, _target, capability_kind) @@ -339,24 +347,26 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { let place_ty = place.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); - let place_ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); + let place_ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); let ref_p = self.encode_place(place); - if matches!(repack_op, mir_state_analysis::free_pcs::RepackOp::Expand(..)) { - self.stmt(vir::StmtData::Unfold(self.vcx.alloc(vir::PredicateAppData { - target: place_ty_out.predicate_name, - args: self.vcx.alloc_slice(&[ - ref_p, - ]), - }))); + if matches!( + repack_op, + mir_state_analysis::free_pcs::RepackOp::Expand(..) + ) { + self.stmt(vir::StmtData::Unfold(self.vcx.alloc( + vir::PredicateAppData { + target: place_ty_out.predicate_name, + args: self.vcx.alloc_slice(&[ref_p]), + }, + ))); } else { self.stmt(vir::StmtData::Fold(self.vcx.alloc(vir::PredicateAppData { target: place_ty_out.predicate_name, - args: self.vcx.alloc_slice(&[ - ref_p, - ]), + args: self.vcx.alloc_slice(&[ref_p]), }))); } } @@ -364,133 +374,148 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { let place_ty = place.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); - let place_ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); + let place_ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); let ref_p = self.encode_place(place); - self.stmt(vir::StmtData::Exhale(self.vcx.alloc(vir::ExprData::PredicateApp(self.vcx.alloc(vir::PredicateAppData { - target: place_ty_out.predicate_name, - args: self.vcx.alloc_slice(&[ - ref_p, - ]), - }))))); + self.stmt(vir::StmtData::Exhale(self.vcx.alloc( + vir::ExprData::PredicateApp(self.vcx.alloc(vir::PredicateAppData { + target: place_ty_out.predicate_name, + args: self.vcx.alloc_slice(&[ref_p]), + })), + ))); } unsupported_op => panic!("unsupported repack op: {unsupported_op:?}"), } } } - fn encode_operand_snap( - &mut self, - operand: &mir::Operand<'vir>, - ) -> vir::Expr<'vir> { + fn encode_operand_snap(&mut self, operand: &mir::Operand<'vir>) -> vir::Expr<'vir> { match operand { &mir::Operand::Move(source) => { let place_ty = source.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); // TODO - let ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); + let ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); let place_exp = self.encode_place(Place::from(source)); let snap_val = self.vcx.mk_func_app(ty_out.function_snap, &[place_exp]); let tmp_exp = self.new_tmp(ty_out.snapshot).1; - self.stmt(vir::StmtData::PureAssign(self.vcx.alloc(vir::PureAssignData { - lhs: tmp_exp, - rhs: snap_val, - }))); - self.stmt(vir::StmtData::Exhale(self.vcx.alloc(vir::ExprData::PredicateApp(self.vcx.alloc(vir::PredicateAppData { - target: ty_out.predicate_name, - args: self.vcx.alloc_slice(&[place_exp]), - }))))); + self.stmt(vir::StmtData::PureAssign(self.vcx.alloc( + vir::PureAssignData { + lhs: tmp_exp, + rhs: snap_val, + }, + ))); + self.stmt(vir::StmtData::Exhale(self.vcx.alloc( + vir::ExprData::PredicateApp(self.vcx.alloc(vir::PredicateAppData { + target: ty_out.predicate_name, + args: self.vcx.alloc_slice(&[place_exp]), + })), + ))); tmp_exp } &mir::Operand::Copy(source) => { let place_ty = source.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); // TODO - let ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); - self.vcx.mk_func_app(ty_out.function_snap, &[self.encode_place(Place::from(source))]) + let ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); + self.vcx.mk_func_app( + ty_out.function_snap, + &[self.encode_place(Place::from(source))], + ) } mir::Operand::Constant(box constant) => self.encode_constant(constant), } } - fn encode_operand( - &mut self, - operand: &mir::Operand<'vir>, - ) -> vir::Expr<'vir> { + fn encode_operand(&mut self, operand: &mir::Operand<'vir>) -> vir::Expr<'vir> { let (snap_val, ty_out) = match operand { &mir::Operand::Move(source) => return self.encode_place(Place::from(source)), &mir::Operand::Copy(source) => { let place_ty = source.ty(self.local_decls, self.vcx.tcx); assert!(place_ty.variant_index.is_none()); // TODO - let ty_out = self.deps.require_ref::( - place_ty.ty, - ).unwrap(); - (self.vcx.mk_func_app(ty_out.function_snap, &[self.encode_place(Place::from(source))]), ty_out) + let ty_out = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); + ( + self.vcx.mk_func_app( + ty_out.function_snap, + &[self.encode_place(Place::from(source))], + ), + ty_out, + ) } mir::Operand::Constant(box constant) => { - let ty_out = self.deps.require_ref::( - constant.ty(), - ).unwrap(); + let ty_out = self + .deps + .require_ref::(constant.ty()) + .unwrap(); (self.encode_constant(constant), ty_out) } }; let tmp_exp = self.new_tmp(&vir::TypeData::Ref).1; - self.stmt(vir::StmtData::MethodCall(self.vcx.alloc(vir::MethodCallData { - targets: &[], - method: ty_out.method_assign, - args: self.vcx.alloc_slice(&[tmp_exp, snap_val]), - }))); + self.stmt(vir::StmtData::MethodCall(self.vcx.alloc( + vir::MethodCallData { + targets: &[], + method: ty_out.method_assign, + args: self.vcx.alloc_slice(&[tmp_exp, snap_val]), + }, + ))); tmp_exp } - fn encode_place( - &mut self, - place: Place<'vir>, - ) -> vir::Expr<'vir> { + fn encode_place(&mut self, place: Place<'vir>) -> vir::Expr<'vir> { //assert!(place.projection.is_empty()); //self.vcx.mk_local_ex(vir::vir_format!(self.vcx, "_{}p", place.local.index())) self.project( self.local_defs.locals[place.local].local_ex, self.local_defs.locals[place.local].ty.clone(), place.projection, - ).0 + ) + .0 } // TODO: this will not work for unevaluated constants (which needs const // MIR evaluation, more like pure fn body encoding) - fn encode_constant( - &self, - constant: &mir::Constant<'vir>, - ) -> vir::Expr<'vir> { + fn encode_constant(&self, constant: &mir::Constant<'vir>) -> vir::Expr<'vir> { match constant.literal { - mir::ConstantKind::Val(const_val, const_ty) => { - match const_ty.kind() { - ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Tuple0_cons()"), - )), - ty::TyKind::Int(int_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Int_{}_cons({})", int_ty.name_str(), scalar_val.try_to_int(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Uint(uint_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Uint_{}_cons({})", uint_ty.name_str(), scalar_val.try_to_uint(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Bool => self.vcx.alloc(vir::ExprData::Todo( - vir::vir_format!(self.vcx, "s_Bool_cons({})", const_val.try_to_bool().unwrap()), - )), - unsupported_ty => todo!("unsupported constant literal type: {unsupported_ty:?}"), + mir::ConstantKind::Val(const_val, const_ty) => match const_ty.kind() { + ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc(vir::ExprData::Todo( + vir::vir_format!(self.vcx, "s_Tuple0_cons()"), + )), + ty::TyKind::Int(int_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!( + self.vcx, + "s_Int_{}_cons({})", + int_ty.name_str(), + scalar_val.try_to_int(scalar_val.size()).unwrap() + ))) } - } + ty::TyKind::Uint(uint_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!( + self.vcx, + "s_Uint_{}_cons({})", + uint_ty.name_str(), + scalar_val.try_to_uint(scalar_val.size()).unwrap() + ))) + } + ty::TyKind::Bool => self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!( + self.vcx, + "s_Bool_cons({})", + const_val.try_to_bool().unwrap() + ))), + unsupported_ty => todo!("unsupported constant literal type: {unsupported_ty:?}"), + }, unsupported_literal => todo!("unsupported constant literal: {unsupported_literal:?}"), } } @@ -511,21 +536,23 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { // fn visit_body(&mut self, body: &mir::Body<'tcx>) { // println!("visiting body!"); // } - fn visit_basic_block_data( - &mut self, - block: mir::BasicBlock, - data: &mir::BasicBlockData<'vir>, - ) { + fn visit_basic_block_data(&mut self, block: mir::BasicBlock, data: &mir::BasicBlockData<'vir>) { self.current_fpcs = Some(self.fpcs_analysis.get_all_for_bb(block)); self.current_stmts = Some(Vec::with_capacity( data.statements.len(), // TODO: not exact? )); if ENCODE_REACH_BB { - self.stmt(vir::StmtData::PureAssign(self.vcx.alloc(vir::PureAssignData { - lhs: self.vcx.mk_local_ex(vir::vir_format!(self.vcx, "_reach_bb{}", block.as_usize())), - rhs: self.vcx.mk_true(), - }))); + self.stmt(vir::StmtData::PureAssign(self.vcx.alloc( + vir::PureAssignData { + lhs: self.vcx.mk_local_ex(vir::vir_format!( + self.vcx, + "_reach_bb{}", + block.as_usize() + )), + rhs: self.vcx.mk_true(), + }, + ))); } /* @@ -559,26 +586,25 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { self.super_basic_block_data(block, data); let stmts = self.current_stmts.take().unwrap(); let terminator = self.current_terminator.take().unwrap(); - self.encoded_blocks.push(self.vcx.alloc(vir::CfgBlockData { - label: self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(block.as_usize())), - stmts: self.vcx.alloc_slice(&stmts), - terminator, - })); + self.encoded_blocks.push( + self.vcx.alloc(vir::CfgBlockData { + label: self + .vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(block.as_usize())), + stmts: self.vcx.alloc_slice(&stmts), + terminator, + }), + ); } - fn visit_statement( - &mut self, - statement: &mir::Statement<'vir>, - location: mir::Location, - ) { + fn visit_statement(&mut self, statement: &mir::Statement<'vir>, location: mir::Location) { // TODO: proper flag // This clears up the noise a bit, making sure StorageLive and other // kinds do not show up in the comments. let IGNORE_NOP_STMTS = true; if IGNORE_NOP_STMTS { match &statement.kind { - mir::StatementKind::StorageLive(..) - | mir::StatementKind::StorageDead(..) => { + mir::StatementKind::StorageLive(..) | mir::StatementKind::StorageDead(..) => { return; } _ => {} @@ -720,10 +746,10 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { method: dest_ty_out.method_assign, args: self.vcx.alloc_slice(&[proj_ref, cons]), }))); - + for field in fields { if let mir::Operand::Move(source) = field { - + } } None @@ -766,42 +792,53 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { } } - fn visit_terminator( - &mut self, - terminator: &mir::Terminator<'vir>, - location: mir::Location, - ) { + fn visit_terminator(&mut self, terminator: &mir::Terminator<'vir>, location: mir::Location) { self.fpcs_location(location); let terminator = match &terminator.kind { mir::TerminatorKind::Goto { target } - | mir::TerminatorKind::FalseUnwind { real_target: target, .. } => - self.vcx.alloc(vir::TerminatorStmtData::Goto( - self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), - )), + | mir::TerminatorKind::FalseUnwind { + real_target: target, + .. + } => self.vcx.alloc(vir::TerminatorStmtData::Goto( + self.vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), + )), mir::TerminatorKind::SwitchInt { discr, targets } => { //let discr_version = self.ssa_analysis.version.get(&(location, discr.local)).unwrap(); //let discr_name = vir::vir_format!(self.vcx, "_{}s_{}", discr.local.index(), discr_version); - let ty_out = self.deps.require_ref::( - discr.ty(self.local_decls, self.vcx.tcx), - ).unwrap(); - - let goto_targets = self.vcx.alloc_slice(&targets.iter() - .map(|(value, target)| ( - ty_out.expr_from_u128(value), - //self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!(self.vcx, "constant({value})"))), - self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), - )) - .collect::>()); + let ty_out = self + .deps + .require_ref::( + discr.ty(self.local_decls, self.vcx.tcx), + ) + .unwrap(); + + let goto_targets = self.vcx.alloc_slice( + &targets + .iter() + .map(|(value, target)| { + ( + ty_out.expr_from_u128(value), + //self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!(self.vcx, "constant({value})"))), + self.vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), + ) + }) + .collect::>(), + ); let goto_otherwise = self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock( targets.otherwise().as_usize(), )); let discr_ex = self.encode_operand_snap(discr); - self.vcx.alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc(vir::GotoIfData { - value: discr_ex, // self.vcx.mk_local_ex(discr_name), - targets: goto_targets, - otherwise: goto_otherwise, - }))) + self.vcx + .alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc( + vir::GotoIfData { + value: discr_ex, // self.vcx.mk_local_ex(discr_name), + targets: goto_targets, + otherwise: goto_otherwise, + }, + ))) } mir::TerminatorKind::Return => self.vcx.alloc(vir::TerminatorStmtData::Goto( self.vcx.alloc(vir::CfgBlockLabelData::End), @@ -825,88 +862,116 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { // TODO: dedup with mir_pure let attrs = self.vcx.tcx.get_attrs_unchecked(*func_def_id); - let is_pure = attrs.iter() + let is_pure = attrs + .iter() .filter(|attr| !attr.is_doc_comment()) - .map(|attr| attr.get_normal_item()).any(|item| + .map(|attr| attr.get_normal_item()) + .any(|item| { item.path.segments.len() == 2 - && item.path.segments[0].ident.as_str() == "prusti" - && item.path.segments[1].ident.as_str() == "pure" - ); + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "pure" + }); let dest = self.encode_place(Place::from(*destination)); - let call_args = args.iter().map(|op| + let call_args = args.iter().map(|op| { if is_pure { self.encode_operand_snap(op) - } - else { + } else { self.encode_operand(op) } - ); + }); if is_pure { let func_args = call_args.collect::>(); - let pure_func_name = self.deps.require_ref::( - *func_def_id, - ).unwrap().function_name; + let pure_func_name = self + .deps + .require_ref::(*func_def_id) + .unwrap() + .function_name; let pure_func_app = self.vcx.mk_func_app(pure_func_name, &func_args); let method_assign = { //TODO: Can we get `method_assign` in a better way? Maybe from the MirFunctionEncoder? - let body = self.vcx.body.borrow_mut().get_impure_fn_body_identity(func_def_id.expect_local()); - let return_type = self.deps + let body = self + .vcx + .body + .borrow_mut() + .get_impure_fn_body_identity(func_def_id.expect_local()); + let return_type = self + .deps .require_ref::(body.return_ty()) .unwrap(); return_type.method_assign }; - self.stmt(vir::StmtData::MethodCall(self.vcx.alloc(vir::MethodCallData { - targets: &[], - method: method_assign, - args: self.vcx.alloc_slice(&[dest, pure_func_app]), - }))); - } - else { - let meth_args = std::iter::once(dest) - .chain(call_args) - .collect::>(); - let func_out = self.deps.require_ref::( - *func_def_id, - ).unwrap(); - - self.stmt(vir::StmtData::MethodCall(self.vcx.alloc(vir::MethodCallData { - targets: &[], - method: func_out.method_name, - args: self.vcx.alloc_slice(&meth_args), - }))); + self.stmt(vir::StmtData::MethodCall(self.vcx.alloc( + vir::MethodCallData { + targets: &[], + method: method_assign, + args: self.vcx.alloc_slice(&[dest, pure_func_app]), + }, + ))); + } else { + let meth_args = std::iter::once(dest).chain(call_args).collect::>(); + let func_out = self + .deps + .require_ref::(*func_def_id) + .unwrap(); + + self.stmt(vir::StmtData::MethodCall(self.vcx.alloc( + vir::MethodCallData { + targets: &[], + method: func_out.method_name, + args: self.vcx.alloc_slice(&meth_args), + }, + ))); } - self.vcx.alloc(vir::TerminatorStmtData::Goto( - self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.unwrap().as_usize())), - )) + self.vcx.alloc(vir::TerminatorStmtData::Goto(self.vcx.alloc( + vir::CfgBlockLabelData::BasicBlock(target.unwrap().as_usize()), + ))) } - mir::TerminatorKind::Assert { cond, expected, msg, target, unwind } => { - + mir::TerminatorKind::Assert { + cond, + expected, + msg, + target, + unwind, + } => { let otherwise = match unwind { mir::UnwindAction::Cleanup(bb) => bb, - _ => todo!() + _ => todo!(), }; let enc = self.encode_operand_snap(cond); let enc = self.vcx.mk_func_app("s_Bool_val", &[enc]); - let target_bb = self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())); - - self.vcx.alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc(vir::GotoIfData { - value: enc, - targets: self.vcx.alloc_slice(&[(self.vcx.alloc(vir::ExprData::Const(self.vcx.alloc(vir::ConstData::Bool(*expected)))) - , &target_bb)]), - otherwise: self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock(otherwise.as_usize())), - }))) + let target_bb = self + .vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())); + + self.vcx.alloc(vir::TerminatorStmtData::GotoIf( + self.vcx.alloc(vir::GotoIfData { + value: enc, + targets: self.vcx.alloc_slice(&[( + self.vcx.alloc(vir::ExprData::Const( + self.vcx.alloc(vir::ConstData::Bool(*expected)), + )), + &target_bb, + )]), + otherwise: self + .vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(otherwise.as_usize())), + }), + )) } - unsupported_kind => self.vcx.alloc(vir::TerminatorStmtData::Dummy( - vir::vir_format!(self.vcx, "terminator {unsupported_kind:?}"), - )), + unsupported_kind => self + .vcx + .alloc(vir::TerminatorStmtData::Dummy(vir::vir_format!( + self.vcx, + "terminator {unsupported_kind:?}" + ))), }; assert!(self.current_terminator.replace(terminator).is_none()); } diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index 369f48edfd2..f26926d107d 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -1,15 +1,12 @@ +use crate::encoders::{TypeEncoder, ViperTupleEncoder}; use prusti_rustc_interface::{ data_structures::graph::dominators::Dominators, middle::{mir, ty}, span::def_id::DefId, type_ir::sty::TyKind, }; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; use std::collections::HashMap; -use crate::encoders::{ViperTupleEncoder, TypeEncoder}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct MirPureEncoder; @@ -31,33 +28,39 @@ pub struct MirPureEncoderOutput<'vir> { } /// Optimize a vir expresison -/// +/// /// This is a temporary fix for the issue where variables that are quantified over and are then stored in a let binding /// cause issues with triggering somehow. -/// +/// /// This was also intended to make debugging easier by making the resulting viper code a bit more readable -/// +/// /// This should be replaced with a proper solution -fn opt<'vir, Cur, Next> (expr: vir::ExprGen<'vir, Cur, Next>, rename: &mut HashMap) -> vir::ExprGen<'vir, Cur, Next>{ +fn opt<'vir, Cur, Next>( + expr: vir::ExprGen<'vir, Cur, Next>, + rename: &mut HashMap, +) -> vir::ExprGen<'vir, Cur, Next> { match expr { - vir::ExprGenData::Local(d) => { - let nam = rename.get(d.name).map(|e| e.as_str()).unwrap_or(d.name).to_owned(); - vir::with_vcx(move |vcx| { - vcx.mk_local_ex(vcx.alloc_str(&nam)) - }) - }, - - vir::ExprGenData::Let(vir::LetGenData{name, val, expr}) => { + let nam = rename + .get(d.name) + .map(|e| e.as_str()) + .unwrap_or(d.name) + .to_owned(); + vir::with_vcx(move |vcx| vcx.mk_local_ex(vcx.alloc_str(&nam))) + } + vir::ExprGenData::Let(vir::LetGenData { name, val, expr }) => { let val = opt(val, rename); match val { // let name = loc.name vir::ExprGenData::Local(loc) => { - let t = rename.get(loc.name).map(|e| e.to_owned()).unwrap_or(loc.name.to_string()); + let t = rename + .get(loc.name) + .map(|e| e.to_owned()) + .unwrap_or(loc.name.to_string()); assert!(rename.insert(name.to_string(), t).is_none()); - return opt(expr, rename) + return opt(expr, rename); } _ => {} } @@ -67,82 +70,96 @@ fn opt<'vir, Cur, Next> (expr: vir::ExprGen<'vir, Cur, Next>, rename: &mut HashM match expr { vir::ExprGenData::Local(inner_local) => { if &inner_local.name == name { - return val + return val; } } _ => {} } vir::with_vcx(move |vcx| { - vcx.alloc(vir::ExprGenData::Let(vcx.alloc(vir::LetGenData{name, val, expr}))) + vcx.alloc(vir::ExprGenData::Let(vcx.alloc(vir::LetGenData { + name, + val, + expr, + }))) }) - }, - vir::ExprGenData::FuncApp( vir::FuncAppGenData{target, args}) => { + } + vir::ExprGenData::FuncApp(vir::FuncAppGenData { target, args }) => { let n_args = args.iter().map(|arg| opt(arg, rename)).collect::>(); - vir::with_vcx(move |vcx| { - vcx.mk_func_app(target, &n_args) - }) + vir::with_vcx(move |vcx| vcx.mk_func_app(target, &n_args)) } - - vir::ExprGenData::PredicateApp( vir::PredicateAppGenData{target, args}) => { + vir::ExprGenData::PredicateApp(vir::PredicateAppGenData { target, args }) => { let n_args = args.iter().map(|arg| opt(arg, rename)).collect::>(); vir::with_vcx(move |vcx| { - vcx.alloc(vir::ExprGenData::PredicateApp(vcx.alloc(vir::PredicateAppGenData { - target, - args: vcx.alloc_slice(&n_args), - }))) + vcx.alloc(vir::ExprGenData::PredicateApp(vcx.alloc( + vir::PredicateAppGenData { + target, + args: vcx.alloc_slice(&n_args), + }, + ))) }) } - - vir::ExprGenData::Forall(vir::ForallGenData{qvars, triggers, body}) => { + + vir::ExprGenData::Forall(vir::ForallGenData { + qvars, + triggers, + body, + }) => { let body = opt(body, rename); vir::with_vcx(move |vcx| { - vcx.alloc(vir::ExprGenData::Forall(vcx.alloc(vir::ForallGenData{qvars, triggers, body}))) + vcx.alloc(vir::ExprGenData::Forall(vcx.alloc(vir::ForallGenData { + qvars, + triggers, + body, + }))) }) } - - vir::ExprGenData::Ternary(vir::TernaryGenData{cond, then, else_}) => { + vir::ExprGenData::Ternary(vir::TernaryGenData { cond, then, else_ }) => { let cond = opt(cond, rename); let then = opt(then, rename); let else_ = opt(else_, rename); vir::with_vcx(move |vcx| { - vcx.alloc(vir::ExprGenData::Ternary(vcx.alloc(vir::TernaryGenData{cond, then, else_}))) + vcx.alloc(vir::ExprGenData::Ternary(vcx.alloc(vir::TernaryGenData { + cond, + then, + else_, + }))) }) } - vir::ExprGenData::BinOp(vir::BinOpGenData {kind, lhs, rhs}) => { + vir::ExprGenData::BinOp(vir::BinOpGenData { kind, lhs, rhs }) => { let lhs = opt(lhs, rename); let rhs = opt(rhs, rename); vir::with_vcx(move |vcx| { - vcx.alloc(vir::ExprGenData::BinOp(vcx.alloc(vir::BinOpGenData{kind: kind.clone(), lhs, rhs}))) + vcx.alloc(vir::ExprGenData::BinOp(vcx.alloc(vir::BinOpGenData { + kind: kind.clone(), + lhs, + rhs, + }))) }) } - - vir::ExprGenData::UnOp(vir::UnOpGenData {kind, expr}) =>{ + vir::ExprGenData::UnOp(vir::UnOpGenData { kind, expr }) => { let expr = opt(expr, rename); vir::with_vcx(move |vcx| { - vcx.alloc(vir::ExprGenData::UnOp(vcx.alloc(vir::UnOpGenData{kind: kind.clone(), expr}))) + vcx.alloc(vir::ExprGenData::UnOp(vcx.alloc(vir::UnOpGenData { + kind: kind.clone(), + expr, + }))) }) } - todo@( - vir::ExprGenData::Unfolding(_) | - vir::ExprGenData::Field(_, _) | - vir::ExprGenData::Old(_) | - vir::ExprGenData::AccField(_)| - vir::ExprGenData::Lazy(_, _)) - => todo, - - other @ ( - vir::ExprGenData::Const(_) | - vir::ExprGenData::Todo(_) ) => other - + todo @ (vir::ExprGenData::Unfolding(_) + | vir::ExprGenData::Field(_, _) + | vir::ExprGenData::Old(_) + | vir::ExprGenData::AccField(_) + | vir::ExprGenData::Lazy(_, _)) => todo, + other @ (vir::ExprGenData::Const(_) | vir::ExprGenData::Todo(_)) => other, } } @@ -156,9 +173,9 @@ pub struct MirPureEncoderTask<'tcx> { // TODO: depth of encoding should be in the lazy context rather than here; // can we integrate the lazy context into the identifier system? pub encoding_depth: usize, - pub parent_def_id: DefId, // ID of the function - pub promoted: Option, // ID of a constant within the function - pub param_env: ty::ParamEnv<'tcx>, // param environment at the usage site + pub parent_def_id: DefId, // ID of the function + pub promoted: Option, // ID of a constant within the function + pub param_env: ty::ParamEnv<'tcx>, // param environment at the usage site pub substs: ty::GenericArgsRef<'tcx>, // type substitutions at the usage site } @@ -166,8 +183,8 @@ impl TaskEncoder for MirPureEncoder { type TaskDescription<'vir> = MirPureEncoderTask<'vir>; type TaskKey<'vir> = ( - usize, // encoding depth - DefId, // ID of the function + usize, // encoding depth + DefId, // ID of the function Option, // ID of a constant within the function, or `None` if encoding the function itself ty::GenericArgsRef<'vir>, // ? this should be the "signature", after applying the env/substs ); @@ -177,7 +194,8 @@ impl TaskEncoder for MirPureEncoder { type EncodingError = MirPureEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirPureEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, MirPureEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -201,13 +219,16 @@ impl TaskEncoder for MirPureEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { deps.emit_output_ref::(*task_key, ()); let def_id = task_key.1; //.parent_def_id; @@ -216,7 +237,10 @@ impl TaskEncoder for MirPureEncoder { tracing::debug!("encoding {def_id:?}"); let expr = vir::with_vcx(move |vcx| { //let body = vcx.tcx.mir_promoted(local_def_id).0.borrow(); - let body = vcx.body.borrow_mut().get_impure_fn_body_identity(local_def_id); + let body = vcx + .body + .borrow_mut() + .get_impure_fn_body_identity(local_def_id); let expr_inner = Encoder::new(vcx, task_key.0, &body, deps).encode_body(); @@ -242,8 +266,7 @@ impl TaskEncoder for MirPureEncoder { let opted = opt(expr_inner, &mut rename); tracing::warn!("after opt {opted:?}"); opted - } - else { + } else { expr_inner }; @@ -276,8 +299,16 @@ impl<'vir> Update<'vir> { fn merge(self, newer: Self) -> Self { Self { - binds: self.binds.into_iter().chain(newer.binds.into_iter()).collect(), - versions: self.versions.into_iter().chain(newer.versions.into_iter()).collect(), + binds: self + .binds + .into_iter() + .chain(newer.binds.into_iter()) + .collect(), + versions: self + .versions + .into_iter() + .chain(newer.versions.into_iter()) + .collect(), } } @@ -289,7 +320,8 @@ impl<'vir> Update<'vir> { } struct Encoder<'vir, 'enc> - where 'vir: 'enc +where + 'vir: 'enc, { vcx: &'vir vir::VirCtxt<'vir>, encoding_depth: usize, @@ -301,7 +333,8 @@ struct Encoder<'vir, 'enc> } impl<'vir, 'enc> Encoder<'vir, 'enc> - where 'vir: 'enc +where + 'vir: 'enc, { fn new( vcx: &'vir vir::VirCtxt<'vir>, @@ -309,7 +342,10 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> body: &'enc mir::Body<'vir>, deps: &'enc mut TaskEncoderDependencies<'vir>, ) -> Self { - assert!(!body.basic_blocks.is_cfg_cyclic(), "MIR pure encoding does not support loops"); + assert!( + !body.basic_blocks.is_cfg_cyclic(), + "MIR pure encoding does not support loops" + ); Self { vcx, @@ -317,31 +353,28 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> body, deps, visited: Default::default(), - version_ctr: (0..body.local_decls.len()).map(|local| (local.into(), 0)).collect(), + version_ctr: (0..body.local_decls.len()) + .map(|local| (local.into(), 0)) + .collect(), phi_ctr: 0, } } - fn mk_local( - &self, - local: mir::Local, - version: usize, - ) -> &'vir str { - vir::vir_format!(self.vcx, "_{}_{}s_{}", self.encoding_depth, local.as_usize(), version) + fn mk_local(&self, local: mir::Local, version: usize) -> &'vir str { + vir::vir_format!( + self.vcx, + "_{}_{}s_{}", + self.encoding_depth, + local.as_usize(), + version + ) } - fn mk_local_ex( - &self, - local: mir::Local, - version: usize, - ) -> ExprRet<'vir> { + fn mk_local_ex(&self, local: mir::Local, version: usize) -> ExprRet<'vir> { self.vcx.mk_local_ex(self.mk_local(local, version)) } - fn mk_phi( - &self, - idx: usize, - ) -> &'vir str { + fn mk_phi(&self, idx: usize) -> &'vir str { vir::vir_format!(self.vcx, "_{}_phi_{}", self.encoding_depth, idx) } @@ -354,36 +387,34 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> tuple_ref.mk_elem(self.vcx, self.vcx.mk_local_ex(self.mk_phi(idx)), elem_idx) } - fn bump_version( - &mut self, - update: &mut Update<'vir>, - local: mir::Local, - expr: ExprRet<'vir>, - ) { + fn bump_version(&mut self, update: &mut Update<'vir>, local: mir::Local, expr: ExprRet<'vir>) { let new_version = self.version_ctr.get(&local).copied().unwrap_or(0usize); self.version_ctr.insert(local, new_version + 1); - update.binds.push(UpdateBind::Local(local, new_version, expr)); + update + .binds + .push(UpdateBind::Local(local, new_version, expr)); update.versions.insert(local, new_version); } - fn reify_binds( - &self, - update: Update<'vir>, - expr: ExprRet<'vir>, - ) -> ExprRet<'vir> { - update.binds.iter() - .rfold(expr, |expr, bind| match bind { - UpdateBind::Local(local, ver, val) => self.vcx.alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { - name: self.mk_local(*local, *ver), - val, - expr, - }))), - UpdateBind::Phi(idx, val) => self.vcx.alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { - name: self.mk_phi(*idx), - val, - expr, - }))), - }) + fn reify_binds(&self, update: Update<'vir>, expr: ExprRet<'vir>) -> ExprRet<'vir> { + update.binds.iter().rfold(expr, |expr, bind| match bind { + UpdateBind::Local(local, ver, val) => { + self.vcx + .alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { + name: self.mk_local(*local, *ver), + val, + expr, + }))) + } + UpdateBind::Phi(idx, val) => { + self.vcx + .alloc(ExprRetData::Let(self.vcx.alloc(vir::LetGenData { + name: self.mk_phi(*idx), + val, + expr, + }))) + } + }) } fn reify_branch( @@ -393,30 +424,40 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> curr_ver: &HashMap, update: Update<'vir>, ) -> ExprRet<'vir> { - let tuple_args = mod_locals.iter().map(|local| self.mk_local_ex( - *local, - update.versions.get(local).copied().unwrap_or_else(|| { - // TODO: remove (debug) - if !curr_ver.contains_key(&local) { - tracing::error!("unknown version of local! {}", local.as_usize()); - return 0xff - } - curr_ver[local] - }), - )).collect::>(); - self.reify_binds( - update, - tuple_ref.mk_cons(self.vcx, &tuple_args), - ) + let tuple_args = mod_locals + .iter() + .map(|local| { + self.mk_local_ex( + *local, + update.versions.get(local).copied().unwrap_or_else(|| { + // TODO: remove (debug) + if !curr_ver.contains_key(&local) { + tracing::error!("unknown version of local! {}", local.as_usize()); + return 0xff; + } + curr_ver[local] + }), + ) + }) + .collect::>(); + self.reify_binds(update, tuple_ref.mk_cons(self.vcx, &tuple_args)) } fn encode_body(&mut self) -> ExprRet<'vir> { - let end_blocks = self.body.basic_blocks.reverse_postorder() + let end_blocks = self + .body + .basic_blocks + .reverse_postorder() .iter() - .filter(|bb| matches!( - self.body[**bb].terminator, - Some(mir::Terminator { kind: mir::TerminatorKind::Return, .. }), - )) + .filter(|bb| { + matches!( + self.body[**bb].terminator, + Some(mir::Terminator { + kind: mir::TerminatorKind::Return, + .. + }), + ) + }) .collect::>(); assert!(end_blocks.len() > 0, "no Return block found"); assert!(end_blocks.len() < 2, "multiple Return blocks found"); @@ -429,19 +470,16 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> vir::vir_format!(self.vcx, "pure in _{local}"), Box::new(move |_vcx, lctx: ExprInput<'vir>| lctx.1[local - 1]), )); - init.binds.push(UpdateBind::Local(local.into(), 0, local_ex)); + init.binds + .push(UpdateBind::Local(local.into(), 0, local_ex)); init.versions.insert(local.into(), 0); } - let update = self.encode_cfg( - &init.versions, - mir::START_BLOCK, - *end_block, - ); + let update = self.encode_cfg(&init.versions, mir::START_BLOCK, *end_block); let res = init.merge(update); let ret_version = res.versions.get(&mir::RETURN_PLACE).copied().unwrap_or(0); - + self.reify_binds(res, self.mk_local_ex(mir::RETURN_PLACE, ret_version)) } @@ -451,7 +489,8 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> start: mir::BasicBlock, end: mir::BasicBlock, ) -> mir::BasicBlock { - dominators.dominators(end) + dominators + .dominators(end) .take_while(|bb| *bb != start) .last() .unwrap() @@ -467,7 +506,9 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> // walk block statements first let mut new_curr_ver = curr_ver.clone(); - let stmt_update = self.body[start].statements.iter() + let stmt_update = self.body[start] + .statements + .iter() .fold(Update::new(), |update, stmt| { let newer = self.encode_stmt(&new_curr_ver, stmt); newer.add_to_map(&mut new_curr_ver); @@ -490,22 +531,26 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> mir::TerminatorKind::SwitchInt { discr, targets } => { // encode the discriminant operand let discr_expr = self.encode_operand(&new_curr_ver, discr); - let discr_ty_out = self.deps.require_ref::( - discr.ty(self.body, self.vcx.tcx), - ).unwrap(); + let discr_ty_out = self + .deps + .require_ref::(discr.ty(self.body, self.vcx.tcx)) + .unwrap(); // find earliest join point `join` let join = self.find_join_point(dominators, start, end); // walk `start` -> `targets[i]` -> `join` for each target // TODO: indexvec? - let mut updates = targets.all_targets().iter() + let mut updates = targets + .all_targets() + .iter() .map(|target| self.encode_cfg(&new_curr_ver, *target, join)) .collect::>(); // find locals updated in any of the results, which were also // defined before the branch - let mut mod_locals = updates.iter() + let mut mod_locals = updates + .iter() .map(|update| update.versions.keys()) .flatten() .filter(|local| new_curr_ver.contains_key(&local)) @@ -515,24 +560,33 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> mod_locals.dedup(); // for each branch, create a Viper tuple of the updated locals - let tuple_ref = self.deps.require_ref::( - mod_locals.len(), - ).unwrap(); + let tuple_ref = self + .deps + .require_ref::(mod_locals.len()) + .unwrap(); let otherwise_update = updates.pop().unwrap(); - let phi_expr = targets.iter() - .zip(updates.into_iter()) - .fold( - self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, otherwise_update), - |expr, ((cond_val, target), branch_update)| self.vcx.alloc(ExprRetData::Ternary(self.vcx.alloc(vir::TernaryGenData { - cond: self.vcx.alloc(ExprRetData::BinOp(self.vcx.alloc(vir::BinOpGenData { - kind: vir::BinOpKind::CmpEq, - lhs: discr_expr, - rhs: discr_ty_out.expr_from_u128(cond_val).lift(), - }))), - then: self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, branch_update), - else_: expr, - }))), - ); + let phi_expr = targets.iter().zip(updates.into_iter()).fold( + self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, otherwise_update), + |expr, ((cond_val, target), branch_update)| { + self.vcx + .alloc(ExprRetData::Ternary(self.vcx.alloc(vir::TernaryGenData { + cond: self.vcx.alloc(ExprRetData::BinOp(self.vcx.alloc( + vir::BinOpGenData { + kind: vir::BinOpKind::CmpEq, + lhs: discr_expr, + rhs: discr_ty_out.expr_from_u128(cond_val).lift(), + }, + ))), + then: self.reify_branch( + &tuple_ref, + &mod_locals, + &new_curr_ver, + branch_update, + ), + else_: expr, + }))) + }, + ); // assign tuple into a `phi` variable let phi_idx = self.phi_ctr; @@ -578,84 +632,101 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> TyKind::FnDef(def_id, arg_tys) => { // TODO: this attribute extraction should be done elsewhere? let attrs = self.vcx.tcx.get_attrs_unchecked(*def_id); - let normal_attrs = attrs.iter() + let normal_attrs = attrs + .iter() .filter(|attr| !attr.is_doc_comment()) - .map(|attr| attr.get_normal_item()).collect::>(); - normal_attrs.iter() - .filter(|item| item.path.segments.len() == 2 - && item.path.segments[0].ident.as_str() == "prusti" - && item.path.segments[1].ident.as_str() == "builtin") + .map(|attr| attr.get_normal_item()) + .collect::>(); + normal_attrs + .iter() + .filter(|item| { + item.path.segments.len() == 2 + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "builtin" + }) .for_each(|attr| match &attr.args { prusti_rustc_interface::ast::AttrArgs::Eq( _, prusti_rustc_interface::ast::AttrArgsEq::Hir(lit), ) => { assert!(builtin.is_none(), "multiple prusti::builtin"); - builtin = Some((match lit.symbol.as_str() { - "forall" => PrustiBuiltin::Forall, - _ => panic!("illegal prusti::builtin"), - }, arg_tys)); + builtin = Some(( + match lit.symbol.as_str() { + "forall" => PrustiBuiltin::Forall, + _ => panic!("illegal prusti::builtin"), + }, + arg_tys, + )); } _ => panic!("illegal prusti::builtin"), }); - - let is_pure = normal_attrs.iter().any(|item| + let is_pure = normal_attrs.iter().any(|item| { item.path.segments.len() == 2 - && item.path.segments[0].ident.as_str() == "prusti" - && item.path.segments[1].ident.as_str() == "pure" - ); + && item.path.segments[0].ident.as_str() == "prusti" + && item.path.segments[1].ident.as_str() == "pure" + }); - // TODO: detect snapshot_equality properly - let is_snapshot_eq = self.vcx.tcx.opt_item_name(*def_id).map(|e| e.as_str() == "snapshot_equality") == Some(true) + // TODO: detect snapshot_equality properly + let is_snapshot_eq = self + .vcx + .tcx + .opt_item_name(*def_id) + .map(|e| e.as_str() == "snapshot_equality") + == Some(true) && self.vcx.tcx.crate_name(def_id.krate).as_str() == "prusti_contracts"; let func_call = if is_pure { assert!(builtin.is_none(), "Function is pure and builtin?"); - let pure_func = self.deps.require_ref::(*def_id).unwrap().function_name; - - let encoded_args = args.iter() + let pure_func = self + .deps + .require_ref::(*def_id) + .unwrap() + .function_name; + + let encoded_args = args + .iter() .map(|oper| self.encode_operand(&new_curr_ver, oper)) .collect::>(); - - let func_call = self.vcx.mk_func_app(pure_func, &encoded_args); - Some(func_call) - } + let func_call = self.vcx.mk_func_app(pure_func, &encoded_args); - else if is_snapshot_eq { - assert!(builtin.is_none(), "Function is snapshot_equality and builtin?"); - let encoded_args = args.iter() + Some(func_call) + } else if is_snapshot_eq { + assert!( + builtin.is_none(), + "Function is snapshot_equality and builtin?" + ); + let encoded_args = args + .iter() .map(|oper| self.encode_operand(&new_curr_ver, oper)) .collect::>(); assert_eq!(encoded_args.len(), 2); - - let eq_expr = self.vcx.alloc(vir::ExprGenData::BinOp(self.vcx.alloc(vir::BinOpGenData { - kind: vir::BinOpKind::CmpEq, - lhs: encoded_args[0], - rhs: encoded_args[1], - }))); - + let eq_expr = self.vcx.alloc(vir::ExprGenData::BinOp(self.vcx.alloc( + vir::BinOpGenData { + kind: vir::BinOpKind::CmpEq, + lhs: encoded_args[0], + rhs: encoded_args[1], + }, + ))); // TODO: type encoder Some(self.vcx.mk_func_app("s_Bool_cons", &[eq_expr])) - } - else { + } else { None }; - if let Some(func_call) = func_call { let mut term_update = Update::new(); assert!(destination.projection.is_empty()); self.bump_version(&mut term_update, destination.local, func_call); term_update.add_to_map(&mut new_curr_ver); - + // walk rest of CFG let end_update = self.encode_cfg(&new_curr_ver, target.unwrap(), end); - + return stmt_update.merge(term_update).merge(end_update); } } @@ -677,18 +748,26 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> _ => panic!("illegal prusti::forall"), }; - let qvars = self.vcx.alloc_slice(&qvar_tys.iter() - .enumerate() - .map(|(idx, qvar_ty)| { - let ty_out = self.deps.require_ref::( - qvar_ty, - ).unwrap(); - self.vcx.mk_local_decl( - vir::vir_format!(self.vcx, "qvar_{}_{idx}", self.encoding_depth), - ty_out.snapshot, - ) - }) - .collect::>()); + let qvars = self.vcx.alloc_slice( + &qvar_tys + .iter() + .enumerate() + .map(|(idx, qvar_ty)| { + let ty_out = self + .deps + .require_ref::(qvar_ty) + .unwrap(); + self.vcx.mk_local_decl( + vir::vir_format!( + self.vcx, + "qvar_{}_{idx}", + self.encoding_depth + ), + ty_out.snapshot, + ) + }) + .collect::>(), + ); //let qvar_tuple_ref = self.deps.require_ref::( // qvars.len(), //).unwrap(); @@ -708,49 +787,49 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> reify_args.push(unsafe { std::mem::transmute(self.encode_operand(&new_curr_ver, &args[1])) }); - reify_args.extend((0..qvars.len()) - .map(|idx| self.vcx.mk_local_ex( - vir::vir_format!(self.vcx, "qvar_{}_{idx}", self.encoding_depth), - ))); + reify_args.extend((0..qvars.len()).map(|idx| { + self.vcx.mk_local_ex(vir::vir_format!( + self.vcx, + "qvar_{}_{idx}", + self.encoding_depth + )) + })); // TODO: recursively invoke MirPure encoder to encode // the body of the closure; pass the closure as the // variable to use, then closure access = tuple access // (then hope to optimise this away later ...?) use vir::Reify; - let body = self.deps.require_local::( - MirPureEncoderTask { + let body = self + .deps + .require_local::(MirPureEncoderTask { encoding_depth: self.encoding_depth + 1, parent_def_id: cl_def_id, promoted: None, param_env: self.vcx.tcx.param_env(cl_def_id), substs: ty::List::identity_for_item(self.vcx.tcx, cl_def_id), - } - ).unwrap().expr - // arguments to the closure are - // - the closure itself - // - the qvars - .reify(self.vcx, ( - cl_def_id, - self.vcx.alloc_slice(&reify_args), - )) + }) + .unwrap() + .expr + // arguments to the closure are + // - the closure itself + // - the qvars + .reify(self.vcx, (cl_def_id, self.vcx.alloc_slice(&reify_args))) .lift(); // TODO: use type encoder - let body = - self.vcx.mk_func_app( - "s_Bool_val", - &[body], - ); + let body = self.vcx.mk_func_app("s_Bool_val", &[body]); // TODO: use type encoder let forall = self.vcx.mk_func_app( "s_Bool_cons", - &[self.vcx.alloc(ExprRetData::Forall(self.vcx.alloc(vir::ForallGenData { - qvars, - triggers: &[], // TODO - body, - })))], + &[self.vcx.alloc(ExprRetData::Forall(self.vcx.alloc( + vir::ForallGenData { + qvars, + triggers: &[], // TODO + body, + }, + )))], ); let mut term_update = Update::new(); @@ -784,11 +863,11 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> let new_version = self.version_ctr.get(local).copied().unwrap_or(0usize); self.version_ctr.insert(*local, new_version + 1); update.versions.insert(*local, new_version); - }, + } mir::StatementKind::StorageDead(..) | mir::StatementKind::FakeRead(..) - | mir::StatementKind::AscribeUserType(..) - | mir::StatementKind::PlaceMention(..)=> {}, // nop + | mir::StatementKind::AscribeUserType(..) + | mir::StatementKind::PlaceMention(..) => {} // nop mir::StatementKind::Assign(box (dest, rvalue)) => { assert!(dest.projection.is_empty()); let expr = self.encode_rvalue(curr_ver, rvalue); @@ -818,77 +897,105 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> // Len // Cast mir::Rvalue::BinaryOp(op, box (l, r)) => { - let ty_l = self.deps.require_ref::( - l.ty(self.body, self.vcx.tcx), - ).unwrap().to_primitive.unwrap(); - let ty_r = self.deps.require_ref::( - r.ty(self.body, self.vcx.tcx), - ).unwrap().to_primitive.unwrap(); - let ty_rvalue = self.deps.require_ref::( - rvalue.ty(self.body, self.vcx.tcx), - ).unwrap().from_primitive.unwrap(); + let ty_l = self + .deps + .require_ref::(l.ty(self.body, self.vcx.tcx)) + .unwrap() + .to_primitive + .unwrap(); + let ty_r = self + .deps + .require_ref::(r.ty(self.body, self.vcx.tcx)) + .unwrap() + .to_primitive + .unwrap(); + let ty_rvalue = self + .deps + .require_ref::(rvalue.ty(self.body, self.vcx.tcx)) + .unwrap() + .from_primitive + .unwrap(); self.vcx.mk_func_app( ty_rvalue, - &[self.vcx.alloc(ExprRetData::BinOp(self.vcx.alloc(vir::BinOpGenData { - kind: op.into(), - lhs: self.vcx.mk_func_app( - ty_l, - &[self.encode_operand(curr_ver, l)], - ), - rhs: self.vcx.mk_func_app( - ty_r, - &[self.encode_operand(curr_ver, r)], - ), - })))], + &[self.vcx.alloc(ExprRetData::BinOp( + self.vcx.alloc(vir::BinOpGenData { + kind: op.into(), + lhs: self + .vcx + .mk_func_app(ty_l, &[self.encode_operand(curr_ver, l)]), + rhs: self + .vcx + .mk_func_app(ty_r, &[self.encode_operand(curr_ver, r)]), + }), + ))], ) } // CheckedBinaryOp // NullaryOp mir::Rvalue::UnaryOp(op, expr) => { - let ty_expr = self.deps.require_ref::( - expr.ty(self.body, self.vcx.tcx), - ).unwrap().to_primitive.unwrap(); - let ty_rvalue = self.deps.require_ref::( - rvalue.ty(self.body, self.vcx.tcx), - ).unwrap().from_primitive.unwrap(); + let ty_expr = self + .deps + .require_ref::(expr.ty(self.body, self.vcx.tcx)) + .unwrap() + .to_primitive + .unwrap(); + let ty_rvalue = self + .deps + .require_ref::(rvalue.ty(self.body, self.vcx.tcx)) + .unwrap() + .from_primitive + .unwrap(); self.vcx.mk_func_app( ty_rvalue, - &[self.vcx.alloc(ExprRetData::UnOp(self.vcx.alloc(vir::UnOpGenData { - kind: op.into(), - expr: self.vcx.mk_func_app( - ty_expr, - &[self.encode_operand(curr_ver, expr)], - ), - })))], + &[self.vcx.alloc(ExprRetData::UnOp( + self.vcx.alloc(vir::UnOpGenData { + kind: op.into(), + expr: self + .vcx + .mk_func_app(ty_expr, &[self.encode_operand(curr_ver, expr)]), + }), + ))], ) } // Discriminant mir::Rvalue::Aggregate(box kind, fields) => match kind { mir::AggregateKind::Tuple if fields.len() == 0 => - // TODO: why is this not a constant? - self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Tuple0_cons()"), - )), + // TODO: why is this not a constant? + { + self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Tuple0_cons()" + ))) + } mir::AggregateKind::Closure(..) => { // TODO: only when this is a spec closure? - let tuple_ref = self.deps.require_ref::( - fields.len(), - ).unwrap(); - tuple_ref.mk_cons(self.vcx, &fields.iter() - .map(|field| self.encode_operand(curr_ver, field)) - .collect::>()) + let tuple_ref = self + .deps + .require_ref::(fields.len()) + .unwrap(); + tuple_ref.mk_cons( + self.vcx, + &fields + .iter() + .map(|field| self.encode_operand(curr_ver, field)) + .collect::>(), + ) } _ => todo!("Unsupported Rvalue::AggregateKind: {kind:?}"), - } + }, mir::Rvalue::CheckedBinaryOp(binop, box (l, r)) => { - let binop_function = self.deps.require_ref::( - crate::encoders::MirBuiltinEncoderTask::CheckedBinOp( - *binop, - l.ty(self.body, self.vcx.tcx), // TODO: ? - ), - ).unwrap().name; + let binop_function = self + .deps + .require_ref::( + crate::encoders::MirBuiltinEncoderTask::CheckedBinOp( + *binop, + l.ty(self.body, self.vcx.tcx), // TODO: ? + ), + ) + .unwrap() + .name; self.vcx.mk_func_app( binop_function, &[ @@ -912,35 +1019,46 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> operand: &mir::Operand<'vir>, ) -> ExprRet<'vir> { match operand { - mir::Operand::Copy(place) - | mir::Operand::Move(place) => self.encode_place(curr_ver, place), + mir::Operand::Copy(place) | mir::Operand::Move(place) => { + self.encode_place(curr_ver, place) + } mir::Operand::Constant(box constant) => { // TODO: duplicated from mir_impure! match constant.literal { - mir::ConstantKind::Val(const_val, const_ty) => { - match const_ty.kind() { - ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Tuple0_cons()"), - )), - ty::TyKind::Int(int_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Int_{}_cons({})", int_ty.name_str(), scalar_val.try_to_int(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Uint(uint_ty) => { - let scalar_val = const_val.try_to_scalar_int().unwrap(); - self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Uint_{}_cons({})", uint_ty.name_str(), scalar_val.try_to_uint(scalar_val.size()).unwrap()), - )) - } - ty::TyKind::Bool => self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "s_Bool_cons({})", const_val.try_to_bool().unwrap()), - )), - unsupported_ty => todo!("unsupported constant literal type: {unsupported_ty:?}"), + mir::ConstantKind::Val(const_val, const_ty) => match const_ty.kind() { + ty::TyKind::Tuple(tys) if tys.len() == 0 => self.vcx.alloc( + ExprRetData::Todo(vir::vir_format!(self.vcx, "s_Tuple0_cons()")), + ), + ty::TyKind::Int(int_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Int_{}_cons({})", + int_ty.name_str(), + scalar_val.try_to_int(scalar_val.size()).unwrap() + ))) } + ty::TyKind::Uint(uint_ty) => { + let scalar_val = const_val.try_to_scalar_int().unwrap(); + self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Uint_{}_cons({})", + uint_ty.name_str(), + scalar_val.try_to_uint(scalar_val.size()).unwrap() + ))) + } + ty::TyKind::Bool => self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "s_Bool_cons({})", + const_val.try_to_bool().unwrap() + ))), + unsupported_ty => { + todo!("unsupported constant literal type: {unsupported_ty:?}") + } + }, + unsupported_literal => { + todo!("unsupported constant literal: {unsupported_literal:?}") } - unsupported_literal => todo!("unsupported constant literal: {unsupported_literal:?}"), } } } @@ -954,13 +1072,15 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> // TODO: remove (debug) if !curr_ver.contains_key(&place.local) { tracing::error!("unknown version of local! {}", place.local.as_usize()); - return self.vcx.alloc(ExprRetData::Todo( - vir::vir_format!(self.vcx, "unknown_version_{}", place.local.as_usize()), - )); + return self.vcx.alloc(ExprRetData::Todo(vir::vir_format!( + self.vcx, + "unknown_version_{}", + place.local.as_usize() + ))); } let local = self.mk_local_ex(place.local, curr_ver[&place.local]); - let mut partent_ty = self.body.local_decls[place.local].ty; + let mut partent_ty = self.body.local_decls[place.local].ty; let mut expr = local; for elem in place.projection { @@ -970,37 +1090,38 @@ impl<'vir, 'enc> Encoder<'vir, 'enc> expr } - - fn encode_place_element(&mut self, parent_ty: ty::Ty<'vir>, elem: mir::PlaceElem<'vir>, expr: ExprRet<'vir>) -> (ty::Ty<'vir>, ExprRet<'vir>) { + fn encode_place_element( + &mut self, + parent_ty: ty::Ty<'vir>, + elem: mir::PlaceElem<'vir>, + expr: ExprRet<'vir>, + ) -> (ty::Ty<'vir>, ExprRet<'vir>) { let parent_ty = parent_ty.peel_refs(); - match elem { - mir::ProjectionElem::Deref => { - (parent_ty, expr) - } + match elem { + mir::ProjectionElem::Deref => (parent_ty, expr), mir::ProjectionElem::Field(field_idx, field_ty) => { - let field_idx= field_idx.as_usize(); + let field_idx = field_idx.as_usize(); match parent_ty.kind() { TyKind::Closure(_def_id, args) => { let upvars = args.as_closure().upvar_tys().collect::>().len(); - let tuple_ref = self.deps.require_ref::( - upvars, - ).unwrap(); - let tup = tuple_ref.mk_elem(self.vcx, expr, field_idx); + let tuple_ref = self.deps.require_ref::(upvars).unwrap(); + let tup = tuple_ref.mk_elem(self.vcx, expr, field_idx); (field_ty, tup) } _ => { - let local_encoded_ty = self.deps.require_ref::(parent_ty).unwrap(); + let local_encoded_ty = + self.deps.require_ref::(parent_ty).unwrap(); let struct_like = local_encoded_ty.expect_structlike(); let proj = struct_like.field_read[field_idx]; - + let app = self.vcx.mk_func_app(proj, self.vcx.alloc_slice(&[expr])); (field_ty, app) } } - } + } _ => todo!("Unsupported ProjectionElem {:?}", elem), } } diff --git a/prusti-encoder/src/encoders/mir_pure_function.rs b/prusti-encoder/src/encoders/mir_pure_function.rs index d977550fe2b..e73c72894d7 100644 --- a/prusti-encoder/src/encoders/mir_pure_function.rs +++ b/prusti-encoder/src/encoders/mir_pure_function.rs @@ -1,8 +1,11 @@ -use prusti_rustc_interface::{middle::{mir, ty}, span::def_id::DefId}; +use prusti_rustc_interface::{ + middle::{mir, ty}, + span::def_id::DefId, +}; +use std::cell::RefCell; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use vir::Reify; -use std::cell::RefCell; use crate::encoders::{ MirPureEncoder, MirPureEncoderTask, SpecEncoder, SpecEncoderTask, TypeEncoder, @@ -76,11 +79,14 @@ impl TaskEncoder for MirFunctionEncoder { deps.emit_output_ref::(*task_key, MirFunctionEncoderOutputRef { function_name }); let local_def_id = def_id.expect_local(); - let body = vcx.body.borrow_mut().get_impure_fn_body_identity(local_def_id); - - let spec = deps.require_local::( - (def_id, true) - ).unwrap(); + let body = vcx + .body + .borrow_mut() + .get_impure_fn_body_identity(local_def_id); + + let spec = deps + .require_local::((def_id, true)) + .unwrap(); let mut func_args = Vec::with_capacity(body.arg_count); diff --git a/prusti-encoder/src/encoders/mod.rs b/prusti-encoder/src/encoders/mod.rs index 0acdcb605aa..6afd9a4fc81 100644 --- a/prusti-encoder/src/encoders/mod.rs +++ b/prusti-encoder/src/encoders/mod.rs @@ -9,33 +9,13 @@ mod mir_pure_function; pub mod pure; pub mod local_def; -pub use generic::{ - GenericEncoder, -}; -pub use mir_builtin::{ - MirBuiltinEncoder, - MirBuiltinEncoderTask, -}; +pub use generic::GenericEncoder; +pub use mir_builtin::{MirBuiltinEncoder, MirBuiltinEncoderTask}; pub use mir_impure::MirImpureEncoder; -pub use mir_pure::{ - MirPureEncoder, - MirPureEncoderTask, -}; -pub use spec::{ - SpecEncoder, - SpecEncoderOutput, - SpecEncoderTask, -}; +pub use mir_pure::{MirPureEncoder, MirPureEncoderTask}; pub(super) use spec::{init_def_spec, with_def_spec}; -pub use typ::{ - TypeEncoder, - TypeEncoderOutputRef, - TypeEncoderOutput, -}; -pub use viper_tuple::{ - ViperTupleEncoder, - ViperTupleEncoderOutputRef, - ViperTupleEncoderOutput, -}; +pub use spec::{SpecEncoder, SpecEncoderOutput, SpecEncoderTask}; +pub use typ::{TypeEncoder, TypeEncoderOutput, TypeEncoderOutputRef}; +pub use viper_tuple::{ViperTupleEncoder, ViperTupleEncoderOutput, ViperTupleEncoderOutputRef}; pub use mir_pure_function::MirFunctionEncoder; diff --git a/prusti-encoder/src/encoders/pure/spec.rs b/prusti-encoder/src/encoders/pure/spec.rs index 8c3a88a77d1..e59500b51cf 100644 --- a/prusti-encoder/src/encoders/pure/spec.rs +++ b/prusti-encoder/src/encoders/pure/spec.rs @@ -1,8 +1,11 @@ -use prusti_rustc_interface::{middle::{mir, ty}, span::def_id::DefId}; +use prusti_rustc_interface::{ + middle::{mir, ty}, + span::def_id::DefId, +}; +use std::cell::RefCell; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use vir::Reify; -use std::cell::RefCell; use crate::encoders::MirPureEncoder; pub struct MirSpecEncoder; @@ -20,8 +23,8 @@ thread_local! { impl TaskEncoder for MirSpecEncoder { type TaskDescription<'vir> = ( - DefId, // The function annotated with specs - bool, // If to encode as pure or not + DefId, // The function annotated with specs + bool, // If to encode as pure or not ); type OutputFullLocal<'vir> = MirSpecEncoderOutput<'vir>; @@ -61,21 +64,21 @@ impl TaskEncoder for MirSpecEncoder { let (def_id, pure) = *task_key; deps.emit_output_ref::(*task_key, ()); - let local_defs = deps.require_local::( - def_id, - ).unwrap(); - let specs = deps.require_local::( - crate::encoders::SpecEncoderTask { + let local_defs = deps + .require_local::(def_id) + .unwrap(); + let specs = deps + .require_local::(crate::encoders::SpecEncoderTask { def_id, - } - ).unwrap(); + }) + .unwrap(); vir::with_vcx(|vcx| { let mut pre_args: Vec<_> = (1..=local_defs.arg_count) .map(mir::Local::from) .map(|local| { if pure { - local_defs.locals[local].local_ex + local_defs.locals[local].local_ex } else { local_defs.locals[local].impure_snap } @@ -90,55 +93,69 @@ impl TaskEncoder for MirSpecEncoder { let pre_args = vcx.alloc_slice(&pre_args); - let to_bool = deps.require_ref::( - vcx.tcx.types.bool, - ).unwrap().to_primitive.unwrap(); - - let pres = specs.pres.iter().map(|spec_def_id| { - let expr = deps.require_local::( - crate::encoders::MirPureEncoderTask { - encoding_depth: 0, - parent_def_id: *spec_def_id, - promoted: None, - param_env: vcx.tcx.param_env(spec_def_id), - substs: ty::List::identity_for_item(vcx.tcx, *spec_def_id), - } - ).unwrap().expr; - let expr = vcx.mk_func_app( - to_bool, - &[expr], - ); - expr.reify(vcx, (*spec_def_id, &pre_args.split_last().unwrap().1)) - }).collect::>>(); + let to_bool = deps + .require_ref::(vcx.tcx.types.bool) + .unwrap() + .to_primitive + .unwrap(); + + let pres = specs + .pres + .iter() + .map(|spec_def_id| { + let expr = deps + .require_local::( + crate::encoders::MirPureEncoderTask { + encoding_depth: 0, + parent_def_id: *spec_def_id, + promoted: None, + param_env: vcx.tcx.param_env(spec_def_id), + substs: ty::List::identity_for_item(vcx.tcx, *spec_def_id), + }, + ) + .unwrap() + .expr; + let expr = vcx.mk_func_app(to_bool, &[expr]); + expr.reify(vcx, (*spec_def_id, &pre_args.split_last().unwrap().1)) + }) + .collect::>>(); let post_args = if pure { pre_args } else { - let post_args: Vec<_> = pre_args.iter().enumerate().map(|(idx, arg)| { - if idx == pre_args.len() - 1 { - arg - } else { - vcx.alloc(vir::ExprData::Old(arg)) - } - }).collect(); + let post_args: Vec<_> = pre_args + .iter() + .enumerate() + .map(|(idx, arg)| { + if idx == pre_args.len() - 1 { + arg + } else { + vcx.alloc(vir::ExprData::Old(arg)) + } + }) + .collect(); vcx.alloc_slice(&post_args) }; - let posts = specs.posts.iter().map(|spec_def_id| { - let expr = deps.require_local::( - crate::encoders::MirPureEncoderTask { - encoding_depth: 0, - parent_def_id: *spec_def_id, - promoted: None, - param_env: vcx.tcx.param_env(spec_def_id), - substs: ty::List::identity_for_item(vcx.tcx, *spec_def_id), - } - ).unwrap().expr; - let expr = vcx.mk_func_app( - to_bool, - &[expr], - ); - expr.reify(vcx, (*spec_def_id, post_args)) - }).collect::>>(); + let posts = specs + .posts + .iter() + .map(|spec_def_id| { + let expr = deps + .require_local::( + crate::encoders::MirPureEncoderTask { + encoding_depth: 0, + parent_def_id: *spec_def_id, + promoted: None, + param_env: vcx.tcx.param_env(spec_def_id), + substs: ty::List::identity_for_item(vcx.tcx, *spec_def_id), + }, + ) + .unwrap() + .expr; + let expr = vcx.mk_func_app(to_bool, &[expr]); + expr.reify(vcx, (*spec_def_id, post_args)) + }) + .collect::>>(); let data = MirSpecEncoderOutput { pres, posts, diff --git a/prusti-encoder/src/encoders/spec.rs b/prusti-encoder/src/encoders/spec.rs index c6da14c0db3..a36f8237284 100644 --- a/prusti-encoder/src/encoders/spec.rs +++ b/prusti-encoder/src/encoders/spec.rs @@ -1,12 +1,6 @@ -use prusti_rustc_interface::{ - //middle::{mir, ty}, - span::def_id::DefId, -}; use prusti_interface::specs::typed::DefSpecificationMap; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use prusti_rustc_interface::span::def_id::DefId; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct SpecEncoder; @@ -42,7 +36,7 @@ pub fn init_def_spec(def_spec: DefSpecificationMap) { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct SpecEncoderTask { pub def_id: DefId, // ID of the function - // TODO: substs here? + // TODO: substs here? } impl TaskEncoder for SpecEncoder { @@ -57,7 +51,8 @@ impl TaskEncoder for SpecEncoder { type EncodingError = SpecEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, SpecEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, SpecEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -78,13 +73,16 @@ impl TaskEncoder for SpecEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { deps.emit_output_ref::(task_key.clone(), ()); vir::with_vcx(|vcx| { with_def_spec(|def_spec| { @@ -98,7 +96,7 @@ impl TaskEncoder for SpecEncoder { .and_then(|specs| specs.base_spec.posts.expect_empty_or_inherent()) .map(|specs| vcx.alloc_slice(specs)) .unwrap_or_default(); - Ok((SpecEncoderOutput { pres, posts, }, () )) + Ok((SpecEncoderOutput { pres, posts }, ())) }) }) } diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index 48b8a2ba5bd..99423544bf4 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -1,9 +1,6 @@ use prusti_rustc_interface::middle::ty; use rustc_type_ir::sty::TyKind; -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct TypeEncoder; @@ -54,18 +51,22 @@ impl<'vir> TypeEncoderOutputRef<'vir> { // TODO: not great: store the TyKind as well? // or should this be a different task for TypeEncoder? match self.snapshot_name { - "s_Bool" => vir::with_vcx(|vcx| vcx.mk_func_app( + "s_Bool" => vir::with_vcx(|vcx| { + vcx.mk_func_app( self.from_primitive.unwrap(), &[vcx.alloc(vir::ExprData::Const( vcx.alloc(vir::ConstData::Bool(val != 0)), ))], - )), - name if name.starts_with("s_Int_") || name.starts_with("s_Uint_") => vir::with_vcx(|vcx| vcx.mk_func_app( - self.from_primitive.unwrap(), - &[vcx.alloc(vir::ExprData::Const( - vcx.alloc(vir::ConstData::Int(val)), - ))], - )), + ) + }), + name if name.starts_with("s_Int_") || name.starts_with("s_Uint_") => { + vir::with_vcx(|vcx| { + vcx.mk_func_app( + self.from_primitive.unwrap(), + &[vcx.alloc(vir::ExprData::Const(vcx.alloc(vir::ConstData::Int(val))))], + ) + }) + } k => todo!("unsupported type in expr_from_u128 {k:?}"), } } @@ -99,7 +100,8 @@ impl TaskEncoder for TypeEncoder { type EncodingError = TypeEncoderError; fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, TypeEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, TypeEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -146,13 +148,16 @@ impl TaskEncoder for TypeEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { fn mk_unreachable<'vir>( vcx: &'vir vir::VirCtxt, snapshot_name: &'vir str, @@ -173,8 +178,8 @@ impl TaskEncoder for TypeEncoder { field_name: &'vir str, ) -> vir::Predicate<'vir> { let predicate_body = vcx.alloc(vir::ExprData::AccField(vcx.alloc(vir::AccFieldData { - recv: vcx.mk_local_ex("self_p"), - field: field_name, + recv: vcx.mk_local_ex("self_p"), + field: field_name, }))); vir::vir_predicate! { vcx; predicate [predicate_name](self_p: Ref) { [predicate_body] } } } @@ -269,27 +274,20 @@ impl TaskEncoder for TypeEncoder { ) -> vir::Function<'vir> { let pred_app = vcx.alloc(vir::PredicateAppData { target: predicate_name, - args: vcx.alloc_slice(&[ - vcx.mk_local_ex("self"), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_ex("self")]), }); vcx.alloc(vir::FunctionData { name: vir::vir_format!(vcx, "{predicate_name}_snap"), - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self", &vir::TypeData::Ref)]), ret: snapshot_ty, - pres: vcx.alloc_slice(&[ - vcx.alloc(vir::ExprData::PredicateApp(pred_app)), - ]), + pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), posts: &[], - expr: field_name.map(|field_name| vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: vcx.alloc(vir::ExprData::Field( - vcx.mk_local_ex("self"), - field_name, - )), - })))), + expr: field_name.map(|field_name| { + vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: vcx.alloc(vir::ExprData::Field(vcx.mk_local_ex("self"), field_name)), + }))) + }), }) } fn mk_structlike<'vir>( @@ -299,10 +297,13 @@ impl TaskEncoder for TypeEncoder { name_s: &'vir str, name_p: &'vir str, field_ty_out: Vec>, - ) -> Result<::OutputFullLocal<'vir>, ( - ::EncodingError, - Option<::OutputFullDependency<'vir>>, - )> { + ) -> Result< + ::OutputFullLocal<'vir>, + ( + ::EncodingError, + Option<::OutputFullDependency<'vir>>, + ), + > { let mut field_read_names = Vec::new(); let mut field_write_names = Vec::new(); let mut field_projection_p_names = Vec::new(); @@ -315,50 +316,58 @@ impl TaskEncoder for TypeEncoder { let field_write_names = vcx.alloc_slice(&field_write_names); let field_projection_p_names = vcx.alloc_slice(&field_projection_p_names); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: name_s, - predicate_name: name_p, - from_primitive: None, - to_primitive: None, - snapshot: vcx.alloc(vir::TypeData::Domain(name_s)), - function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), - function_snap: vir::vir_format!(vcx, "{name_p}_snap"), - //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), - specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { - field_read: field_read_names, - field_write: field_write_names, - field_projection_p: field_projection_p_names, - }), - method_assign: vir::vir_format!(vcx, "assign_{name_p}"), - }); + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: name_s, + predicate_name: name_p, + from_primitive: None, + to_primitive: None, + snapshot: vcx.alloc(vir::TypeData::Domain(name_s)), + function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), + function_snap: vir::vir_format!(vcx, "{name_p}_snap"), + //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), + specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { + field_read: field_read_names, + field_write: field_write_names, + field_projection_p: field_projection_p_names, + }), + method_assign: vir::vir_format!(vcx, "assign_{name_p}"), + }, + ); let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); let mut funcs: Vec> = vec![]; let mut axioms: Vec> = vec![]; let cons_name = vir::vir_format!(vcx, "{name_s}_cons"); - funcs.push(vcx.alloc(vir::DomainFunctionData { - unique: false, - name: cons_name, - args: vcx.alloc_slice(&field_ty_out.iter() - .map(|field_ty_out| field_ty_out.snapshot) - .collect::>()), - ret: ty_s, - })); + funcs.push( + vcx.alloc(vir::DomainFunctionData { + unique: false, + name: cons_name, + args: vcx.alloc_slice( + &field_ty_out + .iter() + .map(|field_ty_out| field_ty_out.snapshot) + .collect::>(), + ), + ret: ty_s, + }), + ); let mut field_projection_p = Vec::new(); for (idx, ty_out) in field_ty_out.iter().enumerate() { let name_r = vir::vir_format!(vcx, "{name_s}_read_{idx}"); - funcs.push(vir::vir_domain_func! { vcx; function [name_r]([ty_s]): [ty_out.snapshot] }); + funcs.push( + vir::vir_domain_func! { vcx; function [name_r]([ty_s]): [ty_out.snapshot] }, + ); let name_w = vir::vir_format!(vcx, "{name_s}_write_{idx}"); funcs.push(vir::vir_domain_func! { vcx; function [name_w]([ty_s], [ty_out.snapshot]): [ty_s] }); field_projection_p.push(vcx.alloc(vir::FunctionData { name: vir::vir_format!(vcx, "{name_p}_field_{idx}"), - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self", &vir::TypeData::Ref)]), ret: &vir::TypeData::Ref, pres: &[], posts: &[], @@ -370,35 +379,30 @@ impl TaskEncoder for TypeEncoder { for (write_idx, write_ty_out) in field_ty_out.iter().enumerate() { for (read_idx, _read_ty_out) in field_ty_out.iter().enumerate() { axioms.push(vcx.alloc(vir::DomainAxiomData { - name: vir::vir_format!(vcx, "ax_{name_s}_write_{write_idx}_read_{read_idx}"), + name: vir::vir_format!( + vcx, + "ax_{name_s}_write_{write_idx}_read_{read_idx}" + ), expr: if read_idx == write_idx { vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { qvars: vcx.alloc_slice(&[ vcx.mk_local_decl("self", ty_s), vcx.mk_local_decl("val", write_ty_out.snapshot), ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[ - vcx.mk_local_ex("self"), - vcx.mk_local_ex("val"), - ], - )], - ), - ])]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + )], + )])]), body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpEq, lhs: vcx.mk_func_app( vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), &[vcx.mk_func_app( vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[ - vcx.mk_local_ex("self"), - vcx.mk_local_ex("val"), - ], + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], )], ), rhs: vcx.mk_local_ex("val"), @@ -410,28 +414,20 @@ impl TaskEncoder for TypeEncoder { vcx.mk_local_decl("self", ty_s), vcx.mk_local_decl("val", write_ty_out.snapshot), ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[ - vcx.mk_local_ex("self"), - vcx.mk_local_ex("val"), - ], - )], - ), - ])]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + )], + )])]), body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpEq, lhs: vcx.mk_func_app( vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), &[vcx.mk_func_app( vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[ - vcx.mk_local_ex("self"), - vcx.mk_local_ex("val"), - ], + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], )], ), rhs: vcx.mk_func_app( @@ -448,18 +444,21 @@ impl TaskEncoder for TypeEncoder { // constructor { let cons_qvars = vcx.alloc_slice( - &field_ty_out.iter() + &field_ty_out + .iter() .enumerate() - .map(|(idx, field_ty_out)| vcx.mk_local_decl( - vir::vir_format!(vcx, "f{idx}"), - field_ty_out.snapshot, - )) - .collect::>()); - let cons_args = field_ty_out.iter() + .map(|(idx, field_ty_out)| { + vcx.mk_local_decl( + vir::vir_format!(vcx, "f{idx}"), + field_ty_out.snapshot, + ) + }) + .collect::>(), + ); + let cons_args = field_ty_out + .iter() .enumerate() - .map(|(idx, _field_ty_out)| vcx.mk_local_ex( - vir::vir_format!(vcx, "f{idx}"), - )) + .map(|(idx, _field_ty_out)| vcx.mk_local_ex(vir::vir_format!(vcx, "f{idx}"))) .collect::>(); let cons_call = vcx.mk_func_app(cons_name, &cons_args); @@ -486,21 +485,19 @@ impl TaskEncoder for TypeEncoder { &field_ty_out .iter() .enumerate() - .map(|(idx, _field_ty_out)| vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{idx}"), - &[vcx.mk_local_ex("self")], - )) + .map(|(idx, _field_ty_out)| { + vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{idx}"), + &[vcx.mk_local_ex("self")], + ) + }) .collect::>(), ); axioms.push(vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_{name_s}_cons"), expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", ty_s), - ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[ - cons_call_with_reads, - ])]), + qvars: vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call_with_reads])]), body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { kind: vir::BinOpKind::CmpEq, lhs: cons_call_with_reads, @@ -512,28 +509,31 @@ impl TaskEncoder for TypeEncoder { // predicate let predicate = { - let expr = field_ty_out.iter() + let expr = field_ty_out + .iter() .enumerate() - .map(|(idx, field_ty_out)| vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc(vir::PredicateAppData { - target: field_ty_out.predicate_name, - args: vcx.alloc_slice(&[ - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_p}_field_{idx}"), - &[vcx.mk_local_ex("self_p")], - ), - ]), - })))) - .reduce(|base, field_expr| vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::And, - lhs: base, - rhs: field_expr, - })))) + .map(|(idx, field_ty_out)| { + vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc( + vir::PredicateAppData { + target: field_ty_out.predicate_name, + args: vcx.alloc_slice(&[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_p}_field_{idx}"), + &[vcx.mk_local_ex("self_p")], + )]), + }, + ))) + }) + .reduce(|base, field_expr| { + vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::And, + lhs: base, + rhs: field_expr, + }))) + }) .unwrap_or_else(|| vcx.mk_true()); vcx.alloc(vir::PredicateData { name: name_p, - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self_p", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), expr: Some(expr), }) }; @@ -549,40 +549,42 @@ impl TaskEncoder for TypeEncoder { function_snap: { let pred_app = vcx.alloc(vir::PredicateAppData { target: name_p, - args: vcx.alloc_slice(&[ - vcx.mk_local_ex("self_p"), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_ex("self_p")]), }); vcx.alloc(vir::FunctionData { name: vir::vir_format!(vcx, "{name_p}_snap"), - args: vcx.alloc_slice(&[ - vcx.mk_local_decl("self_p", &vir::TypeData::Ref), - ]), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), ret: ty_s, - pres: vcx.alloc_slice(&[ - vcx.alloc(vir::ExprData::PredicateApp(pred_app)), - ]), + pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), posts: &[], - expr: Some(vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: vcx.mk_func_app( - cons_name, - vcx.alloc_slice(&field_ty_out - .iter() - .enumerate() - .map(|(idx, field_ty_out)| vcx.mk_func_app( - field_ty_out.function_snap, - &[ - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_p}_field_{idx}"), - &[vcx.mk_local_ex("self_p")], - ), - ], - )) - .collect::>(), - ), - ), - })))), + expr: Some( + vcx.alloc(vir::ExprData::Unfolding( + vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: vcx.mk_func_app( + cons_name, + vcx.alloc_slice( + &field_ty_out + .iter() + .enumerate() + .map(|(idx, field_ty_out)| { + vcx.mk_func_app( + field_ty_out.function_snap, + &[vcx.mk_func_app( + vir::vir_format!( + vcx, + "{name_p}_field_{idx}" + ), + &[vcx.mk_local_ex("self_p")], + )], + ) + }) + .collect::>(), + ), + ), + }), + )), + ), }) }, //method_refold: mk_refold(vcx, name_p, ty_s), @@ -594,39 +596,44 @@ impl TaskEncoder for TypeEncoder { vir::with_vcx(|vcx| match task_key.kind() { TyKind::Bool => { let ty_s = vcx.alloc(vir::TypeData::Domain("s_Bool")); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: "s_Bool", - predicate_name: "p_Bool", - to_primitive: Some("s_Bool_val"), - from_primitive: Some("s_Bool_cons"), - snapshot: ty_s, - function_unreachable: "s_Bool_unreachable", - function_snap: "p_Bool_snap", - //method_refold: "refold_p_Bool", - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: "assign_p_Bool", - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: "f_Bool", - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain s_Bool { - function s_Bool_cons(Bool): s_Bool; - function s_Bool_val(s_Bool): Bool; - axiom_inverse(s_Bool_val, s_Bool_cons, Bool); - axiom_inverse(s_Bool_cons, s_Bool_val, s_Bool); - } }, - predicate: mk_simple_predicate(vcx, "p_Bool", "f_Bool"), - function_unreachable: mk_unreachable(vcx, "s_Bool", ty_s), - function_snap: mk_snap(vcx, "p_Bool", "s_Bool", Some("f_Bool"), ty_s), - //method_refold: mk_refold(vcx, "p_Bool", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Bool", ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: "s_Bool", + predicate_name: "p_Bool", + to_primitive: Some("s_Bool_val"), + from_primitive: Some("s_Bool_cons"), + snapshot: ty_s, + function_unreachable: "s_Bool_unreachable", + function_snap: "p_Bool_snap", + //method_refold: "refold_p_Bool", + specifics: TypeEncoderOutputRefSub::Primitive, + method_assign: "assign_p_Bool", + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: "f_Bool", + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain s_Bool { + function s_Bool_cons(Bool): s_Bool; + function s_Bool_val(s_Bool): Bool; + axiom_inverse(s_Bool_val, s_Bool_cons, Bool); + axiom_inverse(s_Bool_cons, s_Bool_val, s_Bool); + } }, + predicate: mk_simple_predicate(vcx, "p_Bool", "f_Bool"), + function_unreachable: mk_unreachable(vcx, "s_Bool", ty_s), + function_snap: mk_snap(vcx, "p_Bool", "s_Bool", Some("f_Bool"), ty_s), + //method_refold: mk_refold(vcx, "p_Bool", ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, "p_Bool", ty_s), + }, + (), + )) } - TyKind::Int(_) | - TyKind::Uint(_) => { + TyKind::Int(_) | TyKind::Uint(_) => { let (sign, name_str) = match task_key.kind() { TyKind::Int(kind) => ("Int", kind.name_str()), TyKind::Uint(kind) => ("Uint", kind.name_str()), @@ -638,89 +645,110 @@ impl TaskEncoder for TypeEncoder { let name_val = vir::vir_format!(vcx, "{name_s}_val"); let name_field = vir::vir_format!(vcx, "f_{name_s}"); let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: name_s, - predicate_name: name_p, - to_primitive: Some(name_val), - from_primitive: Some(name_cons), - snapshot: ty_s, - function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), - function_snap: vir::vir_format!(vcx, "{name_p}_snap"), - //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: vir::vir_format!(vcx, "assign_{name_p}"), - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: name_field, - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain [name_s] { - function [name_cons](Int): [ty_s]; - function [name_val]([ty_s]): Int; - axiom_inverse([name_val], [name_cons], Int); - axiom_inverse([name_cons], [name_val], [ty_s]); - } }, - predicate: mk_simple_predicate(vcx, name_p, name_field), - function_unreachable: mk_unreachable(vcx, name_s, ty_s), - function_snap: mk_snap(vcx, name_p, name_s, Some(name_field), ty_s), - //method_refold: mk_refold(vcx, name_p, ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, name_p, ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: name_s, + predicate_name: name_p, + to_primitive: Some(name_val), + from_primitive: Some(name_cons), + snapshot: ty_s, + function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), + function_snap: vir::vir_format!(vcx, "{name_p}_snap"), + //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), + specifics: TypeEncoderOutputRefSub::Primitive, + method_assign: vir::vir_format!(vcx, "assign_{name_p}"), + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: name_field, + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain [name_s] { + function [name_cons](Int): [ty_s]; + function [name_val]([ty_s]): Int; + axiom_inverse([name_val], [name_cons], Int); + axiom_inverse([name_cons], [name_val], [ty_s]); + } }, + predicate: mk_simple_predicate(vcx, name_p, name_field), + function_unreachable: mk_unreachable(vcx, name_s, ty_s), + function_snap: mk_snap(vcx, name_p, name_s, Some(name_field), ty_s), + //method_refold: mk_refold(vcx, name_p, ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, name_p, ty_s), + }, + (), + )) } TyKind::Tuple(tys) if tys.len() == 0 => { let ty_s = vcx.alloc(vir::TypeData::Domain("s_Tuple0")); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: "s_Tuple0", - predicate_name: "p_Tuple0", - to_primitive: None, - from_primitive: None, - snapshot: ty_s, - function_unreachable: "s_Tuple0_unreachable", - function_snap: "p_Tuple0_snap", - //method_refold: "refold_p_Tuple0", - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: vir::vir_format!(vcx, "assign_p_Tuple0"), - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: vir::vir_format!(vcx, "f_Tuple0"), - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain s_Tuple0 { - function s_Tuple0_cons(): [ty_s]; - } }, - predicate: vir::vir_predicate! { vcx; predicate p_Tuple0(self_p: Ref) }, - function_unreachable: mk_unreachable(vcx, "s_Tuple0", ty_s), - function_snap: mk_snap(vcx, "p_Tuple0", "s_Tuple0", None, ty_s), - //method_refold: mk_refold(vcx, "p_Tuple0", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Tuple0", ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: "s_Tuple0", + predicate_name: "p_Tuple0", + to_primitive: None, + from_primitive: None, + snapshot: ty_s, + function_unreachable: "s_Tuple0_unreachable", + function_snap: "p_Tuple0_snap", + //method_refold: "refold_p_Tuple0", + specifics: TypeEncoderOutputRefSub::Primitive, + method_assign: vir::vir_format!(vcx, "assign_p_Tuple0"), + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: vir::vir_format!(vcx, "f_Tuple0"), + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain s_Tuple0 { + function s_Tuple0_cons(): [ty_s]; + } }, + predicate: vir::vir_predicate! { vcx; predicate p_Tuple0(self_p: Ref) }, + function_unreachable: mk_unreachable(vcx, "s_Tuple0", ty_s), + function_snap: mk_snap(vcx, "p_Tuple0", "s_Tuple0", None, ty_s), + //method_refold: mk_refold(vcx, "p_Tuple0", ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, "p_Tuple0", ty_s), + }, + (), + )) } TyKind::Tuple(tys) => { let field_ty_out = tys .iter() - .map(|ty| deps.require_ref::(ty).unwrap()) + .map(|ty| { + deps.require_ref::(ty) + .unwrap() + }) .collect::>(); // TODO: Properly name the tuple according to its types, or make generic? // Temporary fix to make it possisble to have multiple tuples of the same size with different element types - let tmp_ty_name = field_ty_out.iter().map(|e| e.snapshot_name).collect::>().join("_"); - - - Ok((mk_structlike( - vcx, - deps, - task_key, - vir::vir_format!(vcx, "s_Tuple{}_{}", tys.len(), tmp_ty_name), - vir::vir_format!(vcx, "p_Tuple{}_{}", tys.len(), tmp_ty_name), - field_ty_out, - )?, ())) + let tmp_ty_name = field_ty_out + .iter() + .map(|e| e.snapshot_name) + .collect::>() + .join("_"); + + Ok(( + mk_structlike( + vcx, + deps, + task_key, + vir::vir_format!(vcx, "s_Tuple{}_{}", tys.len(), tmp_ty_name), + vir::vir_format!(vcx, "p_Tuple{}_{}", tys.len(), tmp_ty_name), + field_ty_out, + )?, + (), + )) /* let ty_len = tys.len(); @@ -757,76 +785,97 @@ impl TaskEncoder for TypeEncoder { } TyKind::Param(_param) => { - let param_out = deps.require_ref::(()).unwrap(); + let param_out = deps + .require_ref::(()) + .unwrap(); let ty_s = vcx.alloc(vir::TypeData::Domain("s_Param")); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: param_out.snapshot_param_name, - predicate_name: param_out.predicate_param_name, - to_primitive: None, - from_primitive: None, - snapshot: ty_s, - function_unreachable: "s_Param_unreachable", - function_snap: "p_Param_snap", - //method_refold: "refold_p_Param", - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: vir::vir_format!(vcx, "assign_p_Bool"), - }); - Ok((TypeEncoderOutput { - fields: &[], - snapshot: vir::vir_domain! { vcx; domain s_ParamTodo { // TODO: should not be emitted -- make outputs vectors - } }, - predicate: vir::vir_predicate! { vcx; predicate p_ParamTodo(self_p: Ref) }, - function_unreachable: mk_unreachable(vcx, "p_Param", ty_s), - function_snap: mk_snap(vcx, "p_Param", "s_Param", None, ty_s), - //method_refold: mk_refold(vcx, "p_Param", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Param", ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: param_out.snapshot_param_name, + predicate_name: param_out.predicate_param_name, + to_primitive: None, + from_primitive: None, + snapshot: ty_s, + function_unreachable: "s_Param_unreachable", + function_snap: "p_Param_snap", + //method_refold: "refold_p_Param", + specifics: TypeEncoderOutputRefSub::Primitive, + method_assign: vir::vir_format!(vcx, "assign_p_Bool"), + }, + ); + Ok(( + TypeEncoderOutput { + fields: &[], + snapshot: vir::vir_domain! { vcx; domain s_ParamTodo { // TODO: should not be emitted -- make outputs vectors + } }, + predicate: vir::vir_predicate! { vcx; predicate p_ParamTodo(self_p: Ref) }, + function_unreachable: mk_unreachable(vcx, "p_Param", ty_s), + function_snap: mk_snap(vcx, "p_Param", "s_Param", None, ty_s), + //method_refold: mk_refold(vcx, "p_Param", ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, "p_Param", ty_s), + }, + (), + )) } TyKind::Adt(adt_def, substs) if adt_def.is_struct() => { tracing::debug!("encoding ADT {adt_def:?} with substs {substs:?}"); let substs = ty::List::identity_for_item(vcx.tcx, adt_def.did()); - let field_ty_out = adt_def.all_fields() - .map(|field| deps.require_ref::(field.ty(vcx.tcx, substs)).unwrap()) + let field_ty_out = adt_def + .all_fields() + .map(|field| { + deps.require_ref::(field.ty(vcx.tcx, substs)) + .unwrap() + }) .collect::>(); let did_name = vcx.tcx.item_name(adt_def.did()).to_ident_string(); - Ok((mk_structlike( - vcx, - deps, - task_key, - vir::vir_format!(vcx, "s_Adt_{did_name}"), - vir::vir_format!(vcx, "p_Adt_{did_name}"), - field_ty_out, - )?, ())) + Ok(( + mk_structlike( + vcx, + deps, + task_key, + vir::vir_format!(vcx, "s_Adt_{did_name}"), + vir::vir_format!(vcx, "p_Adt_{did_name}"), + field_ty_out, + )?, + (), + )) } TyKind::Never => { let ty_s = vcx.alloc(vir::TypeData::Domain("s_Never")); - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: "s_Never", - predicate_name: "p_Never", - to_primitive: None, - from_primitive: None, - snapshot: ty_s, - function_unreachable: "s_Never_unreachable", - function_snap: "p_Never_snap", - //method_refold: "refold_p_Never", - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: vir::vir_format!(vcx, "assign_p_Never"), - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: vir::vir_format!(vcx, "f_Never"), - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain s_Never {} }, - predicate: vir::vir_predicate! { vcx; predicate p_Never(self_p: Ref) }, - function_unreachable: mk_unreachable(vcx, "s_Never", ty_s), - function_snap: mk_snap(vcx, "p_Never", "s_Never", None, ty_s), - //method_refold: mk_refold(vcx, "p_Never", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Never", ty_s), - }, ())) + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: "s_Never", + predicate_name: "p_Never", + to_primitive: None, + from_primitive: None, + snapshot: ty_s, + function_unreachable: "s_Never_unreachable", + function_snap: "p_Never_snap", + //method_refold: "refold_p_Never", + specifics: TypeEncoderOutputRefSub::Primitive, + method_assign: vir::vir_format!(vcx, "assign_p_Never"), + }, + ); + Ok(( + TypeEncoderOutput { + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + name: vir::vir_format!(vcx, "f_Never"), + ty: ty_s, + })]), + snapshot: vir::vir_domain! { vcx; domain s_Never {} }, + predicate: vir::vir_predicate! { vcx; predicate p_Never(self_p: Ref) }, + function_unreachable: mk_unreachable(vcx, "s_Never", ty_s), + function_snap: mk_snap(vcx, "p_Never", "s_Never", None, ty_s), + //method_refold: mk_refold(vcx, "p_Never", ty_s), + field_projection_p: &[], + method_assign: mk_assign(vcx, "p_Never", ty_s), + }, + (), + )) } //_ => Err((TypeEncoderError::UnsupportedType, None)), unsupported_type => todo!("type not supported: {unsupported_type:?}"), diff --git a/prusti-encoder/src/encoders/viper_tuple.rs b/prusti-encoder/src/encoders/viper_tuple.rs index 708665d88a1..174c07a7ade 100644 --- a/prusti-encoder/src/encoders/viper_tuple.rs +++ b/prusti-encoder/src/encoders/viper_tuple.rs @@ -1,8 +1,5 @@ -use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, -}; use std::cell::RefCell; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; pub struct ViperTupleEncoder; @@ -19,7 +16,7 @@ impl<'vir> ViperTupleEncoderOutputRef<'vir> { pub fn mk_cons( &self, vcx: &'vir vir::VirCtxt<'vir>, - elems: &[vir::ExprGen<'vir, Curr, Next>] + elems: &[vir::ExprGen<'vir, Curr, Next>], ) -> vir::ExprGen<'vir, Curr, Next> { if self.elem_count == 1 { return elems[0]; @@ -58,7 +55,8 @@ impl TaskEncoder for ViperTupleEncoder { type EncodingError = (); fn with_cache<'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir task_encoder::CacheRef<'vir, ViperTupleEncoder>) -> R, + where + F: FnOnce(&'vir task_encoder::CacheRef<'vir, ViperTupleEncoder>) -> R, { CACHE.with(|cache| { // SAFETY: the 'vir and 'tcx given to this function will always be @@ -76,69 +74,89 @@ impl TaskEncoder for ViperTupleEncoder { fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + > { vir::with_vcx(|vcx| { let domain_name = vir::vir_format!(vcx, "Tuple_{task_key}"); let cons_name = vir::vir_format!(vcx, "Tuple_{task_key}_cons"); let elem_names = (0..*task_key) .map(|idx| vir::vir_format!(vcx, "Tuple_{task_key}_elem_{idx}")) .collect::>(); - deps.emit_output_ref::(*task_key, ViperTupleEncoderOutputRef { - elem_count: *task_key, - domain_name, - cons_name, - elem_names: vcx.alloc_slice(&elem_names), - }); + deps.emit_output_ref::( + *task_key, + ViperTupleEncoderOutputRef { + elem_count: *task_key, + domain_name, + cons_name, + elem_names: vcx.alloc_slice(&elem_names), + }, + ); let typaram_names = (0..*task_key) .map(|idx| vir::vir_format!(vcx, "T{idx}")) .collect::>(); - let typaram_tys = vcx.alloc_slice(&typaram_names.iter() - .map(|name| vcx.alloc(vir::TypeData::Domain(name))) - .collect::>()); + let typaram_tys = vcx.alloc_slice( + &typaram_names + .iter() + .map(|name| vcx.alloc(vir::TypeData::Domain(name))) + .collect::>(), + ); let domain_ty = vcx.alloc(vir::TypeData::DomainParams(domain_name, typaram_tys)); let qvars_names = (0..*task_key) .map(|idx| vir::vir_format!(vcx, "elem{idx}")) .collect::>(); - let qvars_decl = vcx.alloc_slice(&(0..*task_key) - .map(|idx| vcx.mk_local_decl(qvars_names[idx], typaram_tys[idx])) - .collect::>()); + let qvars_decl = vcx.alloc_slice( + &(0..*task_key) + .map(|idx| vcx.mk_local_decl(qvars_names[idx], typaram_tys[idx])) + .collect::>(), + ); let qvars_ex = (0..*task_key) .map(|idx| vcx.mk_local_ex(qvars_names[idx])) .collect::>(); let cons_call = vcx.mk_func_app( cons_name, - &qvars_names.iter() + &qvars_names + .iter() .map(|qvar| vcx.mk_local_ex(qvar)) .collect::>(), ); let axiom = vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_Tuple_{task_key}_elem"), - expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: qvars_decl, - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), - body: vcx.mk_conj(&(0..*task_key) - .map(|idx| vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: vcx.mk_func_app(elem_names[idx], &[cons_call]), - rhs: qvars_ex[idx], - })))) - .collect::>()), - }))), + expr: vcx.alloc(vir::ExprData::Forall( + vcx.alloc(vir::ForallData { + qvars: qvars_decl, + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), + body: vcx.mk_conj( + &(0..*task_key) + .map(|idx| { + vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::CmpEq, + lhs: vcx.mk_func_app(elem_names[idx], &[cons_call]), + rhs: qvars_ex[idx], + }))) + }) + .collect::>(), + ), + }), + )), }); let elem_args = vcx.alloc_slice(&[domain_ty]); let mut functions = (0..*task_key) - .map(|idx| vcx.alloc(vir::DomainFunctionData { - unique: false, - name: elem_names[idx], - args: elem_args, - ret: typaram_tys[idx], - })) + .map(|idx| { + vcx.alloc(vir::DomainFunctionData { + unique: false, + name: elem_names[idx], + args: elem_args, + ret: typaram_tys[idx], + }) + }) .collect::>(); functions.push(vcx.alloc(vir::DomainFunctionData { unique: false, @@ -146,14 +164,21 @@ impl TaskEncoder for ViperTupleEncoder { args: typaram_tys, ret: domain_ty, })); - Ok((ViperTupleEncoderOutput { - domain: Some(vcx.alloc(vir::DomainData { - name: domain_name, - typarams: vcx.alloc_slice(&typaram_names), - axioms: if task_key == &0 {&[]} else {vcx.alloc_slice(&[axiom])}, - functions: vcx.alloc_slice(&functions), - })), - }, ())) + Ok(( + ViperTupleEncoderOutput { + domain: Some(vcx.alloc(vir::DomainData { + name: domain_name, + typarams: vcx.alloc_slice(&typaram_names), + axioms: if task_key == &0 { + &[] + } else { + vcx.alloc_slice(&[axiom]) + }, + functions: vcx.alloc_slice(&functions), + })), + }, + (), + )) }) } } diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index 6ca82e1d8a9..a91c3cf160c 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -10,10 +10,7 @@ extern crate rustc_type_ir; mod encoders; use prusti_interface::{environment::EnvBody, specs::typed::SpecificationItem}; -use prusti_rustc_interface::{ - middle::ty, - hir, -}; +use prusti_rustc_interface::{hir, middle::ty}; /* struct MirBodyPureEncoder; @@ -100,7 +97,7 @@ impl<'vir, 'tcx> TaskEncoder<'vir, 'tcx> for MirBodyImpureEncoder<'vir, 'tcx> { ); // TaskKey, OutputRef same as above type OutputFull = vir::Method<'vir>; -} +} struct MirTyEncoder<'vir, 'tcx>(PhantomData<&'vir ()>, PhantomData<&'tcx ()>); impl<'vir, 'tcx> TaskEncoder<'vir, 'tcx> for MirTyEncoder<'vir, 'tcx> { @@ -147,27 +144,29 @@ pub fn test_entrypoint<'tcx>( .get_proc_spec(&def_id.to_def_id()) .map(|e| &e.base_spec); - let is_pure = base_spec.and_then(|kind| kind.kind.is_pure().ok()).unwrap_or_default(); - let is_trusted = matches!(base_spec.map(|spec| spec.trusted), Some(SpecificationItem::Inherent( - true, - ))); - (is_pure, is_trusted) - } - ); - - if ! (is_trusted && is_pure) { + let is_pure = base_spec + .and_then(|kind| kind.kind.is_pure().ok()) + .unwrap_or_default(); + let is_trusted = matches!( + base_spec.map(|spec| spec.trusted), + Some(SpecificationItem::Inherent(true,)) + ); + (is_pure, is_trusted) + }); + + if !(is_trusted && is_pure) { let res = crate::encoders::MirImpureEncoder::encode(def_id.to_def_id()); assert!(res.is_ok()); } - if is_pure { - tracing::debug!("Encoding {def_id:?} as a pure function because it is labeled as pure"); + tracing::debug!( + "Encoding {def_id:?} as a pure function because it is labeled as pure" + ); let res = crate::encoders::MirFunctionEncoder::encode(def_id.to_def_id()); assert!(res.is_ok()); } - /* match res { Ok(res) => println!("ok: {:?}", res), @@ -235,20 +234,20 @@ pub fn test_entrypoint<'tcx>( std::fs::write("local-testing/simple.vpr", viper_code).unwrap(); - vir::with_vcx(|vcx| vcx.alloc(vir::ProgramData { - fields: &[], - domains: &[], - predicates: &[], - functions: vcx.alloc_slice(&[ - vcx.alloc(vir::FunctionData { + vir::with_vcx(|vcx| { + vcx.alloc(vir::ProgramData { + fields: &[], + domains: &[], + predicates: &[], + functions: vcx.alloc_slice(&[vcx.alloc(vir::FunctionData { name: "test_function", args: &[], ret: &vir::TypeData::Bool, pres: &[], posts: &[], expr: None, - }), - ]), - methods: &[], - })) + })]), + methods: &[], + }) + }) } diff --git a/prusti-interface/src/environment/body.rs b/prusti-interface/src/environment/body.rs index 90593363e48..84873731aae 100644 --- a/prusti-interface/src/environment/body.rs +++ b/prusti-interface/src/environment/body.rs @@ -136,7 +136,7 @@ impl<'tcx> EnvBody<'tcx> { BodyWithBorrowckFacts { body: MirBody(Rc::new(body_with_facts.body)), - // borrowck_facts: Rc::new(facts), + // borrowck_facts: Rc::new(facts), } } @@ -174,12 +174,18 @@ impl<'tcx> EnvBody<'tcx> { { let monomorphised = if let Some(caller_def_id) = caller_def_id { let param_env = self.tcx.param_env(caller_def_id); - self.tcx - .subst_and_normalize_erasing_regions(substs, param_env, ty::EarlyBinder::bind(body.0)) + self.tcx.subst_and_normalize_erasing_regions( + substs, + param_env, + ty::EarlyBinder::bind(body.0), + ) } else { let param_env = self.tcx.param_env(def_id); - self.tcx - .subst_and_normalize_erasing_regions(substs, param_env, ty::EarlyBinder::bind(body.0)) + self.tcx.subst_and_normalize_erasing_regions( + substs, + param_env, + ty::EarlyBinder::bind(body.0), + ) }; v.insert(MirBody(monomorphised)).clone() } else { @@ -201,7 +207,11 @@ impl<'tcx> EnvBody<'tcx> { /// with the given type substitutions. /// /// FIXME: This function is called only in pure contexts??? - pub fn get_impure_fn_body(&self, def_id: LocalDefId, substs: GenericArgsRef<'tcx>) -> MirBody<'tcx> { + pub fn get_impure_fn_body( + &self, + def_id: LocalDefId, + substs: GenericArgsRef<'tcx>, + ) -> MirBody<'tcx> { if let Some(body) = self.get_monomorphised(def_id.to_def_id(), substs, None) { return body; } @@ -328,7 +338,7 @@ impl<'tcx> EnvBody<'tcx> { pub(crate) fn load_pure_fn_body(&mut self, def_id: LocalDefId) { assert!(!self.pure_fns.local.contains_key(&def_id)); - let body = Self::load_local_mir( self.tcx, def_id); + let body = Self::load_local_mir(self.tcx, def_id); self.pure_fns.local.insert(def_id, body); let bwbf = Self::load_local_mir_with_facts(self.tcx, def_id); // Also add to `impure_fns` since we'll also be encoding this as impure diff --git a/prusti-interface/src/environment/query.rs b/prusti-interface/src/environment/query.rs index 77c0319a7f2..174eb9843a9 100644 --- a/prusti-interface/src/environment/query.rs +++ b/prusti-interface/src/environment/query.rs @@ -7,7 +7,10 @@ use prusti_rustc_interface::{ hir::hir_id::HirId, middle::{ hir::map::Map, - ty::{self, Binder, BoundConstness, GenericArgsRef, ImplPolarity, ParamEnv, TraitPredicate, TyCtxt}, + ty::{ + self, Binder, BoundConstness, GenericArgsRef, ImplPolarity, ParamEnv, TraitPredicate, + TyCtxt, + }, }, span::{ def_id::{DefId, LocalDefId}, @@ -206,66 +209,66 @@ impl<'tcx> EnvQuery<'tcx> { impl_method_substs: GenericArgsRef<'tcx>, // what are the substs on the call? ) -> Option<(ProcedureDefId, GenericArgsRef<'tcx>)> { todo!() /* - let impl_method_def_id = impl_method_def_id.into_param(); - let impl_def_id = self.tcx.impl_of_method(impl_method_def_id)?; - let trait_ref = self.tcx.impl_trait_ref(impl_def_id)?.skip_binder(); - - // At this point, we know that the given method: - // - belongs to an impl block and - // - the impl block implements a trait. - // For the `get_assoc_item` call, we therefore `unwrap`, as not finding - // the associated item would be a (compiler) internal error. - let trait_def_id = trait_ref.def_id; - let trait_method_def_id = self - .get_assoc_item(trait_def_id, impl_method_def_id) - .unwrap() - .def_id; - - // sanity check: have we been given the correct number of substs? - let identity_impl_method = self.identity_substs(impl_method_def_id); - assert_eq!(identity_impl_method.len(), impl_method_substs.len()); - - // Given: - // ``` - // trait Trait { - // fn f(); - // } - // struct Struct { ... } - // impl Trait for Struct { - // fn f() { ... } - // } - // ``` - // - // The various substs look like this: - // - identity for Trait: `[Self, Tp]` - // - identity for Trait::f: `[Self, Tp, Tx, Ty, Tz]` - // - substs of the impl trait ref: `[Struct, A]` - // - identity for the impl: `[A, B, C]` - // - identity for Struct::f: `[A, B, C, X, Y, Z]` - // - // What we need is a substs suitable for a call to Trait::f, which is in - // this case `[Struct, A, X, Y, Z]`. More generally, it is the - // concatenation of the trait ref substs with the identity of the impl - // method after skipping the identity of the impl. - // - // We also need to subst the prefix (`[Struct, A]` in the example - // above) with call substs, so that we get the trait's type parameters - // more precisely. We can do this directly with `impl_method_substs` - // because they contain the substs for the `impl` block as a prefix. - let call_trait_substs = - ty::EarlyBinder(trait_ref.substs).subst(self.tcx, impl_method_substs); - let impl_substs = self.identity_substs(impl_def_id); - let trait_method_substs = self.tcx.mk_substs_from_iter( - call_trait_substs - .iter() - .chain(impl_method_substs.iter().skip(impl_substs.len())), - ); + let impl_method_def_id = impl_method_def_id.into_param(); + let impl_def_id = self.tcx.impl_of_method(impl_method_def_id)?; + let trait_ref = self.tcx.impl_trait_ref(impl_def_id)?.skip_binder(); + + // At this point, we know that the given method: + // - belongs to an impl block and + // - the impl block implements a trait. + // For the `get_assoc_item` call, we therefore `unwrap`, as not finding + // the associated item would be a (compiler) internal error. + let trait_def_id = trait_ref.def_id; + let trait_method_def_id = self + .get_assoc_item(trait_def_id, impl_method_def_id) + .unwrap() + .def_id; + + // sanity check: have we been given the correct number of substs? + let identity_impl_method = self.identity_substs(impl_method_def_id); + assert_eq!(identity_impl_method.len(), impl_method_substs.len()); + + // Given: + // ``` + // trait Trait { + // fn f(); + // } + // struct Struct { ... } + // impl Trait for Struct { + // fn f() { ... } + // } + // ``` + // + // The various substs look like this: + // - identity for Trait: `[Self, Tp]` + // - identity for Trait::f: `[Self, Tp, Tx, Ty, Tz]` + // - substs of the impl trait ref: `[Struct, A]` + // - identity for the impl: `[A, B, C]` + // - identity for Struct::f: `[A, B, C, X, Y, Z]` + // + // What we need is a substs suitable for a call to Trait::f, which is in + // this case `[Struct, A, X, Y, Z]`. More generally, it is the + // concatenation of the trait ref substs with the identity of the impl + // method after skipping the identity of the impl. + // + // We also need to subst the prefix (`[Struct, A]` in the example + // above) with call substs, so that we get the trait's type parameters + // more precisely. We can do this directly with `impl_method_substs` + // because they contain the substs for the `impl` block as a prefix. + let call_trait_substs = + ty::EarlyBinder(trait_ref.substs).subst(self.tcx, impl_method_substs); + let impl_substs = self.identity_substs(impl_def_id); + let trait_method_substs = self.tcx.mk_substs_from_iter( + call_trait_substs + .iter() + .chain(impl_method_substs.iter().skip(impl_substs.len())), + ); - // sanity check: do we now have the correct number of substs? - let identity_trait_method = self.identity_substs(trait_method_def_id); - assert_eq!(trait_method_substs.len(), identity_trait_method.len()); + // sanity check: do we now have the correct number of substs? + let identity_trait_method = self.identity_substs(trait_method_def_id); + assert_eq!(trait_method_substs.len(), identity_trait_method.len()); - Some((trait_method_def_id, trait_method_substs))*/ + Some((trait_method_def_id, trait_method_substs))*/ } /// Given some procedure `proc_def_id` which is called, this method returns the actual method which will be executed when `proc_def_id` is defined on a trait. diff --git a/prusti-interface/src/specs/encoder.rs b/prusti-interface/src/specs/encoder.rs index 21b7b6f4531..ae42d305b6c 100644 --- a/prusti-interface/src/specs/encoder.rs +++ b/prusti-interface/src/specs/encoder.rs @@ -18,10 +18,7 @@ pub struct DefSpecsEncoder<'tcx> { } impl<'tcx> DefSpecsEncoder<'tcx> { - pub fn new( - tcx: TyCtxt<'tcx>, - path: &std::path::PathBuf, - ) -> std::io::Result { + pub fn new(tcx: TyCtxt<'tcx>, path: &std::path::PathBuf) -> std::io::Result { Ok(DefSpecsEncoder { tcx, opaque: opaque::FileEncoder::new(path)?, diff --git a/prusti-interface/src/specs/external.rs b/prusti-interface/src/specs/external.rs index 608a5cd8d62..4cfc2f60c73 100644 --- a/prusti-interface/src/specs/external.rs +++ b/prusti-interface/src/specs/external.rs @@ -4,10 +4,7 @@ use prusti_rustc_interface::{ def_id::{DefId, LocalDefId}, intravisit::{self, Visitor}, }, - middle::{ - hir::map::Map, - ty::GenericArgsRef, - }, + middle::{hir::map::Map, ty::GenericArgsRef}, span::Span, }; diff --git a/prusti-interface/src/utils.rs b/prusti-interface/src/utils.rs index c9de49e76ed..c5e8d7f80db 100644 --- a/prusti-interface/src/utils.rs +++ b/prusti-interface/src/utils.rs @@ -10,7 +10,10 @@ use prusti_rustc_interface::{ abi::FieldIdx, ast::ast, data_structures::fx::FxHashSet, - middle::{mir, ty::{self, TyCtxt}}, + middle::{ + mir, + ty::{self, TyCtxt}, + }, }; use std::borrow::Borrow; @@ -95,8 +98,7 @@ pub fn expand_struct_place<'tcx>( for (index, field_def) in variant.fields.iter().enumerate() { if Some(index) != without_field { let field = FieldIdx::from_usize(index); - let field_place = - tcx.mk_place_field(place, field, field_def.ty(tcx, substs)); + let field_place = tcx.mk_place_field(place, field, field_def.ty(tcx, substs)); places.push(field_place); } } @@ -133,7 +135,6 @@ pub fn expand_struct_place<'tcx>( places } - /// Pop the last projection from the place and return the new place with the popped element. pub fn try_pop_one_level<'tcx>( tcx: TyCtxt<'tcx>, diff --git a/prusti/src/callbacks.rs b/prusti/src/callbacks.rs index beb48f290f8..102da95c4fb 100644 --- a/prusti/src/callbacks.rs +++ b/prusti/src/callbacks.rs @@ -9,16 +9,12 @@ use prusti_rustc_interface::{ borrowck::consumers, data_structures::steal::Steal, driver::Compilation, + hir::{def::DefKind, def_id::LocalDefId}, index::IndexVec, interface::{interface::Compiler, Config, Queries}, - hir::{def::DefKind, def_id::LocalDefId}, middle::{ mir, - query::{ - queries::mir_borrowck::ProvidedValue as MirBorrowck, - ExternProviders, - Providers - }, + query::{queries::mir_borrowck::ProvidedValue as MirBorrowck, ExternProviders, Providers}, ty::TyCtxt, }, session::{EarlyErrorHandler, Session}, @@ -174,7 +170,7 @@ impl prusti_rustc_interface::driver::Callbacks for PrustiCompilerCalls { test_free_pcs(&mir, tcx); } } else {*/ - verify(env, def_spec); + verify(env, def_spec); //} } }); diff --git a/prusti/src/driver.rs b/prusti/src/driver.rs index 358dad21e0c..7a58e953631 100644 --- a/prusti/src/driver.rs +++ b/prusti/src/driver.rs @@ -56,18 +56,20 @@ fn report_prusti_ice(info: &panic::PanicInfo<'_>, bug_report_url: &str) { prusti_rustc_interface::driver::DEFAULT_LOCALE_RESOURCES.to_vec(), false, ); - let emitter = Box::new(prusti_rustc_interface::errors::emitter::EmitterWriter::stderr( - prusti_rustc_interface::errors::ColorConfig::Auto, - None, - None, - fallback_bundle, - false, - false, - None, - false, - false, - prusti_rustc_interface::errors::TerminalUrl::Auto, - )); + let emitter = Box::new( + prusti_rustc_interface::errors::emitter::EmitterWriter::stderr( + prusti_rustc_interface::errors::ColorConfig::Auto, + None, + None, + fallback_bundle, + false, + false, + None, + false, + false, + prusti_rustc_interface::errors::TerminalUrl::Auto, + ), + ); let handler = prusti_rustc_interface::errors::Handler::with_emitter(true, None, emitter); // a .span_bug or .bug call has already printed what it wants to print. diff --git a/prusti/src/verifier.rs b/prusti/src/verifier.rs index 842265c59c7..a047c1fb038 100644 --- a/prusti/src/verifier.rs +++ b/prusti/src/verifier.rs @@ -14,40 +14,36 @@ pub fn verify(env: Environment<'_>, def_spec: typed::DefSpecificationMap) { if env.diagnostic.has_errors() { warn!("The compiler reported an error, so the program will not be verified."); } else { - debug!("Prepare verification task...");/* - // TODO: can we replace `get_annotated_procedures` with information - // that is already in `def_spec`? - let (annotated_procedures, types) = env.get_annotated_procedures_and_types(); - let verification_task = VerificationTask { - procedures: annotated_procedures, - types, - }; - debug!("Verification task: {:?}", &verification_task); + debug!("Prepare verification task..."); /* + // TODO: can we replace `get_annotated_procedures` with information + // that is already in `def_spec`? + let (annotated_procedures, types) = env.get_annotated_procedures_and_types(); + let verification_task = VerificationTask { + procedures: annotated_procedures, + types, + }; + debug!("Verification task: {:?}", &verification_task); - user::message(format!( - "Verification of {} items...", - verification_task.procedures.len() - )); + user::message(format!( + "Verification of {} items...", + verification_task.procedures.len() + )); - if config::print_collected_verification_items() { - println!( - "Collected verification items {}:", - verification_task.procedures.len() - ); - for procedure in &verification_task.procedures { - println!( - "procedure: {} at {:?}", - env.name.get_item_def_path(*procedure), - env.query.get_def_span(procedure) - ); - } - }*/ + if config::print_collected_verification_items() { + println!( + "Collected verification items {}:", + verification_task.procedures.len() + ); + for procedure in &verification_task.procedures { + println!( + "procedure: {} at {:?}", + env.name.get_item_def_path(*procedure), + env.query.get_def_span(procedure) + ); + } + }*/ - let program = prusti_encoder::test_entrypoint( - env.tcx(), - env.body, - def_spec, - ); + let program = prusti_encoder::test_entrypoint(env.tcx(), env.body, def_spec); //viper::verify(program); //let verification_result = diff --git a/task-encoder/src/lib.rs b/task-encoder/src/lib.rs index d92cea432de..fa36c724f9d 100644 --- a/task-encoder/src/lib.rs +++ b/task-encoder/src/lib.rs @@ -8,7 +8,6 @@ impl<'vir> OutputRefAny<'vir> for () {} pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { // None, // indicated by absence in the cache - /// Task was enqueued but not yet started. Enqueued, @@ -29,9 +28,7 @@ pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { }, /// An error occurred when enqueing the task. - ErrorEnqueue { - error: TaskEncoderError, - }, + ErrorEnqueue { error: TaskEncoderError }, /// An error occurred when encoding the task. The full "local" encoding is /// not available. However, tasks which depend on this task may still @@ -49,16 +46,12 @@ pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { /// Cache for a task encoder. See `TaskEncoderCacheState` for a description of /// the possible values in the encoding process. -pub type Cache<'vir, E> = LinkedHashMap< - ::TaskKey<'vir>, - TaskEncoderCacheState<'vir, E>, ->; +pub type Cache<'vir, E> = + LinkedHashMap<::TaskKey<'vir>, TaskEncoderCacheState<'vir, E>>; pub type CacheRef<'vir, E> = RefCell>; -pub type CacheStatic = LinkedHashMap< - ::TaskKey<'static>, - TaskEncoderCacheState<'static, E>, ->; +pub type CacheStatic = + LinkedHashMap<::TaskKey<'static>, TaskEncoderCacheState<'static, E>>; pub type CacheStaticRef = RefCell>; /* pub struct TaskEncoderOutput<'vir, E: TaskEncoder>( @@ -85,7 +78,8 @@ pub enum TaskEncoderError { } impl std::fmt::Debug for TaskEncoderError - where ::EncodingError: std::fmt::Debug +where + ::EncodingError: std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut helper = f.debug_struct("TaskEncoderError"); @@ -118,30 +112,21 @@ impl<'vir> TaskEncoderDependencies<'vir> { pub fn require_ref( &mut self, task: ::TaskDescription<'vir>, - ) -> Result< - ::OutputRef<'vir>, - TaskEncoderError, - > { + ) -> Result<::OutputRef<'vir>, TaskEncoderError> { E::encode_ref(task) } pub fn require_local( &mut self, task: ::TaskDescription<'vir>, - ) -> Result< - ::OutputFullLocal<'vir>, - TaskEncoderError, - > { + ) -> Result<::OutputFullLocal<'vir>, TaskEncoderError> { E::encode(task).map(|(_output_ref, output_local, _output_dep)| output_local) } pub fn require_dep( &mut self, task: ::TaskDescription<'vir>, - ) -> Result< - ::OutputFullDependency<'vir>, - TaskEncoderError, - > { + ) -> Result<::OutputFullDependency<'vir>, TaskEncoderError> { E::encode(task).map(|(_output_ref, _output_local, output_dep)| output_dep) } @@ -150,10 +135,12 @@ impl<'vir> TaskEncoderDependencies<'vir> { task_key: E::TaskKey<'vir>, output_ref: E::OutputRef<'vir>, ) { - assert!(E::with_cache(move |cache| matches!(cache.borrow_mut().insert( - task_key, - TaskEncoderCacheState::Started { output_ref }, - ), Some(TaskEncoderCacheState::Enqueued)))); + assert!(E::with_cache(move |cache| matches!( + cache + .borrow_mut() + .insert(task_key, TaskEncoderCacheState::Started { output_ref },), + Some(TaskEncoderCacheState::Enqueued) + ))); } } @@ -166,7 +153,8 @@ pub trait TaskEncoder { /// for example if the description should be normalised or some non-trivial /// resolution needs to happen. In other words, multiple descriptions may /// lead to the same key and hence the same output. - type TaskKey<'vir>: std::hash::Hash + Eq + Clone + std::fmt::Debug = Self::TaskDescription<'vir>; + type TaskKey<'vir>: std::hash::Hash + Eq + Clone + std::fmt::Debug = + Self::TaskDescription<'vir>; /// A reference to an encoded item. Should be non-unit for tasks which can /// be "referred" to from other parts of a program, as opposed to tasks @@ -190,7 +178,9 @@ pub trait TaskEncoder { /// Enters the given function with a reference to the cache for this /// encoder. fn with_cache<'vir, F, R>(f: F) -> R - where Self: 'vir, F: FnOnce(&'vir CacheRef<'vir, Self>) -> R; + where + Self: 'vir, + F: FnOnce(&'vir CacheRef<'vir, Self>) -> R; //fn get_all_outputs() -> Self::CacheRef<'vir> { // todo!() @@ -199,7 +189,8 @@ pub trait TaskEncoder { //} fn enqueue<'vir>(task: Self::TaskDescription<'vir>) - where Self: 'vir + where + Self: 'vir, { let task_key = Self::task_to_key(&task); let task_key_clone = task_key.clone(); // TODO: remove? @@ -209,28 +200,32 @@ pub trait TaskEncoder { } // enqueue, expecting no entry (we just checked) - assert!(Self::with_cache(move |cache| cache.borrow_mut().insert( - task_key, - TaskEncoderCacheState::Enqueued, - ).is_none())); + assert!(Self::with_cache(move |cache| cache + .borrow_mut() + .insert(task_key, TaskEncoderCacheState::Enqueued,) + .is_none())); } - fn encode_ref<'vir>(task: Self::TaskDescription<'vir>) -> Result< - Self::OutputRef<'vir>, - TaskEncoderError, - > - where Self: 'vir + fn encode_ref<'vir>( + task: Self::TaskDescription<'vir>, + ) -> Result, TaskEncoderError> + where + Self: 'vir, { let task_key = Self::task_to_key(&task); // is there an output ref available already? let task_key_clone = task_key.clone(); - if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { - Some(TaskEncoderCacheState::Started { output_ref, .. }) - | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) - | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), - _ => None, - }) { + if let Some(output_ref) = + Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { + Some(TaskEncoderCacheState::Started { output_ref, .. }) + | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) + | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => { + Some(output_ref.clone()) + } + _ => None, + }) + { return Ok(output_ref); } @@ -252,24 +247,34 @@ pub trait TaskEncoder { Self::encode(task)?; let task_key_clone = task_key.clone(); - if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { - Some(TaskEncoderCacheState::Started { output_ref, .. }) - | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) - | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), - _ => None, - }) { + if let Some(output_ref) = + Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { + Some(TaskEncoderCacheState::Started { output_ref, .. }) + | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) + | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => { + Some(output_ref.clone()) + } + _ => None, + }) + { return Ok(output_ref); } panic!("output ref not found after encoding") // TODO: error? } - fn encode<'vir>(task: Self::TaskDescription<'vir>) -> Result<( - Self::OutputRef<'vir>, - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), TaskEncoderError> - where Self: 'vir + fn encode<'vir>( + task: Self::TaskDescription<'vir>, + ) -> Result< + ( + Self::OutputRef<'vir>, + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + TaskEncoderError, + > + where + Self: 'vir, { let task_key = Self::task_to_key(&task); @@ -290,8 +295,9 @@ pub trait TaskEncoder { output_local.clone(), output_dep.clone(), ))), - TaskEncoderCacheState::Enqueued | TaskEncoderCacheState::Started { .. } => - panic!("Encoding already started or enqueued"), + TaskEncoderCacheState::Enqueued | TaskEncoderCacheState::Started { .. } => { + panic!("Encoding already started or enqueued") + } }, None => { // enqueue @@ -314,160 +320,172 @@ pub trait TaskEncoder { match encode_result { Ok((output_local, output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { - output_ref: output_ref.clone(), - deps, - output_local: output_local.clone(), - output_dep: output_dep.clone(), - })); - Ok(( - output_ref, - output_local, - output_dep, - )) + Self::with_cache(|cache| { + cache.borrow_mut().insert( + task_key, + TaskEncoderCacheState::Encoded { + output_ref: output_ref.clone(), + deps, + output_local: output_local.clone(), + output_dep: output_dep.clone(), + }, + ) + }); + Ok((output_ref, output_local, output_dep)) } Err((err, maybe_output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::ErrorEncode { - output_ref: output_ref.clone(), - deps, - error: TaskEncoderError::EncodingError(err.clone()), - output_dep: maybe_output_dep, - })); + Self::with_cache(|cache| { + cache.borrow_mut().insert( + task_key, + TaskEncoderCacheState::ErrorEncode { + output_ref: output_ref.clone(), + deps, + error: TaskEncoderError::EncodingError(err.clone()), + output_dep: maybe_output_dep, + }, + ) + }); Err(TaskEncoderError::EncodingError(err)) } } } /* - /// Given a task description for this encoder, enqueue it and return the - /// reference to the output. If the task is already enqueued, the output - /// reference already exists. - fn encode<'vir>(task: Self::TaskDescription<'vir>) -> Self::OutputRef<'vir> - where Self: 'vir - { - let task_key = Self::task_to_key(&task); - let task_key_clone = task_key.clone(); - if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { - Some(TaskEncoderCacheState::Enqueued { output_ref }) - | Some(TaskEncoderCacheState::Started { output_ref, .. }) - | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) - | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), - _ => None, - }) { - return output_ref; + /// Given a task description for this encoder, enqueue it and return the + /// reference to the output. If the task is already enqueued, the output + /// reference already exists. + fn encode<'vir>(task: Self::TaskDescription<'vir>) -> Self::OutputRef<'vir> + where Self: 'vir + { + let task_key = Self::task_to_key(&task); + let task_key_clone = task_key.clone(); + if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { + Some(TaskEncoderCacheState::Enqueued { output_ref }) + | Some(TaskEncoderCacheState::Started { output_ref, .. }) + | Some(TaskEncoderCacheState::Encoded { output_ref, .. }) + | Some(TaskEncoderCacheState::ErrorEncode { output_ref, .. }) => Some(output_ref.clone()), + _ => None, + }) { + return output_ref; + } + let task_ref = Self::task_to_output_ref(&task); + let task_key_clone = task_key.clone(); + let task_ref_clone = task_ref.clone(); + assert!(Self::with_cache(move |cache| cache.borrow_mut().insert( + task_key_clone, + TaskEncoderCacheState::Enqueued { output_ref: task_ref_clone }, + ).is_none())); + task_ref } - let task_ref = Self::task_to_output_ref(&task); - let task_key_clone = task_key.clone(); - let task_ref_clone = task_ref.clone(); - assert!(Self::with_cache(move |cache| cache.borrow_mut().insert( - task_key_clone, - TaskEncoderCacheState::Enqueued { output_ref: task_ref_clone }, - ).is_none())); - task_ref - } - - // TODO: this function should not be needed - fn encode_eager<'vir>(task: Self::TaskDescription<'vir>) -> Result<( - Self::OutputRef<'vir>, - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), TaskEncoderError> - where Self: 'vir - { - let task_key = Self::task_to_key(&task); - // enqueue - let output_ref = Self::encode(task); - // process - Self::encode_full(task_key) - .map(|(output_full_local, output_full_dep)| (output_ref, output_full_local, output_full_dep)) - } - /// Given a task key, fully encode the given task. If this task was already - /// finished, the encoding is not repeated. If this task was enqueued, but - /// not finished, return a `CyclicError`. - fn encode_full<'vir>(task_key: Self::TaskKey<'vir>) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), TaskEncoderError> - where Self: 'vir - { - let mut output_ref_opt = None; - let ret = Self::with_cache(|cache| { - // should be queued by now - match cache.borrow().get(&task_key).unwrap() { - TaskEncoderCacheState::Enqueued { output_ref } => { - output_ref_opt = Some(output_ref.clone()); - None - } - TaskEncoderCacheState::Started { .. } => Some(Err(TaskEncoderError::CyclicError)), - TaskEncoderCacheState::Encoded { output_local, output_dep, .. } => - Some(Ok(( - output_local.clone(), - output_dep.clone(), - ))), - TaskEncoderCacheState::ErrorEncode { error, .. } => - Some(Err(error.clone())), - } - }); - if let Some(ret) = ret { - return ret; + // TODO: this function should not be needed + fn encode_eager<'vir>(task: Self::TaskDescription<'vir>) -> Result<( + Self::OutputRef<'vir>, + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), TaskEncoderError> + where Self: 'vir + { + let task_key = Self::task_to_key(&task); + // enqueue + let output_ref = Self::encode(task); + // process + Self::encode_full(task_key) + .map(|(output_full_local, output_full_dep)| (output_ref, output_full_local, output_full_dep)) } - let output_ref = output_ref_opt.unwrap(); - let mut deps: TaskEncoderDependencies<'vir> = Default::default(); - match Self::do_encode_full(&task_key, &mut deps) { - Ok((output_local, output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { - output_ref: output_ref.clone(), - deps, - output_local: output_local.clone(), - output_dep: output_dep.clone(), - })); - Ok(( - output_local, - output_dep, - )) + /// Given a task key, fully encode the given task. If this task was already + /// finished, the encoding is not repeated. If this task was enqueued, but + /// not finished, return a `CyclicError`. + fn encode_full<'vir>(task_key: Self::TaskKey<'vir>) -> Result<( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), TaskEncoderError> + where Self: 'vir + { + let mut output_ref_opt = None; + let ret = Self::with_cache(|cache| { + // should be queued by now + match cache.borrow().get(&task_key).unwrap() { + TaskEncoderCacheState::Enqueued { output_ref } => { + output_ref_opt = Some(output_ref.clone()); + None + } + TaskEncoderCacheState::Started { .. } => Some(Err(TaskEncoderError::CyclicError)), + TaskEncoderCacheState::Encoded { output_local, output_dep, .. } => + Some(Ok(( + output_local.clone(), + output_dep.clone(), + ))), + TaskEncoderCacheState::ErrorEncode { error, .. } => + Some(Err(error.clone())), + } + }); + if let Some(ret) = ret { + return ret; } - Err((err, maybe_output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::ErrorEncode { - output_ref: output_ref.clone(), - deps, - error: TaskEncoderError::EncodingError(err.clone()), - output_dep: maybe_output_dep, - })); - Err(TaskEncoderError::EncodingError(err)) + let output_ref = output_ref_opt.unwrap(); + + let mut deps: TaskEncoderDependencies<'vir> = Default::default(); + match Self::do_encode_full(&task_key, &mut deps) { + Ok((output_local, output_dep)) => { + Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { + output_ref: output_ref.clone(), + deps, + output_local: output_local.clone(), + output_dep: output_dep.clone(), + })); + Ok(( + output_local, + output_dep, + )) + } + Err((err, maybe_output_dep)) => { + Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::ErrorEncode { + output_ref: output_ref.clone(), + deps, + error: TaskEncoderError::EncodingError(err.clone()), + output_dep: maybe_output_dep, + })); + Err(TaskEncoderError::EncodingError(err)) + } } } - } -*/ + */ /// Given a task description, create a key for storing it in the cache. - fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> // Result< - Self::TaskKey<'vir>;//, - // Self::EnqueueingError, - //> -/* - /// Given a task description, create a reference to the output. - fn task_to_output_ref<'vir>(task: &Self::TaskDescription<'vir>) -> Self::OutputRef<'vir>; -*/ + fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir>; //, + // Self::EnqueueingError, + //> + /* + /// Given a task description, create a reference to the output. + fn task_to_output_ref<'vir>(task: &Self::TaskDescription<'vir>) -> Self::OutputRef<'vir>; + */ fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )>; + ) -> Result< + ( + Self::OutputFullLocal<'vir>, + Self::OutputFullDependency<'vir>, + ), + ( + Self::EncodingError, + Option>, + ), + >; fn all_outputs<'vir>() -> Vec> - where Self: 'vir + where + Self: 'vir, { Self::with_cache(|cache| { let mut ret = vec![]; for (_task_key, cache_state) in cache.borrow().iter() { - match cache_state { // TODO: make this into an iterator chain - TaskEncoderCacheState::Encoded { output_local, .. } => ret.push(output_local.clone()), + match cache_state { + // TODO: make this into an iterator chain + TaskEncoderCacheState::Encoded { output_local, .. } => { + ret.push(output_local.clone()) + } _ => {} } } diff --git a/test-crates/src/main.rs b/test-crates/src/main.rs index ec61eddfb13..bb4ebdddc17 100644 --- a/test-crates/src/main.rs +++ b/test-crates/src/main.rs @@ -4,6 +4,10 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use clap::Parser; +use log::{error, info, warn, LevelFilter}; +use rustwide::{cmd, logging, logging::LogStorage, Crate, Toolchain, Workspace, WorkspaceBuilder}; +use serde::Deserialize; use std::{ env, error::Error, @@ -11,10 +15,6 @@ use std::{ path::{Path, PathBuf}, process::Command, }; -use log::{error, info, warn, LevelFilter}; -use rustwide::{cmd, logging, logging::LogStorage, Crate, Toolchain, Workspace, WorkspaceBuilder}; -use serde::Deserialize; -use clap::Parser; /// How a crate should be tested. All tests use `check_panics=false`, `check_overflows=false` and /// `skip_unsupported_features=true`. @@ -145,17 +145,21 @@ struct Args { shard_index: usize, } -fn attempt_fetch(krate: &Crate, workspace: &Workspace, num_retries: u8) -> Result<(), failure::Error> { +fn attempt_fetch( + krate: &Crate, + workspace: &Workspace, + num_retries: u8, +) -> Result<(), failure::Error> { let mut i = 0; while i < num_retries + 1 { if let Err(err) = krate.fetch(workspace) { warn!("Error fetching crate {}: {}", krate, err); if i == num_retries { // Last attempt failed, return the error - return Err(err) + return Err(err); } } else { - return Ok(()) + return Ok(()); } i += 1; } @@ -228,7 +232,12 @@ fn main() -> Result<(), Box> { .collect::, _>>()? .into_iter() .filter(|record| record.name.contains(&args.filter_crate_name)) - .map(|record| (Crate::crates_io(&record.name, &record.version), record.test_kind)) + .map(|record| { + ( + Crate::crates_io(&record.name, &record.version), + record.test_kind, + ) + }) .collect(); info!("There are {} crates in total.", crates_list.len()); @@ -239,8 +248,11 @@ fn main() -> Result<(), Box> { // List of crates on which Prusti succeed. let mut successful_crates = vec![]; - let shard_crates_list: Vec<&(Crate, TestKind)> = crates_list.iter().skip(args.shard_index) - .step_by(args.num_shards).collect(); + let shard_crates_list: Vec<&(Crate, TestKind)> = crates_list + .iter() + .skip(args.shard_index) + .step_by(args.num_shards) + .collect(); info!( "Iterate over the {} crates of the shard {}/{}...", shard_crates_list.len(), @@ -248,7 +260,13 @@ fn main() -> Result<(), Box> { args.num_shards, ); for (index, (krate, test_kind)) in shard_crates_list.iter().enumerate() { - info!("Crate {}/{}: {}, test kind: {:?}", index, shard_crates_list.len(), krate, test_kind); + info!( + "Crate {}/{}: {}, test kind: {:?}", + index, + shard_crates_list.len(), + krate, + test_kind + ); if let TestKind::Skip = test_kind { info!("Skip crate"); @@ -300,11 +318,7 @@ fn main() -> Result<(), Box> { guest_prusti_home, cmd::MountKind::ReadOnly, ) - .mount( - host_viper_home, - guest_viper_home, - cmd::MountKind::ReadOnly, - ) + .mount(host_viper_home, guest_viper_home, cmd::MountKind::ReadOnly) .mount(host_z3_home, guest_z3_home, cmd::MountKind::ReadOnly) .mount(&host_java_home, &guest_java_home, cmd::MountKind::ReadOnly); for java_policy_path in &host_java_policies { @@ -317,7 +331,8 @@ fn main() -> Result<(), Box> { let verification_status = build_dir.build(&toolchain, krate, sandbox).run(|build| { logging::capture(&storage, || { - let mut command = build.cmd(&cargo_prusti) + let mut command = build + .cmd(&cargo_prusti) .env("RUST_BACKTRACE", "1") .env("PRUSTI_ASSERT_TIMEOUT", "60000") .env("PRUSTI_LOG_DIR", "/tmp/prusti_log") @@ -328,13 +343,14 @@ fn main() -> Result<(), Box> { .env("PRUSTI_SKIP_UNSUPPORTED_FEATURES", "true"); match test_kind { TestKind::NoErrorsWithUnreachableUnsupportedCode => { - command = command.env("PRUSTI_ALLOW_UNREACHABLE_UNSUPPORTED_CODE", "true"); + command = + command.env("PRUSTI_ALLOW_UNREACHABLE_UNSUPPORTED_CODE", "true"); } - TestKind::NoErrors => {}, + TestKind::NoErrors => {} TestKind::NoCrash => { // Report internal errors as warnings command = command.env("PRUSTI_INTERNAL_ERRORS_AS_WARNINGS", "true"); - }, + } TestKind::Skip => { unreachable!(); } @@ -383,7 +399,11 @@ fn main() -> Result<(), Box> { } // Panic - assert!(failed_crates.is_empty(), "Failed to verify {} crates", failed_crates.len()); + assert!( + failed_crates.is_empty(), + "Failed to verify {} crates", + failed_crates.len() + ); Ok(()) } diff --git a/tracing/proc-macro-tracing/src/lib.rs b/tracing/proc-macro-tracing/src/lib.rs index 044b9e8b47c..a7687d87613 100644 --- a/tracing/proc-macro-tracing/src/lib.rs +++ b/tracing/proc-macro-tracing/src/lib.rs @@ -9,7 +9,7 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use syn::{ItemFn, ReturnType, token::Paren, Type, TypeParen}; +use syn::{token::Paren, ItemFn, ReturnType, Type, TypeParen}; // Using `tracing::instrument` without this crate (from the vanilla tracing crate) // causes RA to not pick up the return type properly atm - it colours wrong and @@ -21,17 +21,22 @@ pub fn instrument(attr: TokenStream, tokens: TokenStream) -> TokenStream { let (attr, tokens): (TokenStream2, TokenStream2) = (attr.into(), tokens.into()); if let Ok(mut item) = syn::parse2::(tokens.clone()) { if let ReturnType::Type(a, ty) = item.sig.output { - let new_ty = Type::Paren(TypeParen { paren_token: Paren::default(), elem: ty }); + let new_ty = Type::Paren(TypeParen { + paren_token: Paren::default(), + elem: ty, + }); item.sig.output = ReturnType::Type(a, Box::new(new_ty)); } quote! { #[tracing::tracing_instrument(#attr)] #item - }.into() + } + .into() } else { quote! { #[tracing::tracing_instrument(#attr)] #tokens - }.into() + } + .into() } } diff --git a/tracing/src/lib.rs b/tracing/src/lib.rs index e03deef40ed..e9705a3c56b 100644 --- a/tracing/src/lib.rs +++ b/tracing/src/lib.rs @@ -4,5 +4,5 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -pub use tracing::{*, instrument as tracing_instrument}; pub use proc_macro_tracing::instrument; +pub use tracing::{instrument as tracing_instrument, *}; diff --git a/vir-proc-macro/src/lib.rs b/vir-proc-macro/src/lib.rs index b99512fc16e..a072ea905be 100644 --- a/vir-proc-macro/src/lib.rs +++ b/vir-proc-macro/src/lib.rs @@ -3,7 +3,9 @@ use quote::quote; use syn::{parse_macro_input, DeriveInput}; fn is_reify_copy(field: &syn::Field) -> bool { - field.attrs.iter() + field + .attrs + .iter() .filter_map(|attr| match &attr.meta { syn::Meta::Path(p) => Some(&p.segments), _ => None, @@ -29,13 +31,11 @@ pub fn derive_reify(input: TokenStream) -> TokenStream { }; TokenStream::from(match input.data { syn::Data::Struct(syn::DataStruct { - fields: syn::Fields::Named(syn::FieldsNamed { - named, - .. - }), + fields: syn::Fields::Named(syn::FieldsNamed { named, .. }), .. }) => { - let compute_fields = named.iter() + let compute_fields = named + .iter() .filter_map(|field| { let name = field.ident.as_ref().unwrap(); if is_reify_copy(field) { @@ -47,7 +47,8 @@ pub fn derive_reify(input: TokenStream) -> TokenStream { } }) .collect::>(); - let fields = named.iter() + let fields = named + .iter() .map(|field| { let name = field.ident.as_ref().unwrap(); if is_reify_copy(field) { @@ -70,25 +71,21 @@ pub fn derive_reify(input: TokenStream) -> TokenStream { #slice_impl } } - syn::Data::Enum(syn::DataEnum { - variants, - .. - }) => { - let variants = variants.iter() + syn::Data::Enum(syn::DataEnum { variants, .. }) => { + let variants = variants + .iter() .map(|variant| { let variant_name = &variant.ident; match &variant.fields { - syn::Fields::Unnamed(syn::FieldsUnnamed { - unnamed, - .. - }) => { + syn::Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => { let vbinds = (0..unnamed.len()) .map(|idx| quote::format_ident!("v{idx}")) .collect::>(); let obinds = (0..unnamed.len()) .map(|idx| quote::format_ident!("opt{idx}")) .collect::>(); - let compute_fields = unnamed.iter() + let compute_fields = unnamed + .iter() .enumerate() .filter_map(|(idx, field)| { if is_reify_copy(field) { @@ -102,7 +99,8 @@ pub fn derive_reify(input: TokenStream) -> TokenStream { } }) .collect::>(); - let fields = unnamed.iter() + let fields = unnamed + .iter() .enumerate() .map(|(idx, field)| { let vbind = &vbinds[idx]; @@ -120,7 +118,7 @@ pub fn derive_reify(input: TokenStream) -> TokenStream { vcx.alloc(#name::#variant_name(#(#fields),*)) } } - }, + } syn::Fields::Unit => quote! { #name::#variant_name => &#name::#variant_name }, diff --git a/vir/src/context.rs b/vir/src/context.rs index be684a11d76..ef0778c30eb 100644 --- a/vir/src/context.rs +++ b/vir/src/context.rs @@ -1,11 +1,8 @@ -use std::cell::RefCell; use prusti_interface::environment::EnvBody; use prusti_rustc_interface::middle::ty; +use std::cell::RefCell; -use crate::data::*; -use crate::gendata::*; -use crate::genrefs::*; -use crate::refs::*; +use crate::{data::*, gendata::*, genrefs::*, refs::*}; /// The VIR context is a data structure used throughout the encoding process. pub struct VirCtxt<'tcx> { @@ -20,13 +17,11 @@ pub struct VirCtxt<'tcx> { pub span_stack: Vec, // TODO: span stack // TODO: error positions? - /// The compiler's typing context. This allows convenient access to most /// of the compiler's APIs. pub tcx: ty::TyCtxt<'tcx>, pub body: RefCell>, - } impl<'tcx> VirCtxt<'tcx> { @@ -48,25 +43,23 @@ impl<'tcx> VirCtxt<'tcx> { &*self.arena.alloc_str(val) } -/* pub fn alloc_slice<'a, T: Copy>(&'tcx self, val: &'a [T]) -> &'tcx [T] { - &*self.arena.alloc_slice_copy(val) - }*/ + /* pub fn alloc_slice<'a, T: Copy>(&'tcx self, val: &'a [T]) -> &'tcx [T] { + &*self.arena.alloc_slice_copy(val) + }*/ pub fn alloc_slice(&self, val: &[T]) -> &[T] { &*self.arena.alloc_slice_copy(val) } pub fn mk_local<'vir>(&'vir self, name: &'vir str) -> Local<'vir> { - self.arena.alloc(LocalData { - name, - }) + self.arena.alloc(LocalData { name }) } pub fn mk_local_decl(&'tcx self, name: &'tcx str, ty: Type<'tcx>) -> LocalDecl<'tcx> { - self.arena.alloc(LocalDeclData { - name, - ty, - }) + self.arena.alloc(LocalDeclData { name, ty }) } - pub fn mk_local_ex_local(&'tcx self, local: Local<'tcx>) -> ExprGen<'tcx, Curr, Next> { + pub fn mk_local_ex_local( + &'tcx self, + local: Local<'tcx>, + ) -> ExprGen<'tcx, Curr, Next> { self.arena.alloc(ExprGenData::Local(local)) } pub fn mk_local_ex(&'tcx self, name: &'tcx str) -> ExprGen<'tcx, Curr, Next> { @@ -77,16 +70,18 @@ impl<'tcx> VirCtxt<'tcx> { target: &'tcx str, src_args: &[ExprGen<'tcx, Curr, Next>], ) -> ExprGen<'tcx, Curr, Next> { - self.arena.alloc(ExprGenData::FuncApp(self.arena.alloc(FuncAppGenData { - target, - args: self.alloc_slice(src_args), - }))) + self.arena + .alloc(ExprGenData::FuncApp(self.arena.alloc(FuncAppGenData { + target, + args: self.alloc_slice(src_args), + }))) } pub fn mk_pred_app(&'tcx self, target: &'tcx str, src_args: &[Expr<'tcx>]) -> Expr<'tcx> { - self.arena.alloc(ExprData::PredicateApp(self.arena.alloc(PredicateAppData { - target, - args: self.alloc_slice(src_args), - }))) + self.arena + .alloc(ExprData::PredicateApp(self.arena.alloc(PredicateAppData { + target, + args: self.alloc_slice(src_args), + }))) } pub fn mk_true(&'tcx self) -> Expr<'tcx> { diff --git a/vir/src/data.rs b/vir/src/data.rs index d201bb83a40..d9b42b131ef 100644 --- a/vir/src/data.rs +++ b/vir/src/data.rs @@ -1,7 +1,7 @@ use std::fmt::Debug; -use prusti_rustc_interface::middle::mir; use crate::refs::*; +use prusti_rustc_interface::middle::mir; pub struct LocalData<'vir> { pub name: &'vir str, // TODO: identifiers diff --git a/vir/src/debug.rs b/vir/src/debug.rs index e475019abf9..8475f38fce6 100644 --- a/vir/src/debug.rs +++ b/vir/src/debug.rs @@ -1,13 +1,14 @@ use std::fmt::{Debug, Display, Formatter, Result as FmtResult}; -use crate::data::*; -use crate::gendata::*; +use crate::{data::*, gendata::*}; fn fmt_comma_sep_display(f: &mut Formatter<'_>, els: &[T]) -> FmtResult { els.iter() .enumerate() .map(|(idx, el)| { - if idx > 0 { write!(f, ", ")? } + if idx > 0 { + write!(f, ", ")? + } el.fmt(f) }) .collect::() @@ -16,7 +17,9 @@ fn fmt_comma_sep(f: &mut Formatter<'_>, els: &[T]) -> FmtResult { els.iter() .enumerate() .map(|(idx, el)| { - if idx > 0 { write!(f, ", ")? } + if idx > 0 { + write!(f, ", ")? + } el.fmt(f) }) .collect::() @@ -32,10 +35,7 @@ fn fmt_comma_sep_lines(f: &mut Formatter<'_>, els: &[T]) -> FmtResult Ok(()) } fn indent(s: String) -> String { - s - .split("\n") - .intersperse("\n ") - .collect::() + s.split("\n").intersperse("\n ").collect::() } impl<'vir, Curr, Next> Debug for AccFieldGenData<'vir, Curr, Next> { @@ -48,17 +48,21 @@ impl<'vir, Curr, Next> Debug for BinOpGenData<'vir, Curr, Next> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "(")?; self.lhs.fmt(f)?; - write!(f, ") {} (", match self.kind { - BinOpKind::CmpEq => "==", - BinOpKind::CmpNe => "!=", - BinOpKind::CmpGt => ">", - BinOpKind::CmpGe => ">=", - BinOpKind::CmpLt => "<", - BinOpKind::CmpLe => "<=", - BinOpKind::And => "&&", - BinOpKind::Add => "+", - BinOpKind::Sub => "-", - })?; + write!( + f, + ") {} (", + match self.kind { + BinOpKind::CmpEq => "==", + BinOpKind::CmpNe => "!=", + BinOpKind::CmpGt => ">", + BinOpKind::CmpGe => ">=", + BinOpKind::CmpLt => "<", + BinOpKind::CmpLe => "<=", + BinOpKind::And => "&&", + BinOpKind::Add => "+", + BinOpKind::Sub => "-", + } + )?; self.rhs.fmt(f)?; write!(f, ")") } @@ -93,8 +97,14 @@ impl<'vir, Curr, Next> Debug for DomainGenData<'vir, Curr, Next> { write!(f, "]")?; } writeln!(f, " {{")?; - self.axioms.iter().map(|el| el.fmt(f)).collect::()?; - self.functions.iter().map(|el| el.fmt(f)).collect::()?; + self.axioms + .iter() + .map(|el| el.fmt(f)) + .collect::()?; + self.functions + .iter() + .map(|el| el.fmt(f)) + .collect::()?; writeln!(f, "}}") } } @@ -174,8 +184,14 @@ impl<'vir, Curr, Next> Debug for FunctionGenData<'vir, Curr, Next> { writeln!(f, "function {}(", self.name)?; fmt_comma_sep_lines(f, &self.args)?; writeln!(f, "): {:?}", self.ret)?; - self.pres.iter().map(|el| writeln!(f, " requires {:?}", el)).collect::()?; - self.posts.iter().map(|el| writeln!(f, " ensures {:?}", el)).collect::()?; + self.pres + .iter() + .map(|el| writeln!(f, " requires {:?}", el)) + .collect::()?; + self.posts + .iter() + .map(|el| writeln!(f, " ensures {:?}", el)) + .collect::()?; if let Some(expr) = self.expr { write!(f, "{{\n ")?; expr.fmt(f)?; @@ -222,8 +238,14 @@ impl<'vir, Curr, Next> Debug for MethodGenData<'vir, Curr, Next> { } else { writeln!(f, ")")?; } - self.pres.iter().map(|el| writeln!(f, " requires {:?}", el)).collect::()?; - self.posts.iter().map(|el| writeln!(f, " ensures {:?}", el)).collect::()?; + self.pres + .iter() + .map(|el| writeln!(f, " requires {:?}", el)) + .collect::()?; + self.posts + .iter() + .map(|el| writeln!(f, " ensures {:?}", el)) + .collect::()?; if let Some(blocks) = self.blocks.as_ref() { writeln!(f, "{{")?; for block in blocks.iter() { @@ -302,7 +324,11 @@ impl<'vir, Curr, Next> Debug for TerminatorStmtGenData<'vir, Curr, Next> { write!(f, "goto {:?}", data.otherwise) } else { for target in data.targets { - write!(f, "if ({:?} == {:?}) {{ goto {:?} }}\n else", data.value, target.0, target.1)?; + write!( + f, + "if ({:?} == {:?}) {{ goto {:?} }}\n else", + data.value, target.0, target.1 + )?; } write!(f, " {{ goto {:?} }}", data.otherwise) } @@ -343,10 +369,15 @@ impl<'vir> Debug for TypeData<'vir> { impl<'vir, Curr, Next> Debug for UnOpGenData<'vir, Curr, Next> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - write!(f, "{}({:?})", match self.kind { - UnOpKind::Neg => "-", - UnOpKind::Not => "!", - }, self.expr) + write!( + f, + "{}({:?})", + match self.kind { + UnOpKind::Neg => "-", + UnOpKind::Not => "!", + }, + self.expr + ) } } diff --git a/vir/src/gendata.rs b/vir/src/gendata.rs index 2893c45992f..e66fe87b499 100644 --- a/vir/src/gendata.rs +++ b/vir/src/gendata.rs @@ -1,20 +1,20 @@ use std::fmt::Debug; -use crate::data::*; -use crate::genrefs::*; -use crate::refs::*; +use crate::{data::*, genrefs::*, refs::*}; use vir_proc_macro::*; #[derive(Reify)] pub struct UnOpGenData<'vir, Curr, Next> { - #[reify_copy] pub kind: UnOpKind, + #[reify_copy] + pub kind: UnOpKind, pub expr: ExprGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct BinOpGenData<'vir, Curr, Next> { - #[reify_copy] pub kind: BinOpKind, + #[reify_copy] + pub kind: BinOpKind, pub lhs: ExprGen<'vir, Curr, Next>, pub rhs: ExprGen<'vir, Curr, Next>, } @@ -28,20 +28,23 @@ pub struct TernaryGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct ForallGenData<'vir, Curr, Next> { - #[reify_copy] pub qvars: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub qvars: &'vir [LocalDecl<'vir>], pub triggers: &'vir [&'vir [ExprGen<'vir, Curr, Next>]], pub body: ExprGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct FuncAppGenData<'vir, Curr, Next> { - #[reify_copy] pub target: &'vir str, // TODO: identifiers + #[reify_copy] + pub target: &'vir str, // TODO: identifiers pub args: &'vir [ExprGen<'vir, Curr, Next>], } #[derive(Reify)] pub struct PredicateAppGenData<'vir, Curr, Next> { - #[reify_copy] pub target: &'vir str, // TODO: identifiers + #[reify_copy] + pub target: &'vir str, // TODO: identifiers pub args: &'vir [ExprGen<'vir, Curr, Next>], } @@ -54,13 +57,15 @@ pub struct UnfoldingGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct AccFieldGenData<'vir, Curr, Next> { pub recv: ExprGen<'vir, Curr, Next>, - #[reify_copy] pub field: &'vir str, // TODO: identifiers - // TODO: permission amount + #[reify_copy] + pub field: &'vir str, // TODO: identifiers + // TODO: permission amount } #[derive(Reify)] pub struct LetGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, + #[reify_copy] + pub name: &'vir str, pub val: ExprGen<'vir, Curr, Next>, pub expr: ExprGen<'vir, Curr, Next>, } @@ -103,8 +108,10 @@ pub enum ExprGenData<'vir, Curr: 'vir, Next: 'vir> { PredicateApp(PredicateAppGen<'vir, Curr, Next>), // TODO: this should not be used instead of acc? // domain func app // inhale/exhale - - Lazy(&'vir str, Box, Curr) -> Next + 'vir>), + Lazy( + &'vir str, + Box, Curr) -> Next + 'vir>, + ), Todo(&'vir str), } @@ -120,30 +127,39 @@ impl<'vir, Curr, Next> ExprGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct DomainAxiomGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // ? or comment, then auto-gen the names? + #[reify_copy] + pub name: &'vir str, // ? or comment, then auto-gen the names? pub expr: ExprGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct DomainGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub typarams: &'vir [&'vir str], + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub typarams: &'vir [&'vir str], pub axioms: &'vir [DomainAxiomGen<'vir, Curr, Next>], - #[reify_copy] pub functions: &'vir [DomainFunction<'vir>], + #[reify_copy] + pub functions: &'vir [DomainFunction<'vir>], } #[derive(Reify)] pub struct PredicateGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub args: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub args: &'vir [LocalDecl<'vir>], pub expr: Option>, } #[derive(Reify)] pub struct FunctionGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub args: &'vir [LocalDecl<'vir>], - #[reify_copy] pub ret: Type<'vir>, + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub args: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub ret: Type<'vir>, pub pres: &'vir [ExprGen<'vir, Curr, Next>], pub posts: &'vir [ExprGen<'vir, Curr, Next>], pub expr: Option>, @@ -160,14 +176,19 @@ pub struct PureAssignGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct MethodCallGenData<'vir, Curr, Next> { - #[reify_copy] pub targets: &'vir [Local<'vir>], - #[reify_copy] pub method: &'vir str, + #[reify_copy] + pub targets: &'vir [Local<'vir>], + #[reify_copy] + pub method: &'vir str, pub args: &'vir [ExprGen<'vir, Curr, Next>], } #[derive(Reify)] pub enum StmtGenData<'vir, Curr, Next> { - LocalDecl(#[reify_copy] LocalDecl<'vir>, Option>), + LocalDecl( + #[reify_copy] LocalDecl<'vir>, + Option>, + ), PureAssign(PureAssignGen<'vir, Curr, Next>), Inhale(ExprGen<'vir, Curr, Next>), Exhale(ExprGen<'vir, Curr, Next>), @@ -182,7 +203,8 @@ pub enum StmtGenData<'vir, Curr, Next> { pub struct GotoIfGenData<'vir, Curr, Next> { pub value: ExprGen<'vir, Curr, Next>, pub targets: &'vir [(ExprGen<'vir, Curr, Next>, CfgBlockLabel<'vir>)], - #[reify_copy] pub otherwise: CfgBlockLabel<'vir>, + #[reify_copy] + pub otherwise: CfgBlockLabel<'vir>, } #[derive(Reify)] @@ -196,16 +218,20 @@ pub enum TerminatorStmtGenData<'vir, Curr, Next> { #[derive(Debug, Reify)] pub struct CfgBlockGenData<'vir, Curr, Next> { - #[reify_copy] pub label: CfgBlockLabel<'vir>, + #[reify_copy] + pub label: CfgBlockLabel<'vir>, pub stmts: &'vir [StmtGen<'vir, Curr, Next>], pub terminator: TerminatorStmtGen<'vir, Curr, Next>, } #[derive(Reify)] pub struct MethodGenData<'vir, Curr, Next> { - #[reify_copy] pub name: &'vir str, // TODO: identifiers - #[reify_copy] pub args: &'vir [LocalDecl<'vir>], - #[reify_copy] pub rets: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub name: &'vir str, // TODO: identifiers + #[reify_copy] + pub args: &'vir [LocalDecl<'vir>], + #[reify_copy] + pub rets: &'vir [LocalDecl<'vir>], // TODO: pre/post - add a comment variant pub pres: &'vir [ExprGen<'vir, Curr, Next>], pub posts: &'vir [ExprGen<'vir, Curr, Next>], @@ -214,7 +240,8 @@ pub struct MethodGenData<'vir, Curr, Next> { #[derive(Debug, Reify)] pub struct ProgramGenData<'vir, Curr, Next> { - #[reify_copy] pub fields: &'vir [Field<'vir>], + #[reify_copy] + pub fields: &'vir [Field<'vir>], pub domains: &'vir [DomainGen<'vir, Curr, Next>], pub predicates: &'vir [PredicateGen<'vir, Curr, Next>], pub functions: &'vir [FunctionGen<'vir, Curr, Next>], diff --git a/vir/src/genrefs.rs b/vir/src/genrefs.rs index f21ca3c8ef8..f98eb0bef1a 100644 --- a/vir/src/genrefs.rs +++ b/vir/src/genrefs.rs @@ -1,7 +1,8 @@ pub type AccFieldGen<'vir, Curr, Next> = &'vir crate::gendata::AccFieldGenData<'vir, Curr, Next>; pub type BinOpGen<'vir, Curr, Next> = &'vir crate::gendata::BinOpGenData<'vir, Curr, Next>; pub type CfgBlockGen<'vir, Curr, Next> = &'vir crate::gendata::CfgBlockGenData<'vir, Curr, Next>; -pub type DomainAxiomGen<'vir, Curr, Next> = &'vir crate::gendata::DomainAxiomGenData<'vir, Curr, Next>; +pub type DomainAxiomGen<'vir, Curr, Next> = + &'vir crate::gendata::DomainAxiomGenData<'vir, Curr, Next>; pub type DomainGen<'vir, Curr, Next> = &'vir crate::gendata::DomainGenData<'vir, Curr, Next>; pub type ExprGen<'vir, Curr, Next> = &'vir crate::gendata::ExprGenData<'vir, Curr, Next>; pub type ForallGen<'vir, Curr, Next> = &'vir crate::gendata::ForallGenData<'vir, Curr, Next>; @@ -10,13 +11,17 @@ pub type FunctionGen<'vir, Curr, Next> = &'vir crate::gendata::FunctionGenData<' pub type GotoIfGen<'vir, Curr, Next> = &'vir crate::gendata::GotoIfGenData<'vir, Curr, Next>; pub type LetGen<'vir, Curr, Next> = &'vir crate::gendata::LetGenData<'vir, Curr, Next>; pub type MethodGen<'vir, Curr, Next> = &'vir crate::gendata::MethodGenData<'vir, Curr, Next>; -pub type MethodCallGen<'vir, Curr, Next> = &'vir crate::gendata::MethodCallGenData<'vir, Curr, Next>; +pub type MethodCallGen<'vir, Curr, Next> = + &'vir crate::gendata::MethodCallGenData<'vir, Curr, Next>; pub type PredicateGen<'vir, Curr, Next> = &'vir crate::gendata::PredicateGenData<'vir, Curr, Next>; -pub type PredicateAppGen<'vir, Curr, Next> = &'vir crate::gendata::PredicateAppGenData<'vir, Curr, Next>; +pub type PredicateAppGen<'vir, Curr, Next> = + &'vir crate::gendata::PredicateAppGenData<'vir, Curr, Next>; pub type ProgramGen<'vir, Curr, Next> = &'vir crate::gendata::ProgramGenData<'vir, Curr, Next>; -pub type PureAssignGen<'vir, Curr, Next> = &'vir crate::gendata::PureAssignGenData<'vir, Curr, Next>; +pub type PureAssignGen<'vir, Curr, Next> = + &'vir crate::gendata::PureAssignGenData<'vir, Curr, Next>; pub type StmtGen<'vir, Curr, Next> = &'vir crate::gendata::StmtGenData<'vir, Curr, Next>; -pub type TerminatorStmtGen<'vir, Curr, Next> = &'vir crate::gendata::TerminatorStmtGenData<'vir, Curr, Next>; +pub type TerminatorStmtGen<'vir, Curr, Next> = + &'vir crate::gendata::TerminatorStmtGenData<'vir, Curr, Next>; pub type TernaryGen<'vir, Curr, Next> = &'vir crate::gendata::TernaryGenData<'vir, Curr, Next>; pub type UnOpGen<'vir, Curr, Next> = &'vir crate::gendata::UnOpGenData<'vir, Curr, Next>; pub type UnfoldingGen<'vir, Curr, Next> = &'vir crate::gendata::UnfoldingGenData<'vir, Curr, Next>; diff --git a/vir/src/macros.rs b/vir/src/macros.rs index d9210baf377..80476b45dbe 100644 --- a/vir/src/macros.rs +++ b/vir/src/macros.rs @@ -1,6 +1,6 @@ //#[macro_export] //macro_rules! vir_expr_nopos { -// +// //} //#[macro_export] @@ -112,8 +112,12 @@ macro_rules! vir_expr { #[macro_export] macro_rules! vir_ident { - ($vcx:expr; [ $name:expr ]) => { $name }; - ($vcx:expr; $name:ident ) => { $vcx.alloc_str(stringify!($name)) }; + ($vcx:expr; [ $name:expr ]) => { + $name + }; + ($vcx:expr; $name:ident ) => { + $vcx.alloc_str(stringify!($name)) + }; } #[macro_export] @@ -123,10 +127,18 @@ macro_rules! vir_format { #[macro_export] macro_rules! vir_type { - ($vcx:expr; Int) => { $vcx.alloc($crate::TypeData::Int) }; - ($vcx:expr; Bool) => { $vcx.alloc($crate::TypeData::Bool) }; - ($vcx:expr; Ref) => { $vcx.alloc($crate::TypeData::Ref) }; - ($vcx:expr; [ $ty:expr ]) => { $ty }; + ($vcx:expr; Int) => { + $vcx.alloc($crate::TypeData::Int) + }; + ($vcx:expr; Bool) => { + $vcx.alloc($crate::TypeData::Bool) + }; + ($vcx:expr; Ref) => { + $vcx.alloc($crate::TypeData::Ref) + }; + ($vcx:expr; [ $ty:expr ]) => { + $ty + }; ($vcx:expr; $name:ident) => { $vcx.alloc($crate::TypeData::Domain($vcx.alloc_str(stringify!($name)))) }; diff --git a/vir/src/reify.rs b/vir/src/reify.rs index 6c1106a2731..6e40b4d69ae 100644 --- a/vir/src/reify.rs +++ b/vir/src/reify.rs @@ -1,18 +1,11 @@ -use crate::VirCtxt; -use crate::gendata::*; -use crate::genrefs::*; -use crate::refs::*; +use crate::{gendata::*, genrefs::*, refs::*, VirCtxt}; pub use vir_proc_macro::*; pub trait Reify<'vir, Curr, NextA, NextB> { type Next: Sized; - fn reify( - &self, - vcx: &'vir VirCtxt<'vir>, - lctx: Curr, - ) -> Self::Next; + fn reify(&self, vcx: &'vir VirCtxt<'vir>, lctx: Curr) -> Self::Next; } impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> @@ -31,7 +24,9 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> ExprGenData::Forall(v) => vcx.alloc(ExprGenData::Forall(v.reify(vcx, lctx))), ExprGenData::Let(v) => vcx.alloc(ExprGenData::Let(v.reify(vcx, lctx))), ExprGenData::FuncApp(v) => vcx.alloc(ExprGenData::FuncApp(v.reify(vcx, lctx))), - ExprGenData::PredicateApp(v) => vcx.alloc(ExprGenData::PredicateApp(v.reify(vcx, lctx))), + ExprGenData::PredicateApp(v) => { + vcx.alloc(ExprGenData::PredicateApp(v.reify(vcx, lctx))) + } ExprGenData::Local(v) => vcx.alloc(ExprGenData::Local(v)), ExprGenData::Const(v) => vcx.alloc(ExprGenData::Const(v)), @@ -51,9 +46,12 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> { type Next = &'vir [ExprGen<'vir, NextA, NextB>]; fn reify(&self, vcx: &'vir VirCtxt<'vir>, lctx: Curr) -> Self::Next { - vcx.alloc_slice(&self.iter() - .map(|elem| elem.reify(vcx, lctx)) - .collect::>()) + vcx.alloc_slice( + &self + .iter() + .map(|elem| elem.reify(vcx, lctx)) + .collect::>(), + ) } } @@ -62,20 +60,29 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> { type Next = &'vir [&'vir [ExprGen<'vir, NextA, NextB>]]; fn reify(&self, vcx: &'vir VirCtxt<'vir>, lctx: Curr) -> Self::Next { - vcx.alloc_slice(&self.iter() - .map(|elem| elem.reify(vcx, lctx)) - .collect::>()) + vcx.alloc_slice( + &self + .iter() + .map(|elem| elem.reify(vcx, lctx)) + .collect::>(), + ) } } impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> - for [(ExprGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>, CfgBlockLabel<'vir>)] + for [( + ExprGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>, + CfgBlockLabel<'vir>, + )] { type Next = &'vir [(ExprGen<'vir, NextA, NextB>, CfgBlockLabel<'vir>)]; fn reify(&self, vcx: &'vir VirCtxt<'vir>, lctx: Curr) -> Self::Next { - vcx.alloc_slice(&self.iter() - .map(|(elem, label)| (elem.reify(vcx, lctx), *label)) - .collect::>()) + vcx.alloc_slice( + &self + .iter() + .map(|(elem, label)| (elem.reify(vcx, lctx), *label)) + .collect::>(), + ) } } @@ -97,7 +104,6 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> } } - /* impl< 'vir, From cc6f75cb278ee228649b1a2afbb58ccdef661717 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Fri, 3 Nov 2023 13:23:02 +0100 Subject: [PATCH 06/18] WIP on enums --- prusti-encoder/src/encoders/typ.rs | 1087 ++++++++++++++++------------ prusti-encoder/src/lib.rs | 1 + vir/src/context.rs | 9 + 3 files changed, 630 insertions(+), 467 deletions(-) diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index 99423544bf4..8f15933a771 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -1,6 +1,8 @@ use prusti_rustc_interface::middle::ty; +use rustc_middle::ty::VariantDef; use rustc_type_ir::sty::TyKind; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use vir::{FunctionGenData, TypeData}; pub struct TypeEncoder; @@ -21,6 +23,7 @@ pub enum TypeEncoderOutputRefSub<'vir> { Primitive, // structs, tuples StructLike(TypeEncoderOutputRefSubStruct<'vir>), + Enum, } // TODO: should output refs actually be references to structs...? @@ -40,6 +43,7 @@ pub struct TypeEncoderOutputRef<'vir> { impl<'vir> task_encoder::OutputRefAny<'vir> for TypeEncoderOutputRef<'vir> {} impl<'vir> TypeEncoderOutputRef<'vir> { + #[track_caller] pub fn expect_structlike(&self) -> &TypeEncoderOutputRefSubStruct<'vir> { match self.specifics { TypeEncoderOutputRefSub::StructLike(ref data) => data, @@ -158,440 +162,8 @@ impl TaskEncoder for TypeEncoder { Option>, ), > { - fn mk_unreachable<'vir>( - vcx: &'vir vir::VirCtxt, - snapshot_name: &'vir str, - snapshot_ty: vir::Type<'vir>, - ) -> vir::Function<'vir> { - vcx.alloc(vir::FunctionData { - name: vir::vir_format!(vcx, "{snapshot_name}_unreachable"), // TODO: pass from outside? - args: &[], - ret: snapshot_ty, - pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::Todo("false"))]), - posts: &[], - expr: None, - }) - } - fn mk_simple_predicate<'vir>( - vcx: &'vir vir::VirCtxt<'vir>, - predicate_name: &'vir str, - field_name: &'vir str, - ) -> vir::Predicate<'vir> { - let predicate_body = vcx.alloc(vir::ExprData::AccField(vcx.alloc(vir::AccFieldData { - recv: vcx.mk_local_ex("self_p"), - field: field_name, - }))); - vir::vir_predicate! { vcx; predicate [predicate_name](self_p: Ref) { [predicate_body] } } - } - /* - fn mk_refold<'vir>( - vcx: &'vir vir::VirCtxt<'vir>, - predicate_name: &'vir str, - snapshot_ty: vir::Type<'vir>, - ) -> vir::Method<'vir> { - vcx.alloc(vir::MethodData { - name: vir::vir_format!(vcx, "refold_{predicate_name}"), - args: vcx.alloc_slice(&[ - vcx.alloc(vir::LocalDeclData { - name: "_p", - ty: vcx.alloc(vir::TypeData::Ref), - }), - vcx.alloc(vir::LocalDeclData { - name: "_s_old", - ty: snapshot_ty, - }), - vcx.alloc(vir::LocalDeclData { - name: "_s_new", - ty: snapshot_ty, - }), - ]), - rets: &[], - pres: vcx.alloc_slice(&[ - vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc(vir::PredicateAppData { - target: predicate_name, - args: vcx.alloc_slice(&[ - vcx.mk_local_ex("_p"), - vcx.mk_local_ex("_s_old"), - ]), - }))), - ]), - posts: vcx.alloc_slice(&[ - vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc(vir::PredicateAppData { - target: predicate_name, - args: vcx.alloc_slice(&[ - vcx.mk_local_ex("_p"), - vcx.mk_local_ex("_s_new"), - ]), - }))), - ]), - blocks: None, - }) - } - */ - // TODO: there is a lot of duplication here, both in these assign/ - // reassign methods, and in the match cases below - // also: is mk_assign really worth it? (used in constant method - // arguments only) - fn mk_assign<'vir>( - vcx: &'vir vir::VirCtxt<'vir>, - predicate_name: &'vir str, - snapshot_ty: vir::Type<'vir>, - ) -> vir::Method<'vir> { - vcx.alloc(vir::MethodData { - name: vir::vir_format!(vcx, "assign_{predicate_name}"), - args: vcx.alloc_slice(&[ - vcx.alloc(vir::LocalDeclData { - name: "_p", - ty: vcx.alloc(vir::TypeData::Ref), - }), - vcx.alloc(vir::LocalDeclData { - name: "_s_new", - ty: snapshot_ty, - }), - ]), - rets: &[], - pres: &[], - posts: vcx.alloc_slice(&[ - vcx.mk_pred_app(predicate_name, &[vcx.mk_local_ex("_p")]), - vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: vcx.mk_func_app( - vir::vir_format!(vcx, "{predicate_name}_snap"), - &[vcx.mk_local_ex("_p")], - ), - rhs: vcx.mk_local_ex("_s_new"), - }))), - ]), - blocks: None, - }) - } - fn mk_snap<'vir>( - vcx: &'vir vir::VirCtxt<'vir>, - predicate_name: &'vir str, - snapshot_name: &'vir str, - field_name: Option<&'vir str>, - snapshot_ty: vir::Type<'vir>, - ) -> vir::Function<'vir> { - let pred_app = vcx.alloc(vir::PredicateAppData { - target: predicate_name, - args: vcx.alloc_slice(&[vcx.mk_local_ex("self")]), - }); - vcx.alloc(vir::FunctionData { - name: vir::vir_format!(vcx, "{predicate_name}_snap"), - args: vcx.alloc_slice(&[vcx.mk_local_decl("self", &vir::TypeData::Ref)]), - ret: snapshot_ty, - pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), - posts: &[], - expr: field_name.map(|field_name| { - vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: vcx.alloc(vir::ExprData::Field(vcx.mk_local_ex("self"), field_name)), - }))) - }), - }) - } - fn mk_structlike<'vir>( - vcx: &'vir vir::VirCtxt<'vir>, - deps: &mut TaskEncoderDependencies<'vir>, - task_key: &::TaskKey<'vir>, - name_s: &'vir str, - name_p: &'vir str, - field_ty_out: Vec>, - ) -> Result< - ::OutputFullLocal<'vir>, - ( - ::EncodingError, - Option<::OutputFullDependency<'vir>>, - ), - > { - let mut field_read_names = Vec::new(); - let mut field_write_names = Vec::new(); - let mut field_projection_p_names = Vec::new(); - for idx in 0..field_ty_out.len() { - field_read_names.push(vir::vir_format!(vcx, "{name_s}_read_{idx}")); - field_write_names.push(vir::vir_format!(vcx, "{name_s}_write_{idx}")); - field_projection_p_names.push(vir::vir_format!(vcx, "{name_p}_field_{idx}")); - } - let field_read_names = vcx.alloc_slice(&field_read_names); - let field_write_names = vcx.alloc_slice(&field_write_names); - let field_projection_p_names = vcx.alloc_slice(&field_projection_p_names); - - deps.emit_output_ref::( - *task_key, - TypeEncoderOutputRef { - snapshot_name: name_s, - predicate_name: name_p, - from_primitive: None, - to_primitive: None, - snapshot: vcx.alloc(vir::TypeData::Domain(name_s)), - function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), - function_snap: vir::vir_format!(vcx, "{name_p}_snap"), - //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), - specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { - field_read: field_read_names, - field_write: field_write_names, - field_projection_p: field_projection_p_names, - }), - method_assign: vir::vir_format!(vcx, "assign_{name_p}"), - }, - ); - let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); - - let mut funcs: Vec> = vec![]; - let mut axioms: Vec> = vec![]; - - let cons_name = vir::vir_format!(vcx, "{name_s}_cons"); - funcs.push( - vcx.alloc(vir::DomainFunctionData { - unique: false, - name: cons_name, - args: vcx.alloc_slice( - &field_ty_out - .iter() - .map(|field_ty_out| field_ty_out.snapshot) - .collect::>(), - ), - ret: ty_s, - }), - ); - - let mut field_projection_p = Vec::new(); - for (idx, ty_out) in field_ty_out.iter().enumerate() { - let name_r = vir::vir_format!(vcx, "{name_s}_read_{idx}"); - funcs.push( - vir::vir_domain_func! { vcx; function [name_r]([ty_s]): [ty_out.snapshot] }, - ); - - let name_w = vir::vir_format!(vcx, "{name_s}_write_{idx}"); - funcs.push(vir::vir_domain_func! { vcx; function [name_w]([ty_s], [ty_out.snapshot]): [ty_s] }); - - field_projection_p.push(vcx.alloc(vir::FunctionData { - name: vir::vir_format!(vcx, "{name_p}_field_{idx}"), - args: vcx.alloc_slice(&[vcx.mk_local_decl("self", &vir::TypeData::Ref)]), - ret: &vir::TypeData::Ref, - pres: &[], - posts: &[], - expr: None, - })); - } - let field_projection_p = vcx.alloc_slice(&field_projection_p); - - for (write_idx, write_ty_out) in field_ty_out.iter().enumerate() { - for (read_idx, _read_ty_out) in field_ty_out.iter().enumerate() { - axioms.push(vcx.alloc(vir::DomainAxiomData { - name: vir::vir_format!( - vcx, - "ax_{name_s}_write_{write_idx}_read_{read_idx}" - ), - expr: if read_idx == write_idx { - vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", ty_s), - vcx.mk_local_decl("val", write_ty_out.snapshot), - ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], - )], - )])]), - body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], - )], - ), - rhs: vcx.mk_local_ex("val"), - }))), - }))) - } else { - vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", ty_s), - vcx.mk_local_decl("val", write_ty_out.snapshot), - ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], - )], - )])]), - body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], - )], - ), - rhs: vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_local_ex("self")], - ), - }))), - }))) - }, - })); - } - } - - // constructor - { - let cons_qvars = vcx.alloc_slice( - &field_ty_out - .iter() - .enumerate() - .map(|(idx, field_ty_out)| { - vcx.mk_local_decl( - vir::vir_format!(vcx, "f{idx}"), - field_ty_out.snapshot, - ) - }) - .collect::>(), - ); - let cons_args = field_ty_out - .iter() - .enumerate() - .map(|(idx, _field_ty_out)| vcx.mk_local_ex(vir::vir_format!(vcx, "f{idx}"))) - .collect::>(); - let cons_call = vcx.mk_func_app(cons_name, &cons_args); - - for (read_idx, _) in field_ty_out.iter().enumerate() { - axioms.push(vcx.alloc(vir::DomainAxiomData { - name: vir::vir_format!(vcx, "ax_{name_s}_cons_read_{read_idx}"), - expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: cons_qvars.clone(), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), - body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[cons_call], - ), - rhs: cons_args[read_idx], - }))), - }))), - })); - } - - let cons_call_with_reads = vcx.mk_func_app( - cons_name, - &field_ty_out - .iter() - .enumerate() - .map(|(idx, _field_ty_out)| { - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{idx}"), - &[vcx.mk_local_ex("self")], - ) - }) - .collect::>(), - ); - axioms.push(vcx.alloc(vir::DomainAxiomData { - name: vir::vir_format!(vcx, "ax_{name_s}_cons"), - expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call_with_reads])]), - body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: cons_call_with_reads, - rhs: vcx.mk_local_ex("self"), - }))), - }))), - })); - } - - // predicate - let predicate = { - let expr = field_ty_out - .iter() - .enumerate() - .map(|(idx, field_ty_out)| { - vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc( - vir::PredicateAppData { - target: field_ty_out.predicate_name, - args: vcx.alloc_slice(&[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_p}_field_{idx}"), - &[vcx.mk_local_ex("self_p")], - )]), - }, - ))) - }) - .reduce(|base, field_expr| { - vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::And, - lhs: base, - rhs: field_expr, - }))) - }) - .unwrap_or_else(|| vcx.mk_true()); - vcx.alloc(vir::PredicateData { - name: name_p, - args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), - expr: Some(expr), - }) - }; - - Ok(TypeEncoderOutput { - fields: &[], - snapshot: vir::vir_domain! { vcx; domain [name_s] { - with_funcs [funcs]; - with_axioms [axioms]; - } }, - predicate, - function_unreachable: mk_unreachable(vcx, name_s, ty_s), - function_snap: { - let pred_app = vcx.alloc(vir::PredicateAppData { - target: name_p, - args: vcx.alloc_slice(&[vcx.mk_local_ex("self_p")]), - }); - vcx.alloc(vir::FunctionData { - name: vir::vir_format!(vcx, "{name_p}_snap"), - args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), - ret: ty_s, - pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), - posts: &[], - expr: Some( - vcx.alloc(vir::ExprData::Unfolding( - vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: vcx.mk_func_app( - cons_name, - vcx.alloc_slice( - &field_ty_out - .iter() - .enumerate() - .map(|(idx, field_ty_out)| { - vcx.mk_func_app( - field_ty_out.function_snap, - &[vcx.mk_func_app( - vir::vir_format!( - vcx, - "{name_p}_field_{idx}" - ), - &[vcx.mk_local_ex("self_p")], - )], - ) - }) - .collect::>(), - ), - ), - }), - )), - ), - }) - }, - //method_refold: mk_refold(vcx, name_p, ty_s), - field_projection_p, - method_assign: mk_assign(vcx, name_p, ty_s), - }) - } + // TODO: remove. This is here to get rust-analyzer to know the type. + let task_key: &ty::Ty = task_key; vir::with_vcx(|vcx| match task_key.kind() { TyKind::Bool => { @@ -749,39 +321,6 @@ impl TaskEncoder for TypeEncoder { )?, (), )) - - /* - let ty_len = tys.len(); - let name_s = vir::vir_format!(vcx, "s_Tuple{ty_len}"); - let name_p = vir::vir_format!(vcx, "p_Tuple{ty_len}"); - let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); - - deps.emit_output_ref::(*task_key, TypeEncoderOutputRef { - snapshot_name: name_s, - predicate_name: name_p, - snapshot: ty_s, - function_unreachable: "s_Tuple0_unreachable", - function_snap: "p_Tuple0_snap", - //method_refold: "refold_p_Tuple0", - specifics: TypeEncoderOutputRefSub::Primitive, - method_assign: vir::vir_format!(vcx, "assign_p_Tuple0"), - }); - Ok((TypeEncoderOutput { - fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { - name: vir::vir_format!(vcx, "f_Tuple0"), - ty: ty_s, - })]), - snapshot: vir::vir_domain! { vcx; domain s_Tuple0 { - function s_Tuple0_cons(): [ty_s]; - } }, - predicate: vir::vir_predicate! { vcx; predicate p_Tuple0(self_p: Ref) }, - function_unreachable: mk_unreachable(vcx, "s_Tuple0", ty_s), - function_snap: mk_snap(vcx, "p_Tuple0", "s_Tuple0", None, ty_s), - //method_refold: mk_refold(vcx, "p_Tuple0", ty_s), - field_projection_p: &[], - method_assign: mk_assign(vcx, "p_Tuple0", ty_s), - }, ())) - */ } TyKind::Param(_param) => { @@ -843,6 +382,7 @@ impl TaskEncoder for TypeEncoder { (), )) } + TyKind::Never => { let ty_s = vcx.alloc(vir::TypeData::Domain("s_Never")); deps.emit_output_ref::( @@ -877,8 +417,621 @@ impl TaskEncoder for TypeEncoder { (), )) } + + TyKind::Adt(adt_def, substs) if adt_def.is_enum() => { + tracing::error!("encoding enum {adt_def:#?} with substs {substs:?}"); + tracing::warn!("{:?}", adt_def.all_fields().collect::>()); + tracing::warn!("{:#?}", adt_def.variants()); + + Ok((mk_enum(vcx, deps, adt_def, task_key)?, ())) + } + //_ => Err((TypeEncoderError::UnsupportedType, None)), unsupported_type => todo!("type not supported: {unsupported_type:?}"), }) } } + +fn mk_enum<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + deps: &mut TaskEncoderDependencies<'vir>, + adt: &ty::AdtDef, + task_key: &::TaskKey<'vir>, +) -> Result< + ::OutputFullLocal<'vir>, + ( + ::EncodingError, + Option<::OutputFullDependency<'vir>>, + ), +> { + let did_name = vcx.tcx.item_name(adt.did()).to_ident_string(); + + let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}"); + let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}"); + + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: name_s, + predicate_name: name_p, + from_primitive: None, + to_primitive: None, + snapshot: vcx.alloc(vir::TypeData::Domain(name_s)), + function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), + function_snap: vir::vir_format!(vcx, "{name_p}_snap"), + //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), + specifics: TypeEncoderOutputRefSub::Enum, + method_assign: vir::vir_format!(vcx, "assign_{name_p}"), + }, + ); + let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); + + // TODO: discriminant function + + // TODO: discriminant bounds axioms + + let mut funcs: Vec> = vec![]; + let mut axioms: Vec> = vec![]; + let mut field_projection_p = Vec::new(); + + for variant in adt.variants() { + mk_emum_variant( + vcx, + deps, + variant, + ty_s, + name_s, + name_p, + &mut funcs, + &mut field_projection_p, + &mut axioms, + ) + } + + let field_projection_p = vcx.alloc_slice(&field_projection_p); + + + let predicate = vcx.alloc(vir::PredicateData { + name: name_p, + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), + expr: None, + }); //TODO + + + let pred_app = vcx.alloc(vir::PredicateAppData { + target: name_p, + args: vcx.alloc_slice(&[vcx.mk_local_ex("self_p")]), + }); + + let function_snap =vcx.alloc(vir::FunctionData { + name: vir::vir_format!(vcx, "{name_p}_snap"), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), + ret: ty_s, + pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), + posts: &[], + expr: None, //TODO + }); + + Ok(TypeEncoderOutput { + fields: &[], + snapshot: vir::vir_domain! { vcx; domain [name_s] { + with_funcs [funcs]; + with_axioms [axioms]; + } }, + predicate, + function_unreachable: mk_unreachable(vcx, name_s, ty_s), + function_snap, + //method_refold: mk_refold(vcx, name_p, ty_s), + field_projection_p, + method_assign: mk_assign(vcx, name_p, ty_s), + }) +} + +fn mk_emum_variant<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + deps: &mut TaskEncoderDependencies<'vir>, + variant: &VariantDef, + ty_s: &'vir TypeData<'vir>, + parent_name_s: &'vir str, + parent_name_p: &'vir str, + funcs: &mut Vec<&'vir vir::DomainFunctionData<'vir>>, + field_projection_p: &mut Vec<&'vir FunctionGenData<'vir, !, !>>, + axioms: &mut Vec>, +) { + let substs = ty::List::identity_for_item(vcx.tcx, variant.def_id); + let fields = variant + .fields + .iter() + .map(|field| { + deps.require_ref::(field.ty(vcx.tcx, substs)) + .unwrap() + }) + .collect::>(); + + let did_name = vcx.tcx.item_name(variant.def_id).to_ident_string(); + let name_s = vir::vir_format!(vcx, "{parent_name_s}_variant_{did_name}"); + let name_p = vir::vir_format!(vcx, "{parent_name_p}_variant_{did_name}"); + + + mk_field_projection_p( + &fields, + vcx, + ty_s, + name_s, + name_p, + funcs, + field_projection_p, + ); + + read_write_axioms(vcx, ty_s, name_s, &fields, axioms); +} + +fn mk_unreachable<'vir>( + vcx: &'vir vir::VirCtxt, + snapshot_name: &'vir str, + snapshot_ty: vir::Type<'vir>, +) -> vir::Function<'vir> { + vcx.alloc(vir::FunctionData { + name: vir::vir_format!(vcx, "{snapshot_name}_unreachable"), // TODO: pass from outside? + args: &[], + ret: snapshot_ty, + pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::Todo("false"))]), + posts: &[], + expr: None, + }) +} +fn mk_simple_predicate<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + predicate_name: &'vir str, + field_name: &'vir str, +) -> vir::Predicate<'vir> { + let predicate_body = vcx.alloc(vir::ExprData::AccField(vcx.alloc(vir::AccFieldData { + recv: vcx.mk_local_ex("self_p"), + field: field_name, + }))); + vir::vir_predicate! { vcx; predicate [predicate_name](self_p: Ref) { [predicate_body] } } +} +/* +fn mk_refold<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + predicate_name: &'vir str, + snapshot_ty: vir::Type<'vir>, +) -> vir::Method<'vir> { + vcx.alloc(vir::MethodData { + name: vir::vir_format!(vcx, "refold_{predicate_name}"), + args: vcx.alloc_slice(&[ + vcx.alloc(vir::LocalDeclData { + name: "_p", + ty: vcx.alloc(vir::TypeData::Ref), + }), + vcx.alloc(vir::LocalDeclData { + name: "_s_old", + ty: snapshot_ty, + }), + vcx.alloc(vir::LocalDeclData { + name: "_s_new", + ty: snapshot_ty, + }), + ]), + rets: &[], + pres: vcx.alloc_slice(&[ + vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc(vir::PredicateAppData { + target: predicate_name, + args: vcx.alloc_slice(&[ + vcx.mk_local_ex("_p"), + vcx.mk_local_ex("_s_old"), + ]), + }))), + ]), + posts: vcx.alloc_slice(&[ + vcx.alloc(vir::ExprData::PredicateApp(vcx.alloc(vir::PredicateAppData { + target: predicate_name, + args: vcx.alloc_slice(&[ + vcx.mk_local_ex("_p"), + vcx.mk_local_ex("_s_new"), + ]), + }))), + ]), + blocks: None, + }) +} +*/ +// TODO: there is a lot of duplication here, both in these assign/ +// reassign methods, and in the match cases below +// also: is mk_assign really worth it? (used in constant method +// arguments only) +fn mk_assign<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + predicate_name: &'vir str, + snapshot_ty: vir::Type<'vir>, +) -> vir::Method<'vir> { + vcx.alloc(vir::MethodData { + name: vir::vir_format!(vcx, "assign_{predicate_name}"), + args: vcx.alloc_slice(&[ + vcx.alloc(vir::LocalDeclData { + name: "_p", + ty: vcx.alloc(vir::TypeData::Ref), + }), + vcx.alloc(vir::LocalDeclData { + name: "_s_new", + ty: snapshot_ty, + }), + ]), + rets: &[], + pres: &[], + posts: vcx.alloc_slice(&[ + vcx.mk_pred_app(predicate_name, &[vcx.mk_local_ex("_p")]), + vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::CmpEq, + lhs: vcx.mk_func_app( + vir::vir_format!(vcx, "{predicate_name}_snap"), + &[vcx.mk_local_ex("_p")], + ), + rhs: vcx.mk_local_ex("_s_new"), + }))), + ]), + blocks: None, + }) +} + +fn mk_snap<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + predicate_name: &'vir str, + snapshot_name: &'vir str, + field_name: Option<&'vir str>, + snapshot_ty: vir::Type<'vir>, +) -> vir::Function<'vir> { + let pred_app = vcx.alloc(vir::PredicateAppData { + target: predicate_name, + args: vcx.alloc_slice(&[vcx.mk_local_ex("self")]), + }); + vcx.alloc(vir::FunctionData { + name: vir::vir_format!(vcx, "{predicate_name}_snap"), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self", &vir::TypeData::Ref)]), + ret: snapshot_ty, + pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), + posts: &[], + expr: field_name.map(|field_name| { + vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: vcx.alloc(vir::ExprData::Field(vcx.mk_local_ex("self"), field_name)), + }))) + }), + }) +} + +fn mk_struct_snap<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + name_p: &'vir str, + fields: &[TypeEncoderOutputRef<'vir>], + ty_s: &'vir TypeData<'vir>, + cons_name: &'vir str, +) -> vir::Function<'vir> { + let pred_app = vcx.alloc(vir::PredicateAppData { + target: name_p, + args: vcx.alloc_slice(&[vcx.mk_local_ex("self_p")]), + }); + + let cons_args = vcx.alloc_slice( + &fields + .iter() + .enumerate() + .map(|(idx, field_ty_out)| { + vcx.mk_func_app( + field_ty_out.function_snap, + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_p}_field_{idx}"), + &[vcx.mk_local_ex("self_p")], + )], + ) + }) + .collect::>(), + ); + + vcx.alloc(vir::FunctionData { + name: vir::vir_format!(vcx, "{name_p}_snap"), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), + ret: ty_s, + pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), + posts: &[], + expr: Some( + vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: vcx.mk_func_app(cons_name, cons_args), + }))), + ), + }) +} + +/// for the given fields on the given type create the read_write axioms and push them into the `axioms` Vector +fn read_write_axioms<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + ty_s: &'vir TypeData<'vir>, + name_s: &'vir str, + fields: &[TypeEncoderOutputRef<'vir>], + axioms: &mut Vec>, +) { + for (write_idx, write_ty_out) in fields.iter().enumerate() { + for (read_idx, _read_ty_out) in fields.iter().enumerate() { + axioms.push(vcx.alloc(vir::DomainAxiomData { + name: vir::vir_format!(vcx, "ax_{name_s}_write_{write_idx}_read_{read_idx}"), + expr: if read_idx == write_idx { + vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { + qvars: vcx.alloc_slice(&[ + vcx.mk_local_decl("self", ty_s), + vcx.mk_local_decl("val", write_ty_out.snapshot), + ]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + )], + )])]), + body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::CmpEq, + lhs: vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + )], + ), + rhs: vcx.mk_local_ex("val"), + }))), + }))) + } else { + vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { + qvars: vcx.alloc_slice(&[ + vcx.mk_local_decl("self", ty_s), + vcx.mk_local_decl("val", write_ty_out.snapshot), + ]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + )], + )])]), + body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::CmpEq, + lhs: vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + )], + ), + rhs: vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_local_ex("self")], + ), + }))), + }))) + }, + })); + } + } +} + +/// Create the `_cons_read_` axioms and push them into the `axioms` vector +fn constructor_axioms<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + ty_s: &'vir TypeData<'vir>, + name_s: &'vir str, + fields: &[TypeEncoderOutputRef<'vir>], + cons_name: &'vir str, + axioms: &mut Vec>, +) { + let cons_qvars = vcx.alloc_slice( + &fields + .iter() + .enumerate() + .map(|(idx, field_ty_out)| { + vcx.mk_local_decl(vir::vir_format!(vcx, "f{idx}"), field_ty_out.snapshot) + }) + .collect::>(), + ); + let cons_args = fields + .iter() + .enumerate() + .map(|(idx, _field_ty_out)| vcx.mk_local_ex(vir::vir_format!(vcx, "f{idx}"))) + .collect::>(); + let cons_call = vcx.mk_func_app(cons_name, &cons_args); + + for (read_idx, _) in fields.iter().enumerate() { + axioms.push(vcx.alloc(vir::DomainAxiomData { + name: vir::vir_format!(vcx, "ax_{name_s}_cons_read_{read_idx}"), + expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { + qvars: cons_qvars.clone(), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), + body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::CmpEq, + lhs: vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[cons_call], + ), + rhs: cons_args[read_idx], + }))), + }))), + })); + } + + let cons_call_with_reads = vcx.mk_func_app( + cons_name, + &fields + .iter() + .enumerate() + .map(|(idx, _field_ty_out)| { + vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{idx}"), + &[vcx.mk_local_ex("self")], + ) + }) + .collect::>(), + ); + axioms.push(vcx.alloc(vir::DomainAxiomData { + name: vir::vir_format!(vcx, "ax_{name_s}_cons"), + expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { + qvars: vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call_with_reads])]), + body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { + kind: vir::BinOpKind::CmpEq, + lhs: cons_call_with_reads, + rhs: vcx.mk_local_ex("self"), + }))), + }))), + })); +} + +fn mk_structlike<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + deps: &mut TaskEncoderDependencies<'vir>, + task_key: &::TaskKey<'vir>, + name_s: &'vir str, + name_p: &'vir str, + field_ty_out: Vec>, +) -> Result< + ::OutputFullLocal<'vir>, + ( + ::EncodingError, + Option<::OutputFullDependency<'vir>>, + ), +> { + let mut field_read_names = Vec::new(); + let mut field_write_names = Vec::new(); + let mut field_projection_p_names = Vec::new(); + for idx in 0..field_ty_out.len() { + field_read_names.push(vir::vir_format!(vcx, "{name_s}_read_{idx}")); + field_write_names.push(vir::vir_format!(vcx, "{name_s}_write_{idx}")); + field_projection_p_names.push(vir::vir_format!(vcx, "{name_p}_field_{idx}")); + } + let field_read_names = vcx.alloc_slice(&field_read_names); + let field_write_names = vcx.alloc_slice(&field_write_names); + let field_projection_p_names = vcx.alloc_slice(&field_projection_p_names); + + deps.emit_output_ref::( + *task_key, + TypeEncoderOutputRef { + snapshot_name: name_s, + predicate_name: name_p, + from_primitive: None, + to_primitive: None, + snapshot: vcx.alloc(vir::TypeData::Domain(name_s)), + function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), + function_snap: vir::vir_format!(vcx, "{name_p}_snap"), + //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), + specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { + field_read: field_read_names, + field_write: field_write_names, + field_projection_p: field_projection_p_names, + }), + method_assign: vir::vir_format!(vcx, "assign_{name_p}"), + }, + ); + let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); + + let mut funcs: Vec> = vec![]; + let mut axioms: Vec> = vec![]; + + let cons_name = vir::vir_format!(vcx, "{name_s}_cons"); + funcs.push( + vcx.alloc(vir::DomainFunctionData { + unique: false, + name: cons_name, + args: vcx.alloc_slice( + &field_ty_out + .iter() + .map(|field_ty_out| field_ty_out.snapshot) + .collect::>(), + ), + ret: ty_s, + }), + ); + + let mut field_projection_p = Vec::new(); + mk_field_projection_p( + &field_ty_out, + vcx, + ty_s, + name_s, + name_p, + &mut funcs, + &mut field_projection_p, + ); + + let field_projection_p = vcx.alloc_slice(&field_projection_p); + + read_write_axioms(vcx, ty_s, name_s, &field_ty_out, &mut axioms); + + constructor_axioms(vcx, ty_s, name_s, &field_ty_out, cons_name, &mut axioms); + + // predicate + let predicate = { + let expr = field_ty_out + .iter() + .enumerate() + .map(|(idx, field_ty_out)| { + vcx.mk_pred_app( + field_ty_out.predicate_name, + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_p}_field_{idx}"), + &[vcx.mk_local_ex("self_p")], + )], + ) + }) + .reduce(|base, field_expr| vcx.mk_and(base, field_expr)) + .unwrap_or_else(|| vcx.mk_true()); + + vcx.alloc(vir::PredicateData { + name: name_p, + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), + expr: Some(expr), + }) + }; + + Ok(TypeEncoderOutput { + fields: &[], + snapshot: vir::vir_domain! { vcx; domain [name_s] { + with_funcs [funcs]; + with_axioms [axioms]; + } }, + predicate, + function_unreachable: mk_unreachable(vcx, name_s, ty_s), + function_snap: mk_struct_snap(vcx, name_p, &field_ty_out, ty_s, cons_name), + //method_refold: mk_refold(vcx, name_p, ty_s), + field_projection_p, + method_assign: mk_assign(vcx, name_p, ty_s), + }) +} + +/// add the field projectsions and add the snapshot version to the funcs vector +fn mk_field_projection_p<'vir>( + fields: &[TypeEncoderOutputRef<'vir>], + vcx: &'vir vir::VirCtxt<'vir>, + ty_s: &'vir TypeData<'vir>, + name_s: &'vir str, + name_p: &'vir str, + funcs: &mut Vec<&vir::DomainFunctionData<'vir>>, + field_projection_p: &mut Vec<&FunctionGenData<'vir, !, !>>, +) { + let mut field_projection_p = Vec::new(); + for (idx, ty_out) in fields.iter().enumerate() { + let name_r = vir::vir_format!(vcx, "{name_s}_read_{idx}"); + funcs.push(vir::vir_domain_func! { vcx; function [name_r]([ty_s]): [ty_out.snapshot] }); + + let name_w = vir::vir_format!(vcx, "{name_s}_write_{idx}"); + funcs.push( + vir::vir_domain_func! { vcx; function [name_w]([ty_s], [ty_out.snapshot]): [ty_s] }, + ); + + field_projection_p.push(vcx.alloc(vir::FunctionData { + name: vir::vir_format!(vcx, "{name_p}_field_{idx}"), + args: vcx.alloc_slice(&[vcx.mk_local_decl("self", &vir::TypeData::Ref)]), + ret: &vir::TypeData::Ref, + pres: &[], + posts: &[], + expr: None, + })); + } +} diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index a91c3cf160c..f0d81cbd1a2 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -2,6 +2,7 @@ #![feature(associated_type_defaults)] #![feature(box_patterns)] #![feature(local_key_cell_methods)] +#![feature(never_type)] extern crate rustc_middle; extern crate rustc_serialize; diff --git a/vir/src/context.rs b/vir/src/context.rs index ef0778c30eb..5f524c885b7 100644 --- a/vir/src/context.rs +++ b/vir/src/context.rs @@ -88,6 +88,15 @@ impl<'tcx> VirCtxt<'tcx> { self.alloc(ExprData::Const(self.alloc(ConstData::Bool(true)))) } + + pub fn mk_and(&'tcx self, lhs: ExprGen<'tcx, Curr, Next>, rhs: ExprGen<'tcx, Curr, Next>) -> ExprGen<'tcx, Curr, Next> { + self.alloc(ExprGenData::BinOp(self.alloc(BinOpGenData { + kind: BinOpKind::And, + lhs, + rhs, + }))) + } + pub fn mk_conj(&'tcx self, elems: &[Expr<'tcx>]) -> Expr<'tcx> { if elems.len() == 0 { return self.alloc(ExprData::Const(self.alloc(ConstData::Bool(true)))); From 4d9d0cd7609b05d26237328569081bf039d918d8 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Fri, 3 Nov 2023 17:17:22 +0100 Subject: [PATCH 07/18] More WIP --- prusti-encoder/src/encoders/mir_impure.rs | 32 +++++++++++++++-- prusti-encoder/src/encoders/typ.rs | 42 ++++++++++++++++------- vir/src/context.rs | 7 ++-- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 91bffdf6765..2e778a9ba94 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -732,12 +732,18 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { } mir::Rvalue::Aggregate( - box mir::AggregateKind::Adt(..) | box mir::AggregateKind::Tuple, + kind @(box mir::AggregateKind::Adt(..) | box mir::AggregateKind::Tuple), fields, ) => { - let dest_ty_struct = dest_ty_out.expect_structlike(); + //let dest_ty_struct = dest_ty_out.expect_structlike(); + + let cons_name = match kind { + box mir::AggregateKind::Adt(_,vidx,_, _, _) if dest_ty_out.is_enum() => { + vir::vir_format!(self.vcx, "{}_{vidx:?}_cons", dest_ty_out.snapshot_name) + } + _ => vir::vir_format!(self.vcx, "{}_cons", dest_ty_out.snapshot_name) + }; - let cons_name = vir::vir_format!(self.vcx, "{}_cons", dest_ty_out.snapshot_name); let cons_args: Vec<_> = fields.iter().map(|field| self.encode_operand_snap(field)).collect(); let cons = self.vcx.mk_func_app(cons_name, self.vcx.alloc_slice(&cons_args)); @@ -755,6 +761,22 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { None } + mir::Rvalue::Discriminant(place) => { + tracing::warn!("Discrimiant of {dest_ty_out:?}"); + + let place_ty = self.local_defs.locals[place.local].ty.clone(); + let is_enum = matches!(place_ty.specifics, crate::encoders::typ::TypeEncoderOutputRefSub::Enum); + + Some(if is_enum { + self.vcx.mk_func_app( + "discriminant", // TODO: go through type encoder + &[self.encode_place(Place::from(*place))], + ) + } + else { + self.vcx.alloc(vir::ExprData::Const(self.vcx.alloc(vir::ConstData::Int(0)))) + }) + } //mir::Rvalue::Discriminant(Place<'tcx>) => {} //mir::Rvalue::ShallowInitBox(Operand<'tcx>, Ty<'tcx>) => {} //mir::Rvalue::CopyForDeref(Place<'tcx>) => {} @@ -799,6 +821,10 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { | mir::TerminatorKind::FalseUnwind { real_target: target, .. + } + | mir::TerminatorKind::FalseEdge { + real_target: target, + .. } => self.vcx.alloc(vir::TerminatorStmtData::Goto( self.vcx .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index 8f15933a771..7b2d90e170b 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -51,6 +51,10 @@ impl<'vir> TypeEncoderOutputRef<'vir> { } } + pub fn is_enum(&self) -> bool { + matches!(self.specifics, TypeEncoderOutputRefSub::Enum) + } + pub fn expr_from_u128(&self, val: u128) -> vir::Expr<'vir> { // TODO: not great: store the TyKind as well? // or should this be a different task for TypeEncoder? @@ -474,7 +478,10 @@ fn mk_enum<'vir>( let mut axioms: Vec> = vec![]; let mut field_projection_p = Vec::new(); - for variant in adt.variants() { + for (idx, variant) in adt.variants().iter().enumerate() { + let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}"); + let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}_{idx}"); + mk_emum_variant( vcx, deps, @@ -490,20 +497,18 @@ fn mk_enum<'vir>( let field_projection_p = vcx.alloc_slice(&field_projection_p); - - let predicate = vcx.alloc(vir::PredicateData { + let predicate = vcx.alloc(vir::PredicateData { name: name_p, args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), expr: None, }); //TODO - let pred_app = vcx.alloc(vir::PredicateAppData { target: name_p, args: vcx.alloc_slice(&[vcx.mk_local_ex("self_p")]), }); - let function_snap =vcx.alloc(vir::FunctionData { + let function_snap = vcx.alloc(vir::FunctionData { name: vir::vir_format!(vcx, "{name_p}_snap"), args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), ret: ty_s, @@ -532,8 +537,8 @@ fn mk_emum_variant<'vir>( deps: &mut TaskEncoderDependencies<'vir>, variant: &VariantDef, ty_s: &'vir TypeData<'vir>, - parent_name_s: &'vir str, - parent_name_p: &'vir str, + name_s: &'vir str, + name_p: &'vir str, funcs: &mut Vec<&'vir vir::DomainFunctionData<'vir>>, field_projection_p: &mut Vec<&'vir FunctionGenData<'vir, !, !>>, axioms: &mut Vec>, @@ -548,10 +553,9 @@ fn mk_emum_variant<'vir>( }) .collect::>(); - let did_name = vcx.tcx.item_name(variant.def_id).to_ident_string(); - let name_s = vir::vir_format!(vcx, "{parent_name_s}_variant_{did_name}"); - let name_p = vir::vir_format!(vcx, "{parent_name_p}_variant_{did_name}"); - + // let did_name = vcx.tcx.item_name(variant.def_id).to_ident_string(); + // let name_s = vir::vir_format!(vcx, "{parent_name_s}_variant_{did_name}"); + // let name_p = vir::vir_format!(vcx, "{parent_name_p}_variant_{did_name}"); mk_field_projection_p( &fields, @@ -564,6 +568,21 @@ fn mk_emum_variant<'vir>( ); read_write_axioms(vcx, ty_s, name_s, &fields, axioms); + + let cons_name = vir::vir_format!(vcx, "{name_s}_cons"); + funcs.push( + vcx.alloc(vir::DomainFunctionData { + unique: false, + name: cons_name, + args: vcx.alloc_slice( + &fields + .iter() + .map(|field| field.snapshot) + .collect::>(), + ), + ret: ty_s, + }), + ); } fn mk_unreachable<'vir>( @@ -1015,7 +1034,6 @@ fn mk_field_projection_p<'vir>( funcs: &mut Vec<&vir::DomainFunctionData<'vir>>, field_projection_p: &mut Vec<&FunctionGenData<'vir, !, !>>, ) { - let mut field_projection_p = Vec::new(); for (idx, ty_out) in fields.iter().enumerate() { let name_r = vir::vir_format!(vcx, "{name_s}_read_{idx}"); funcs.push(vir::vir_domain_func! { vcx; function [name_r]([ty_s]): [ty_out.snapshot] }); diff --git a/vir/src/context.rs b/vir/src/context.rs index 5f524c885b7..192b2b9fd10 100644 --- a/vir/src/context.rs +++ b/vir/src/context.rs @@ -88,8 +88,11 @@ impl<'tcx> VirCtxt<'tcx> { self.alloc(ExprData::Const(self.alloc(ConstData::Bool(true)))) } - - pub fn mk_and(&'tcx self, lhs: ExprGen<'tcx, Curr, Next>, rhs: ExprGen<'tcx, Curr, Next>) -> ExprGen<'tcx, Curr, Next> { + pub fn mk_and( + &'tcx self, + lhs: ExprGen<'tcx, Curr, Next>, + rhs: ExprGen<'tcx, Curr, Next>, + ) -> ExprGen<'tcx, Curr, Next> { self.alloc(ExprGenData::BinOp(self.alloc(BinOpGenData { kind: BinOpKind::And, lhs, From ef848b6ade1f8ea246606a8801cfa5abbe9ae9ec Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Sat, 4 Nov 2023 16:41:55 +0100 Subject: [PATCH 08/18] WIP --- prusti-encoder/src/encoders/mir_impure.rs | 39 +++++++++++++++++------ prusti-encoder/src/encoders/typ.rs | 20 +++++++++--- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 2e778a9ba94..fc89a78163a 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -10,6 +10,7 @@ use prusti_rustc_interface::{ // SsaAnalysis, //}; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use vir::PredicateAppGenData; pub struct MirImpureEncoder; @@ -765,17 +766,35 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { tracing::warn!("Discrimiant of {dest_ty_out:?}"); let place_ty = self.local_defs.locals[place.local].ty.clone(); - let is_enum = matches!(place_ty.specifics, crate::encoders::typ::TypeEncoderOutputRefSub::Enum); - Some(if is_enum { - self.vcx.mk_func_app( - "discriminant", // TODO: go through type encoder - &[self.encode_place(Place::from(*place))], - ) - } - else { - self.vcx.alloc(vir::ExprData::Const(self.vcx.alloc(vir::ConstData::Int(0)))) - }) + + + let p = self.encode_place(Place::from(*place)); + + let discr_as_int = match place_ty.specifics { + crate::encoders::typ::TypeEncoderOutputRefSub::Enum(x) => { + self.vcx.alloc(vir::ExprGenData::Field(p, x.field_discriminant)) + } + + _ => { + // mir::Rvalue::Discriminant documents "Returns zero for types without discriminant" + self.vcx.alloc(vir::ExprData::Const(self.vcx.alloc(vir::ConstData::Int(0)))) + + } + }; + + let discr_as_int = self.vcx.alloc(vir::ExprGenData::Unfolding( + self.vcx.alloc(vir::UnfoldingGenData{ + target: self.vcx.alloc(PredicateAppGenData{ target: place_ty.predicate_name, args: self.vcx.alloc_slice(&[p]) }), + expr: discr_as_int}))); + + + let to_prim = dest_ty_out.from_primitive.unwrap(); + + + + + Some(self.vcx.mk_func_app(to_prim, &[discr_as_int])) } //mir::Rvalue::Discriminant(Place<'tcx>) => {} //mir::Rvalue::ShallowInitBox(Operand<'tcx>, Ty<'tcx>) => {} diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index 7b2d90e170b..4da024a490b 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -18,12 +18,17 @@ pub struct TypeEncoderOutputRefSubStruct<'vir> { pub field_projection_p: &'vir [&'vir str], } +#[derive(Clone, Debug)] +pub struct TypeEncoderOutputRefSubEnum<'vir> { + pub field_discriminant: &'vir str, +} + #[derive(Clone, Debug)] pub enum TypeEncoderOutputRefSub<'vir> { Primitive, // structs, tuples StructLike(TypeEncoderOutputRefSubStruct<'vir>), - Enum, + Enum(TypeEncoderOutputRefSubEnum<'vir>), } // TODO: should output refs actually be references to structs...? @@ -52,7 +57,7 @@ impl<'vir> TypeEncoderOutputRef<'vir> { } pub fn is_enum(&self) -> bool { - matches!(self.specifics, TypeEncoderOutputRefSub::Enum) + matches!(self.specifics, TypeEncoderOutputRefSub::Enum { .. }) } pub fn expr_from_u128(&self, val: u128) -> vir::Expr<'vir> { @@ -453,6 +458,8 @@ fn mk_enum<'vir>( let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}"); let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}"); + let field_discriminant = vir::vir_format!(vcx, "p_Adt_{did_name}_discriminant"); + deps.emit_output_ref::( *task_key, TypeEncoderOutputRef { @@ -464,7 +471,9 @@ fn mk_enum<'vir>( function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), function_snap: vir::vir_format!(vcx, "{name_p}_snap"), //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), - specifics: TypeEncoderOutputRefSub::Enum, + specifics: TypeEncoderOutputRefSub::Enum(TypeEncoderOutputRefSubEnum { + field_discriminant, + }), method_assign: vir::vir_format!(vcx, "assign_{name_p}"), }, ); @@ -518,7 +527,10 @@ fn mk_enum<'vir>( }); Ok(TypeEncoderOutput { - fields: &[], + fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { + ty: &TypeData::Int, + name: field_discriminant, + })]), snapshot: vir::vir_domain! { vcx; domain [name_s] { with_funcs [funcs]; with_axioms [axioms]; From ff1c46ea93de103c3fe935c1945f69b970f34c49 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Sun, 5 Nov 2023 12:42:52 +0100 Subject: [PATCH 09/18] WIP --- prusti-encoder/src/encoders/mir_impure.rs | 9 +- prusti-encoder/src/encoders/typ.rs | 140 ++++++++++++++++++---- vir/src/context.rs | 96 +++++++++++---- vir/src/data.rs | 19 +++ vir/src/debug.rs | 1 + 5 files changed, 214 insertions(+), 51 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index fc89a78163a..88977d22cb4 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -777,9 +777,8 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { } _ => { - // mir::Rvalue::Discriminant documents "Returns zero for types without discriminant" - self.vcx.alloc(vir::ExprData::Const(self.vcx.alloc(vir::ConstData::Int(0)))) - + // mir::Rvalue::Discriminant documents "Returns zero for types without discriminant" + self.vcx.mk_const(0u128.into()) } }; @@ -1000,9 +999,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { self.vcx.alloc(vir::GotoIfData { value: enc, targets: self.vcx.alloc_slice(&[( - self.vcx.alloc(vir::ExprData::Const( - self.vcx.alloc(vir::ConstData::Bool(*expected)), - )), + self.vcx.mk_const(vir::ConstData::Bool(*expected)), &target_bb, )]), otherwise: self diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index 4da024a490b..d063802124a 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -2,7 +2,7 @@ use prusti_rustc_interface::middle::ty; use rustc_middle::ty::VariantDef; use rustc_type_ir::sty::TyKind; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; -use vir::{FunctionGenData, TypeData}; +use vir::{ExprData, FunctionGenData, TypeData}; pub struct TypeEncoder; @@ -460,6 +460,8 @@ fn mk_enum<'vir>( let field_discriminant = vir::vir_format!(vcx, "p_Adt_{did_name}_discriminant"); + let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); + deps.emit_output_ref::( *task_key, TypeEncoderOutputRef { @@ -467,7 +469,7 @@ fn mk_enum<'vir>( predicate_name: name_p, from_primitive: None, to_primitive: None, - snapshot: vcx.alloc(vir::TypeData::Domain(name_s)), + snapshot: ty_s, function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), function_snap: vir::vir_format!(vcx, "{name_p}_snap"), //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), @@ -477,27 +479,34 @@ fn mk_enum<'vir>( method_assign: vir::vir_format!(vcx, "assign_{name_p}"), }, ); - let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); - // TODO: discriminant function - - // TODO: discriminant bounds axioms + // TODO: discriminant function bounds axioms let mut funcs: Vec> = vec![]; let mut axioms: Vec> = vec![]; let mut field_projection_p = Vec::new(); + let s_discr_func_name = vir::vir_format!(vcx, "{name_s}_discriminant"); + + funcs.push(vcx.alloc(vir::DomainFunctionData { + unique: false, + name: s_discr_func_name, + args: vcx.alloc_slice(&[ty_s]), + ret: &vir::TypeData::Int, + })); for (idx, variant) in adt.variants().iter().enumerate() { let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}"); let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}_{idx}"); - mk_emum_variant( + mk_enum_variant( vcx, deps, variant, + idx, ty_s, name_s, name_p, + s_discr_func_name, &mut funcs, &mut field_projection_p, &mut axioms, @@ -506,24 +515,63 @@ fn mk_enum<'vir>( let field_projection_p = vcx.alloc_slice(&field_projection_p); + let self_local = vcx.mk_local_ex("self_p"); + + let acc = vcx.alloc(vir::ExprData::AccField(vcx.alloc(vir::AccFieldGenData { + field: field_discriminant, + recv: self_local, + }))); + + let discr_field_access = vcx.alloc(vir::ExprGenData::Field(self_local, field_discriminant)); + + // TODO: handle the empty enum? i guess the lower and upper bound together for an empty enum are false which is correct? + let disc_lower_bound = vcx.mk_bin_op( + vir::BinOpKind::CmpGe, + discr_field_access, + vcx.mk_const(0usize.into()), + ); + + let disc_upper_bound = vcx.mk_bin_op( + vir::BinOpKind::CmpLt, + discr_field_access, + vcx.mk_const(adt.variants().len().into()), + ); + let predicate = vcx.alloc(vir::PredicateData { name: name_p, args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), - expr: None, - }); //TODO + expr: Some(vcx.mk_conj(&[acc, disc_lower_bound, disc_upper_bound])), + }); let pred_app = vcx.alloc(vir::PredicateAppData { target: name_p, - args: vcx.alloc_slice(&[vcx.mk_local_ex("self_p")]), + args: vcx.alloc_slice(&[self_local]), }); + let snap_body = { + let mut cur = vcx.mk_func_app(vir::vir_format!(vcx, "{name_s}_unreachable"), &[]); + + for (idx, variant) in adt.variants().iter().enumerate() { + let cond = vcx.mk_eq(discr_field_access, vcx.mk_const(idx.into())); + let cons_name = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}_cons"); //TODO get better + + let const_call = vcx.mk_func_app(cons_name, &[]); //TODO: args + cur = vcx.mk_tern(cond, const_call, cur) + } + + vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: cur, + }))) + }; + let function_snap = vcx.alloc(vir::FunctionData { name: vir::vir_format!(vcx, "{name_p}_snap"), args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), ret: ty_s, pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), posts: &[], - expr: None, //TODO + expr: Some(snap_body), }); Ok(TypeEncoderOutput { @@ -544,13 +592,15 @@ fn mk_enum<'vir>( }) } -fn mk_emum_variant<'vir>( +fn mk_enum_variant<'vir>( vcx: &'vir vir::VirCtxt<'vir>, deps: &mut TaskEncoderDependencies<'vir>, variant: &VariantDef, + variant_idx: usize, ty_s: &'vir TypeData<'vir>, name_s: &'vir str, name_p: &'vir str, + s_discr_func_name: &'vir str, funcs: &mut Vec<&'vir vir::DomainFunctionData<'vir>>, field_projection_p: &mut Vec<&'vir FunctionGenData<'vir, !, !>>, axioms: &mut Vec>, @@ -595,6 +645,24 @@ fn mk_emum_variant<'vir>( ret: ty_s, }), ); + + let const_cond = vcx.mk_eq( + vcx.mk_func_app( + s_discr_func_name, + vcx.alloc_slice(&[vcx.mk_local_ex("self")]), + ), + vcx.mk_const(variant_idx.into()), + ); + constructor_axioms( + vcx, + ty_s, + name_s, + &fields, + cons_name, + axioms, + Some(const_cond), + None, + ); } fn mk_unreachable<'vir>( @@ -853,6 +921,8 @@ fn constructor_axioms<'vir>( fields: &[TypeEncoderOutputRef<'vir>], cons_name: &'vir str, axioms: &mut Vec>, + cond: Option<&'vir ExprData<'vir>>, + extra: Option<&'vir ExprData<'vir>>, ) { let cons_qvars = vcx.alloc_slice( &fields @@ -876,14 +946,13 @@ fn constructor_axioms<'vir>( expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { qvars: cons_qvars.clone(), triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), - body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: vcx.mk_func_app( + body: vcx.mk_eq( + vcx.mk_func_app( vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), &[cons_call], ), - rhs: cons_args[read_idx], - }))), + cons_args[read_idx], + ), }))), })); } @@ -901,16 +970,30 @@ fn constructor_axioms<'vir>( }) .collect::>(), ); + + let body = { + let body = vcx.mk_eq(cons_call_with_reads, vcx.mk_local_ex("self")); + if let Some(cond) = cond { + vcx.mk_impl(cond, body) + } else { + body + } + }; + + let triggers = { + if fields.is_empty() { + vcx.alloc_slice(&[]) + } else { + vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call_with_reads])]) + } + }; + axioms.push(vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_{name_s}_cons"), expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { qvars: vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call_with_reads])]), - body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: cons_call_with_reads, - rhs: vcx.mk_local_ex("self"), - }))), + triggers, + body, }))), })); } @@ -995,7 +1078,16 @@ fn mk_structlike<'vir>( read_write_axioms(vcx, ty_s, name_s, &field_ty_out, &mut axioms); - constructor_axioms(vcx, ty_s, name_s, &field_ty_out, cons_name, &mut axioms); + constructor_axioms( + vcx, + ty_s, + name_s, + &field_ty_out, + cons_name, + &mut axioms, + None, + None, + ); // predicate let predicate = { diff --git a/vir/src/context.rs b/vir/src/context.rs index 192b2b9fd10..8fa870cd507 100644 --- a/vir/src/context.rs +++ b/vir/src/context.rs @@ -51,10 +51,10 @@ impl<'tcx> VirCtxt<'tcx> { } pub fn mk_local<'vir>(&'vir self, name: &'vir str) -> Local<'vir> { - self.arena.alloc(LocalData { name }) + self.alloc(LocalData { name }) } pub fn mk_local_decl(&'tcx self, name: &'tcx str, ty: Type<'tcx>) -> LocalDecl<'tcx> { - self.arena.alloc(LocalDeclData { name, ty }) + self.alloc(LocalDeclData { name, ty }) } pub fn mk_local_ex_local( &'tcx self, @@ -70,47 +70,101 @@ impl<'tcx> VirCtxt<'tcx> { target: &'tcx str, src_args: &[ExprGen<'tcx, Curr, Next>], ) -> ExprGen<'tcx, Curr, Next> { - self.arena - .alloc(ExprGenData::FuncApp(self.arena.alloc(FuncAppGenData { - target, - args: self.alloc_slice(src_args), - }))) + self.alloc(ExprGenData::FuncApp(self.alloc(FuncAppGenData { + target, + args: self.alloc_slice(src_args), + }))) } pub fn mk_pred_app(&'tcx self, target: &'tcx str, src_args: &[Expr<'tcx>]) -> Expr<'tcx> { - self.arena - .alloc(ExprData::PredicateApp(self.arena.alloc(PredicateAppData { - target, - args: self.alloc_slice(src_args), - }))) + self.alloc(ExprData::PredicateApp(self.alloc(PredicateAppData { + target, + args: self.alloc_slice(src_args), + }))) } pub fn mk_true(&'tcx self) -> Expr<'tcx> { - self.alloc(ExprData::Const(self.alloc(ConstData::Bool(true)))) + self.mk_const(true.into()) } - pub fn mk_and( + pub fn mk_tern( + &'tcx self, + cond: ExprGen<'tcx, Curr, Next>, + then: ExprGen<'tcx, Curr, Next>, + else_: ExprGen<'tcx, Curr, Next>, + ) -> ExprGen<'tcx, Curr, Next> { + self.alloc(ExprGenData::Ternary(self.alloc(TernaryGenData { + cond, + then, + else_, + }))) + } + + pub fn mk_bin_op( &'tcx self, + kind: BinOpKind, lhs: ExprGen<'tcx, Curr, Next>, rhs: ExprGen<'tcx, Curr, Next>, ) -> ExprGen<'tcx, Curr, Next> { self.alloc(ExprGenData::BinOp(self.alloc(BinOpGenData { - kind: BinOpKind::And, + kind, lhs, rhs, }))) } + pub fn mk_not( + &'tcx self, + expr: ExprGen<'tcx, Curr, Next>, + ) -> ExprGen<'tcx, Curr, Next> { + self.alloc(ExprGenData::UnOp(self.alloc(UnOpGenData { + kind: UnOpKind::Not, + expr, + }))) + } + + pub fn mk_impl( + &'tcx self, + cond: ExprGen<'tcx, Curr, Next>, + rhs: ExprGen<'tcx, Curr, Next>, + ) -> ExprGen<'tcx, Curr, Next> { + self.mk_or(self.mk_not(cond), rhs) + } + + pub fn mk_and( + &'tcx self, + lhs: ExprGen<'tcx, Curr, Next>, + rhs: ExprGen<'tcx, Curr, Next>, + ) -> ExprGen<'tcx, Curr, Next> { + self.mk_bin_op(BinOpKind::And, lhs, rhs) + } + + pub fn mk_or( + &'tcx self, + lhs: ExprGen<'tcx, Curr, Next>, + rhs: ExprGen<'tcx, Curr, Next>, + ) -> ExprGen<'tcx, Curr, Next> { + self.mk_bin_op(BinOpKind::Or, lhs, rhs) + } + + pub fn mk_eq( + &'tcx self, + lhs: ExprGen<'tcx, Curr, Next>, + rhs: ExprGen<'tcx, Curr, Next>, + ) -> ExprGen<'tcx, Curr, Next> { + self.mk_bin_op(BinOpKind::CmpEq, lhs, rhs) + } + + pub fn mk_const(&'tcx self, cnst: ConstData) -> ExprGen<'tcx, Curr, Next> { + self.alloc(ExprGenData::Const(self.alloc(cnst))) + } + pub fn mk_conj(&'tcx self, elems: &[Expr<'tcx>]) -> Expr<'tcx> { if elems.len() == 0 { - return self.alloc(ExprData::Const(self.alloc(ConstData::Bool(true)))); + return self.mk_true(); } let mut e = elems[0]; for i in 1..elems.len() { - e = self.alloc(ExprData::BinOp(self.alloc(BinOpData { - kind: BinOpKind::And, - lhs: e, - rhs: elems[i], - }))); + e = self.mk_and(e, elems[i]) } e } diff --git a/vir/src/data.rs b/vir/src/data.rs index d9b42b131ef..27cd16aa707 100644 --- a/vir/src/data.rs +++ b/vir/src/data.rs @@ -40,6 +40,7 @@ pub enum BinOpKind { CmpGe, CmpLe, And, + Or, Add, Sub, // ... @@ -83,6 +84,24 @@ pub enum ConstData { Int(u128), // TODO: what about negative numbers? larger numbers? } +impl From for ConstData { + fn from(value: bool) -> Self { + ConstData::Bool(value) + } +} + +impl From for ConstData { + fn from(value: u128) -> Self { + ConstData::Int(value) + } +} + +impl From for ConstData { + fn from(value: usize) -> Self { + ConstData::Int(value.try_into().unwrap()) + } +} + pub enum TypeData<'vir> { Int, Bool, diff --git a/vir/src/debug.rs b/vir/src/debug.rs index 8475f38fce6..20b94a59a19 100644 --- a/vir/src/debug.rs +++ b/vir/src/debug.rs @@ -59,6 +59,7 @@ impl<'vir, Curr, Next> Debug for BinOpGenData<'vir, Curr, Next> { BinOpKind::CmpLt => "<", BinOpKind::CmpLe => "<=", BinOpKind::And => "&&", + BinOpKind::Or => "||", BinOpKind::Add => "+", BinOpKind::Sub => "-", } From bfd38b6035a13aac24c7b5b7c938c5633a44dcfe Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Mon, 6 Nov 2023 14:00:43 +0100 Subject: [PATCH 10/18] Enum WIP --- prusti-encoder/src/encoders/mir_impure.rs | 10 +- prusti-encoder/src/encoders/typ.rs | 192 +++++++++++++++------- prusti-encoder/src/lib.rs | 6 + vir/src/context.rs | 4 + 4 files changed, 154 insertions(+), 58 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 88977d22cb4..50a0d878bc3 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -272,7 +272,13 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { field_ty_out, ) } - _ => panic!("unsupported projection"), + mir::ProjectionElem::Downcast(name, idx) => { + let enu = ty_out.expect_enum(); + let ty_out_struct = &enu.variants[idx.as_usize()]; + + (base, ty_out_struct.clone()) + } + other => panic!("unsupported projection {other:?}"), } } @@ -346,7 +352,7 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { return; } let place_ty = place.ty(self.local_decls, self.vcx.tcx); - assert!(place_ty.variant_index.is_none()); + //assert!(place_ty.variant_index.is_none()); let place_ty_out = self .deps diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index d063802124a..0cf2d4bf1eb 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -21,6 +21,7 @@ pub struct TypeEncoderOutputRefSubStruct<'vir> { #[derive(Clone, Debug)] pub struct TypeEncoderOutputRefSubEnum<'vir> { pub field_discriminant: &'vir str, + pub variants: &'vir [TypeEncoderOutputRef<'vir>], } #[derive(Clone, Debug)] @@ -56,6 +57,13 @@ impl<'vir> TypeEncoderOutputRef<'vir> { } } + pub fn expect_enum(&self) -> &TypeEncoderOutputRefSubEnum<'vir> { + match self.specifics { + TypeEncoderOutputRefSub::Enum(ref data) => data, + _ => panic!("expected enum type"), + } + } + pub fn is_enum(&self) -> bool { matches!(self.specifics, TypeEncoderOutputRefSub::Enum { .. }) } @@ -90,6 +98,7 @@ pub struct TypeEncoderOutput<'vir> { pub fields: &'vir [vir::Field<'vir>], pub snapshot: vir::Domain<'vir>, pub predicate: vir::Predicate<'vir>, + pub other_predicates: &'vir [vir::Predicate<'vir>], // TODO: these should be generated on demand, put into tiny encoders ? pub function_unreachable: vir::Function<'vir>, pub function_snap: vir::Function<'vir>, @@ -205,6 +214,7 @@ impl TaskEncoder for TypeEncoder { axiom_inverse(s_Bool_cons, s_Bool_val, s_Bool); } }, predicate: mk_simple_predicate(vcx, "p_Bool", "f_Bool"), + other_predicates: &[], function_unreachable: mk_unreachable(vcx, "s_Bool", ty_s), function_snap: mk_snap(vcx, "p_Bool", "s_Bool", Some("f_Bool"), ty_s), //method_refold: mk_refold(vcx, "p_Bool", ty_s), @@ -254,6 +264,7 @@ impl TaskEncoder for TypeEncoder { axiom_inverse([name_cons], [name_val], [ty_s]); } }, predicate: mk_simple_predicate(vcx, name_p, name_field), + other_predicates: &[], function_unreachable: mk_unreachable(vcx, name_s, ty_s), function_snap: mk_snap(vcx, name_p, name_s, Some(name_field), ty_s), //method_refold: mk_refold(vcx, name_p, ty_s), @@ -291,6 +302,7 @@ impl TaskEncoder for TypeEncoder { function s_Tuple0_cons(): [ty_s]; } }, predicate: vir::vir_predicate! { vcx; predicate p_Tuple0(self_p: Ref) }, + other_predicates: &[], function_unreachable: mk_unreachable(vcx, "s_Tuple0", ty_s), function_snap: mk_snap(vcx, "p_Tuple0", "s_Tuple0", None, ty_s), //method_refold: mk_refold(vcx, "p_Tuple0", ty_s), @@ -358,6 +370,7 @@ impl TaskEncoder for TypeEncoder { snapshot: vir::vir_domain! { vcx; domain s_ParamTodo { // TODO: should not be emitted -- make outputs vectors } }, predicate: vir::vir_predicate! { vcx; predicate p_ParamTodo(self_p: Ref) }, + other_predicates: &[], function_unreachable: mk_unreachable(vcx, "p_Param", ty_s), function_snap: mk_snap(vcx, "p_Param", "s_Param", None, ty_s), //method_refold: mk_refold(vcx, "p_Param", ty_s), @@ -417,6 +430,7 @@ impl TaskEncoder for TypeEncoder { })]), snapshot: vir::vir_domain! { vcx; domain s_Never {} }, predicate: vir::vir_predicate! { vcx; predicate p_Never(self_p: Ref) }, + other_predicates: &[], function_unreachable: mk_unreachable(vcx, "s_Never", ty_s), function_snap: mk_snap(vcx, "p_Never", "s_Never", None, ty_s), //method_refold: mk_refold(vcx, "p_Never", ty_s), @@ -462,6 +476,44 @@ fn mk_enum<'vir>( let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); + let mut variants: Vec> = Vec::new(); + + for (idx, variant) in adt.variants().iter().enumerate() { + let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}"); + let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}_{idx}"); + + let mut field_read_names = Vec::new(); + let mut field_write_names = Vec::new(); + let mut field_projection_p_names = Vec::new(); + for idx in 0..variant.fields.len() { + field_read_names.push(vir::vir_format!(vcx, "{name_s}_read_{idx}")); + field_write_names.push(vir::vir_format!(vcx, "{name_s}_write_{idx}")); + field_projection_p_names.push(vir::vir_format!(vcx, "{name_p}_field_{idx}")); + } + let field_read_names = vcx.alloc_slice(&field_read_names); + let field_write_names = vcx.alloc_slice(&field_write_names); + let field_projection_p_names = vcx.alloc_slice(&field_projection_p_names); + + let x = TypeEncoderOutputRef { + snapshot_name: name_s, + predicate_name: name_p, + from_primitive: None, + to_primitive: None, + snapshot: ty_s, + function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), + function_snap: vir::vir_format!(vcx, "{name_p}_snap"), + //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), + specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { + field_read: field_read_names, + field_write: field_write_names, + field_projection_p: field_projection_p_names, + }), + method_assign: vir::vir_format!(vcx, "assign_{name_p}"), + }; + + variants.push(x); + } + deps.emit_output_ref::( *task_key, TypeEncoderOutputRef { @@ -475,6 +527,7 @@ fn mk_enum<'vir>( //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), specifics: TypeEncoderOutputRefSub::Enum(TypeEncoderOutputRefSubEnum { field_discriminant, + variants: vcx.alloc(variants), }), method_assign: vir::vir_format!(vcx, "assign_{name_p}"), }, @@ -485,6 +538,7 @@ fn mk_enum<'vir>( let mut funcs: Vec> = vec![]; let mut axioms: Vec> = vec![]; let mut field_projection_p = Vec::new(); + let mut other_predicates = Vec::new(); let s_discr_func_name = vir::vir_format!(vcx, "{name_s}_discriminant"); @@ -494,11 +548,24 @@ fn mk_enum<'vir>( args: vcx.alloc_slice(&[ty_s]), ret: &vir::TypeData::Int, })); + + let mut field_t = vcx.mk_false(); + + let self_local = vcx.mk_local_ex("self_p"); + + let discr_field_access = vcx.alloc(vir::ExprGenData::Field(self_local, field_discriminant)); + let mut snap_cur = vcx.mk_func_app(vir::vir_format!(vcx, "{name_s}_unreachable"), &[]); + + let pred_app = vcx.alloc(vir::PredicateAppData { + target: name_p, + args: vcx.alloc_slice(&[self_local]), + }); + for (idx, variant) in adt.variants().iter().enumerate() { let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}"); let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}_{idx}"); - mk_enum_variant( + let (a, cons_call) = mk_enum_variant( vcx, deps, variant, @@ -510,20 +577,28 @@ fn mk_enum<'vir>( &mut funcs, &mut field_projection_p, &mut axioms, - ) - } + &mut other_predicates, + ); - let field_projection_p = vcx.alloc_slice(&field_projection_p); + let cond = vcx.mk_eq(discr_field_access, vcx.mk_const(idx.into())); + let pred_call = vcx.mk_pred_app(name_p, vcx.alloc_slice(&[self_local])); + field_t = vcx.mk_tern(cond, pred_call, field_t); - let self_local = vcx.mk_local_ex("self_p"); + let snap_cond = vcx.mk_eq(discr_field_access, vcx.mk_const(idx.into())); + + snap_cur = vcx.mk_tern(snap_cond, cons_call, snap_cur) + } + + let snap_body = vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: snap_cur, + }))); let acc = vcx.alloc(vir::ExprData::AccField(vcx.alloc(vir::AccFieldGenData { field: field_discriminant, recv: self_local, }))); - let discr_field_access = vcx.alloc(vir::ExprGenData::Field(self_local, field_discriminant)); - // TODO: handle the empty enum? i guess the lower and upper bound together for an empty enum are false which is correct? let disc_lower_bound = vcx.mk_bin_op( vir::BinOpKind::CmpGe, @@ -540,31 +615,9 @@ fn mk_enum<'vir>( let predicate = vcx.alloc(vir::PredicateData { name: name_p, args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), - expr: Some(vcx.mk_conj(&[acc, disc_lower_bound, disc_upper_bound])), - }); - - let pred_app = vcx.alloc(vir::PredicateAppData { - target: name_p, - args: vcx.alloc_slice(&[self_local]), + expr: Some(vcx.mk_conj(&[acc, disc_lower_bound, disc_upper_bound, field_t])), }); - let snap_body = { - let mut cur = vcx.mk_func_app(vir::vir_format!(vcx, "{name_s}_unreachable"), &[]); - - for (idx, variant) in adt.variants().iter().enumerate() { - let cond = vcx.mk_eq(discr_field_access, vcx.mk_const(idx.into())); - let cons_name = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}_cons"); //TODO get better - - let const_call = vcx.mk_func_app(cons_name, &[]); //TODO: args - cur = vcx.mk_tern(cond, const_call, cur) - } - - vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: cur, - }))) - }; - let function_snap = vcx.alloc(vir::FunctionData { name: vir::vir_format!(vcx, "{name_p}_snap"), args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), @@ -584,10 +637,11 @@ fn mk_enum<'vir>( with_axioms [axioms]; } }, predicate, + other_predicates: vcx.alloc_slice(&other_predicates), function_unreachable: mk_unreachable(vcx, name_s, ty_s), function_snap, //method_refold: mk_refold(vcx, name_p, ty_s), - field_projection_p, + field_projection_p: vcx.alloc_slice(&field_projection_p), method_assign: mk_assign(vcx, name_p, ty_s), }) } @@ -604,7 +658,8 @@ fn mk_enum_variant<'vir>( funcs: &mut Vec<&'vir vir::DomainFunctionData<'vir>>, field_projection_p: &mut Vec<&'vir FunctionGenData<'vir, !, !>>, axioms: &mut Vec>, -) { + predicates: &mut Vec<&'vir vir::PredicateData<'vir>>, +) -> (&'vir vir::PredicateAppData<'vir>, &'vir ExprData<'vir>) { let substs = ty::List::identity_for_item(vcx.tcx, variant.def_id); let fields = variant .fields @@ -663,6 +718,10 @@ fn mk_enum_variant<'vir>( Some(const_cond), None, ); + + predicates.push(mk_struct_predicate(&fields, vcx, name_p)); + + mk_struct_snap_parts(vcx, name_p, &fields, cons_name) } fn mk_unreachable<'vir>( @@ -799,13 +858,12 @@ fn mk_snap<'vir>( }) } -fn mk_struct_snap<'vir>( +fn mk_struct_snap_parts<'vir>( vcx: &'vir vir::VirCtxt<'vir>, name_p: &'vir str, fields: &[TypeEncoderOutputRef<'vir>], - ty_s: &'vir TypeData<'vir>, cons_name: &'vir str, -) -> vir::Function<'vir> { +) -> (&'vir vir::PredicateAppData<'vir>, &'vir ExprData<'vir>) { let pred_app = vcx.alloc(vir::PredicateAppData { target: name_p, args: vcx.alloc_slice(&[vcx.mk_local_ex("self_p")]), @@ -827,18 +885,30 @@ fn mk_struct_snap<'vir>( .collect::>(), ); + let expr = vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: vcx.mk_func_app(cons_name, cons_args), + }))); + + (pred_app, expr) +} + +fn mk_struct_snap<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + name_p: &'vir str, + fields: &[TypeEncoderOutputRef<'vir>], + ty_s: &'vir TypeData<'vir>, + cons_name: &'vir str, +) -> vir::Function<'vir> { + let (pred_app, expr) = mk_struct_snap_parts(vcx, name_p, fields, cons_name); + vcx.alloc(vir::FunctionData { name: vir::vir_format!(vcx, "{name_p}_snap"), args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), ret: ty_s, pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), posts: &[], - expr: Some( - vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: vcx.mk_func_app(cons_name, cons_args), - }))), - ), + expr: Some(expr), }) } @@ -1090,8 +1160,31 @@ fn mk_structlike<'vir>( ); // predicate + let predicate = mk_struct_predicate(&field_ty_out, vcx, name_p); + + Ok(TypeEncoderOutput { + fields: &[], + snapshot: vir::vir_domain! { vcx; domain [name_s] { + with_funcs [funcs]; + with_axioms [axioms]; + } }, + predicate, + function_unreachable: mk_unreachable(vcx, name_s, ty_s), + function_snap: mk_struct_snap(vcx, name_p, &field_ty_out, ty_s, cons_name), + //method_refold: mk_refold(vcx, name_p, ty_s), + field_projection_p, + method_assign: mk_assign(vcx, name_p, ty_s), + other_predicates: &[], + }) +} + +fn mk_struct_predicate<'vir>( + fields: &Vec>, + vcx: &'vir vir::VirCtxt<'vir>, + name_p: &'vir str, +) -> &'vir vir::PredicateGenData<'vir, !, !> { let predicate = { - let expr = field_ty_out + let expr = fields .iter() .enumerate() .map(|(idx, field_ty_out)| { @@ -1112,20 +1205,7 @@ fn mk_structlike<'vir>( expr: Some(expr), }) }; - - Ok(TypeEncoderOutput { - fields: &[], - snapshot: vir::vir_domain! { vcx; domain [name_s] { - with_funcs [funcs]; - with_axioms [axioms]; - } }, - predicate, - function_unreachable: mk_unreachable(vcx, name_s, ty_s), - function_snap: mk_struct_snap(vcx, name_p, &field_ty_out, ty_s, cons_name), - //method_refold: mk_refold(vcx, name_p, ty_s), - field_projection_p, - method_assign: mk_assign(vcx, name_p, ty_s), - }) + predicate } /// add the field projectsions and add the snapshot version to the funcs vector diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index f0d81cbd1a2..e1b875f0da4 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -221,7 +221,13 @@ pub fn test_entrypoint<'tcx>( } viper_code.push_str(&format!("{:?}\n", output.function_unreachable)); viper_code.push_str(&format!("{:?}\n", output.function_snap)); + viper_code.push_str(&format!("{:?}\n", output.predicate)); + + for pred in output.other_predicates { + viper_code.push_str(&format!("{:?}\n", pred)); + } + //viper_code.push_str(&format!("{:?}\n", output.method_refold)); viper_code.push_str(&format!("{:?}\n", output.method_assign)); } diff --git a/vir/src/context.rs b/vir/src/context.rs index 8fa870cd507..fc3f9a4395d 100644 --- a/vir/src/context.rs +++ b/vir/src/context.rs @@ -86,6 +86,10 @@ impl<'tcx> VirCtxt<'tcx> { self.mk_const(true.into()) } + pub fn mk_false(&'tcx self) -> Expr<'tcx> { + self.mk_const(false.into()) + } + pub fn mk_tern( &'tcx self, cond: ExprGen<'tcx, Curr, Next>, From bf8ec5f969b6f0bbe6177074b1bed411e05df921 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Mon, 6 Nov 2023 17:43:41 +0100 Subject: [PATCH 11/18] WIP --- prusti-encoder/src/encoders/typ.rs | 172 ++++++++++++++++------------- 1 file changed, 98 insertions(+), 74 deletions(-) diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index 0cf2d4bf1eb..a83dad70537 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -2,7 +2,7 @@ use prusti_rustc_interface::middle::ty; use rustc_middle::ty::VariantDef; use rustc_type_ir::sty::TyKind; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; -use vir::{ExprData, FunctionGenData, TypeData}; +use vir::{ExprData, ExprGenData, FunctionGenData, TypeData}; pub struct TypeEncoder; @@ -701,24 +701,47 @@ fn mk_enum_variant<'vir>( }), ); - let const_cond = vcx.mk_eq( - vcx.mk_func_app( - s_discr_func_name, - vcx.alloc_slice(&[vcx.mk_local_ex("self")]), - ), - vcx.mk_const(variant_idx.into()), - ); - constructor_axioms( - vcx, - ty_s, - name_s, - &fields, - cons_name, - axioms, - Some(const_cond), - None, - ); + // let const_cond = vcx.mk_eq( + // vcx.mk_func_app( + // s_discr_func_name, + // vcx.alloc_slice(&[vcx.mk_local_ex("self")]), + // ), + // vcx.mk_const(variant_idx.into()), + // ); + + cons_read_axioms(name_s, vcx, &fields, cons_name, axioms); + + if !fields.is_empty() { + axioms.push(cons_axiom(name_s, vcx, cons_name, &fields, ty_s)); + } + + // discriminant of constructor + { + let (cons_qvars, cons_args, cons_call) = cons_read_parts(vcx, &fields, cons_name); + + let body = vcx.mk_eq( + vcx.mk_func_app(s_discr_func_name, &[cons_call]), + vcx.mk_const(variant_idx.into()), + ); + + let ax = if fields.is_empty() { + body + } else { + // only apply the forall if there are fields + vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { + qvars: cons_qvars.clone(), + triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), + body, + }))) + }; + + axioms.push(vcx.alloc(vir::DomainAxiomData { + name: vir::vir_format!(vcx, "ax_{name_s}_cons_{variant_idx}_discr"), + expr: ax, + })); + } + //TODO: discriminant for write axioms predicates.push(mk_struct_predicate(&fields, vcx, name_p)); mk_struct_snap_parts(vcx, name_p, &fields, cons_name) @@ -983,16 +1006,49 @@ fn read_write_axioms<'vir>( } } -/// Create the `_cons_read_` axioms and push them into the `axioms` vector -fn constructor_axioms<'vir>( +fn cons_axiom<'vir>( + name_s: &'vir str, vcx: &'vir vir::VirCtxt<'vir>, + cons_name: &'vir str, + fields: &[TypeEncoderOutputRef<'vir>], ty_s: &'vir TypeData<'vir>, - name_s: &'vir str, +) -> &'vir vir::DomainAxiomData<'vir> { + let cons_call_with_reads = vcx.mk_func_app( + cons_name, + &fields + .iter() + .enumerate() + .map(|(idx, _field_ty_out)| { + vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{idx}"), + &[vcx.mk_local_ex("self")], + ) + }) + .collect::>(), + ); + + let body = vcx.mk_eq(cons_call_with_reads, vcx.mk_local_ex("self")); + + let triggers = vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call_with_reads])]); + + vcx.alloc(vir::DomainAxiomData { + name: vir::vir_format!(vcx, "ax_{name_s}_cons"), + expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { + qvars: vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), + triggers, + body, + }))), + }) +} + +fn cons_read_parts<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, fields: &[TypeEncoderOutputRef<'vir>], cons_name: &'vir str, - axioms: &mut Vec>, - cond: Option<&'vir ExprData<'vir>>, - extra: Option<&'vir ExprData<'vir>>, +) -> ( + &'vir [&'vir vir::LocalDeclData<'vir>], + Vec<&'vir vir::ExprData<'vir>>, + &'vir vir::ExprData<'vir>, ) { let cons_qvars = vcx.alloc_slice( &fields @@ -1008,8 +1064,21 @@ fn constructor_axioms<'vir>( .enumerate() .map(|(idx, _field_ty_out)| vcx.mk_local_ex(vir::vir_format!(vcx, "f{idx}"))) .collect::>(); + let cons_call = vcx.mk_func_app(cons_name, &cons_args); + (cons_qvars, cons_args, cons_call) +} + +fn cons_read_axioms<'vir>( + name_s: &'vir str, + vcx: &'vir vir::VirCtxt<'vir>, + fields: &[TypeEncoderOutputRef<'vir>], + cons_name: &'vir str, + axioms: &mut Vec<&vir::DomainAxiomGenData<'vir, !, !>>, +) { + let (cons_qvars, cons_args, cons_call) = cons_read_parts(vcx, fields, cons_name); + for (read_idx, _) in fields.iter().enumerate() { axioms.push(vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_{name_s}_cons_read_{read_idx}"), @@ -1026,46 +1095,6 @@ fn constructor_axioms<'vir>( }))), })); } - - let cons_call_with_reads = vcx.mk_func_app( - cons_name, - &fields - .iter() - .enumerate() - .map(|(idx, _field_ty_out)| { - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{idx}"), - &[vcx.mk_local_ex("self")], - ) - }) - .collect::>(), - ); - - let body = { - let body = vcx.mk_eq(cons_call_with_reads, vcx.mk_local_ex("self")); - if let Some(cond) = cond { - vcx.mk_impl(cond, body) - } else { - body - } - }; - - let triggers = { - if fields.is_empty() { - vcx.alloc_slice(&[]) - } else { - vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call_with_reads])]) - } - }; - - axioms.push(vcx.alloc(vir::DomainAxiomData { - name: vir::vir_format!(vcx, "ax_{name_s}_cons"), - expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), - triggers, - body, - }))), - })); } fn mk_structlike<'vir>( @@ -1148,16 +1177,11 @@ fn mk_structlike<'vir>( read_write_axioms(vcx, ty_s, name_s, &field_ty_out, &mut axioms); - constructor_axioms( - vcx, - ty_s, - name_s, - &field_ty_out, - cons_name, - &mut axioms, - None, - None, - ); + cons_read_axioms(name_s, vcx, &field_ty_out, cons_name, &mut axioms); + + if !field_ty_out.is_empty() { + axioms.push(cons_axiom(name_s, vcx, cons_name, &field_ty_out, ty_s)); + } // predicate let predicate = mk_struct_predicate(&field_ty_out, vcx, name_p); From 85890874ca8d4b561ae8a241914d6749a47e3b8f Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Mon, 6 Nov 2023 18:53:30 +0100 Subject: [PATCH 12/18] Fix unfolds for enums --- prusti-encoder/src/encoders/mir_impure.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 50a0d878bc3..4df23ce1857 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -352,13 +352,19 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { return; } let place_ty = place.ty(self.local_decls, self.vcx.tcx); - //assert!(place_ty.variant_index.is_none()); - let place_ty_out = self + let place_ty_encoded = self .deps .require_ref::(place_ty.ty) .unwrap(); + let pred_name = match place_ty.variant_index { + None => place_ty_encoded.predicate_name, + Some(idx) => { + place_ty_encoded.expect_enum().variants[idx.as_usize()].predicate_name + } + }; + let ref_p = self.encode_place(place); if matches!( repack_op, @@ -366,13 +372,13 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { ) { self.stmt(vir::StmtData::Unfold(self.vcx.alloc( vir::PredicateAppData { - target: place_ty_out.predicate_name, + target: pred_name, args: self.vcx.alloc_slice(&[ref_p]), }, ))); } else { self.stmt(vir::StmtData::Fold(self.vcx.alloc(vir::PredicateAppData { - target: place_ty_out.predicate_name, + target: pred_name, args: self.vcx.alloc_slice(&[ref_p]), }))); } From 4133463a4409e3f4723edcd8aa12d4c3786aa06a Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Mon, 6 Nov 2023 19:23:10 +0100 Subject: [PATCH 13/18] WIP --- prusti-encoder/src/encoders/typ.rs | 253 ++++++++++++++--------------- vir/src/context.rs | 13 ++ 2 files changed, 139 insertions(+), 127 deletions(-) diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index a83dad70537..1273ac9e190 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -533,8 +533,6 @@ fn mk_enum<'vir>( }, ); - // TODO: discriminant function bounds axioms - let mut funcs: Vec> = vec![]; let mut axioms: Vec> = vec![]; let mut field_projection_p = Vec::new(); @@ -549,23 +547,18 @@ fn mk_enum<'vir>( ret: &vir::TypeData::Int, })); - let mut field_t = vcx.mk_false(); + let mut predicate_per_variant_predicates = vcx.mk_false(); let self_local = vcx.mk_local_ex("self_p"); let discr_field_access = vcx.alloc(vir::ExprGenData::Field(self_local, field_discriminant)); let mut snap_cur = vcx.mk_func_app(vir::vir_format!(vcx, "{name_s}_unreachable"), &[]); - let pred_app = vcx.alloc(vir::PredicateAppData { - target: name_p, - args: vcx.alloc_slice(&[self_local]), - }); - for (idx, variant) in adt.variants().iter().enumerate() { let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}"); let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}_{idx}"); - let (a, cons_call) = mk_enum_variant( + let (_, cons_call) = mk_enum_variant( vcx, deps, variant, @@ -581,41 +574,48 @@ fn mk_enum<'vir>( ); let cond = vcx.mk_eq(discr_field_access, vcx.mk_const(idx.into())); - let pred_call = vcx.mk_pred_app(name_p, vcx.alloc_slice(&[self_local])); - field_t = vcx.mk_tern(cond, pred_call, field_t); - let snap_cond = vcx.mk_eq(discr_field_access, vcx.mk_const(idx.into())); + let pred_call = vcx.mk_pred_app(name_p, vcx.alloc_slice(&[self_local])); + predicate_per_variant_predicates = + vcx.mk_tern(cond, pred_call, predicate_per_variant_predicates); - snap_cur = vcx.mk_tern(snap_cond, cons_call, snap_cur) + snap_cur = vcx.mk_tern(cond, cons_call, snap_cur) } - let snap_body = vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { - target: pred_app, - expr: snap_cur, - }))); - - let acc = vcx.alloc(vir::ExprData::AccField(vcx.alloc(vir::AccFieldGenData { - field: field_discriminant, - recv: self_local, - }))); + let predicate = { + let acc = vcx.alloc(vir::ExprData::AccField(vcx.alloc(vir::AccFieldGenData { + field: field_discriminant, + recv: self_local, + }))); + + // TODO: handle the empty enum? i guess the lower and upper bound together for an empty enum are false which is correct? + let disc_lower_bound = vcx.mk_bin_op( + vir::BinOpKind::CmpGe, + discr_field_access, + vcx.mk_const(0usize.into()), + ); - // TODO: handle the empty enum? i guess the lower and upper bound together for an empty enum are false which is correct? - let disc_lower_bound = vcx.mk_bin_op( - vir::BinOpKind::CmpGe, - discr_field_access, - vcx.mk_const(0usize.into()), - ); + let disc_upper_bound = vcx.mk_bin_op( + vir::BinOpKind::CmpLt, + discr_field_access, + vcx.mk_const(adt.variants().len().into()), + ); - let disc_upper_bound = vcx.mk_bin_op( - vir::BinOpKind::CmpLt, - discr_field_access, - vcx.mk_const(adt.variants().len().into()), - ); + vcx.alloc(vir::PredicateData { + name: name_p, + args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), + expr: Some(vcx.mk_conj(&[ + acc, + disc_lower_bound, + disc_upper_bound, + predicate_per_variant_predicates, + ])), + }) + }; - let predicate = vcx.alloc(vir::PredicateData { - name: name_p, - args: vcx.alloc_slice(&[vcx.mk_local_decl("self_p", &vir::TypeData::Ref)]), - expr: Some(vcx.mk_conj(&[acc, disc_lower_bound, disc_upper_bound, field_t])), + let pred_app = vcx.alloc(vir::PredicateAppData { + target: name_p, + args: vcx.alloc_slice(&[self_local]), }); let function_snap = vcx.alloc(vir::FunctionData { @@ -624,9 +624,42 @@ fn mk_enum<'vir>( ret: ty_s, pres: vcx.alloc_slice(&[vcx.alloc(vir::ExprData::PredicateApp(pred_app))]), posts: &[], - expr: Some(snap_body), + expr: Some( + vcx.alloc(vir::ExprData::Unfolding(vcx.alloc(vir::UnfoldingData { + target: pred_app, + expr: snap_cur, + }))), + ), }); + // discriminant bounds axiom + { + let self_local = vcx.mk_local_ex("self"); + let discr_func_call = vcx.mk_func_app(s_discr_func_name, vcx.alloc_slice(&[self_local])); + let body1 = vcx.mk_bin_op( + vir::BinOpKind::CmpGe, + discr_func_call, + vcx.mk_const(0usize.into()), + ); + + let body2 = vcx.mk_bin_op( + vir::BinOpKind::CmpLt, + discr_func_call, + vcx.mk_const(adt.variants().len().into()), + ); + + let body = vcx.mk_and(body1, body2); + + axioms.push(vcx.alloc(vir::DomainAxiomData { + name: vir::vir_format!(vcx, "ax_{name_s}_discriminant_bounds"), + expr: vcx.mk_forall( + vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), + &[], + body, + ), + })); + } + Ok(TypeEncoderOutput { fields: vcx.alloc_slice(&[vcx.alloc(vir::FieldData { ty: &TypeData::Int, @@ -684,8 +717,6 @@ fn mk_enum_variant<'vir>( field_projection_p, ); - read_write_axioms(vcx, ty_s, name_s, &fields, axioms); - let cons_name = vir::vir_format!(vcx, "{name_s}_cons"); funcs.push( vcx.alloc(vir::DomainFunctionData { @@ -701,14 +732,6 @@ fn mk_enum_variant<'vir>( }), ); - // let const_cond = vcx.mk_eq( - // vcx.mk_func_app( - // s_discr_func_name, - // vcx.alloc_slice(&[vcx.mk_local_ex("self")]), - // ), - // vcx.mk_const(variant_idx.into()), - // ); - cons_read_axioms(name_s, vcx, &fields, cons_name, axioms); if !fields.is_empty() { @@ -728,11 +751,11 @@ fn mk_enum_variant<'vir>( body } else { // only apply the forall if there are fields - vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: cons_qvars.clone(), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), + vcx.mk_forall( + cons_qvars.clone(), + vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), body, - }))) + ) }; axioms.push(vcx.alloc(vir::DomainAxiomData { @@ -742,6 +765,8 @@ fn mk_enum_variant<'vir>( } //TODO: discriminant for write axioms + read_write_axioms(vcx, ty_s, name_s, &fields, axioms); + predicates.push(mk_struct_predicate(&fields, vcx, name_p)); mk_struct_snap_parts(vcx, name_p, &fields, cons_name) @@ -945,62 +970,39 @@ fn read_write_axioms<'vir>( ) { for (write_idx, write_ty_out) in fields.iter().enumerate() { for (read_idx, _read_ty_out) in fields.iter().enumerate() { + let qvars = vcx.alloc_slice(&[ + vcx.mk_local_decl("self", ty_s), + vcx.mk_local_decl("val", write_ty_out.snapshot), + ]); + + let triggers = vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + )], + )])]); + + let lhs = vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + )], + ); + + let rhs = if read_idx == write_idx { + vcx.mk_local_ex("val") + } else { + vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[vcx.mk_local_ex("self")], + ) + }; + axioms.push(vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_{name_s}_write_{write_idx}_read_{read_idx}"), - expr: if read_idx == write_idx { - vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", ty_s), - vcx.mk_local_decl("val", write_ty_out.snapshot), - ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], - )], - )])]), - body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], - )], - ), - rhs: vcx.mk_local_ex("val"), - }))), - }))) - } else { - vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[ - vcx.mk_local_decl("self", ty_s), - vcx.mk_local_decl("val", write_ty_out.snapshot), - ]), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], - )], - )])]), - body: vcx.alloc(vir::ExprData::BinOp(vcx.alloc(vir::BinOpData { - kind: vir::BinOpKind::CmpEq, - lhs: vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], - )], - ), - rhs: vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_local_ex("self")], - ), - }))), - }))) - }, + expr: vcx.mk_forall(qvars, triggers, vcx.mk_eq(lhs, rhs)), })); } } @@ -1033,11 +1035,11 @@ fn cons_axiom<'vir>( vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_{name_s}_cons"), - expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), + expr: vcx.mk_forall( + vcx.alloc_slice(&[vcx.mk_local_decl("self", ty_s)]), triggers, body, - }))), + ), }) } @@ -1080,19 +1082,21 @@ fn cons_read_axioms<'vir>( let (cons_qvars, cons_args, cons_call) = cons_read_parts(vcx, fields, cons_name); for (read_idx, _) in fields.iter().enumerate() { + let forall_body = vcx.mk_eq( + vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), + &[cons_call], + ), + cons_args[read_idx], + ); + axioms.push(vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_{name_s}_cons_read_{read_idx}"), - expr: vcx.alloc(vir::ExprData::Forall(vcx.alloc(vir::ForallData { - qvars: cons_qvars.clone(), - triggers: vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), - body: vcx.mk_eq( - vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[cons_call], - ), - cons_args[read_idx], - ), - }))), + expr: vcx.mk_forall( + cons_qvars.clone(), + vcx.alloc_slice(&[vcx.alloc_slice(&[cons_call])]), + forall_body, + ), })); } } @@ -1173,8 +1177,6 @@ fn mk_structlike<'vir>( &mut field_projection_p, ); - let field_projection_p = vcx.alloc_slice(&field_projection_p); - read_write_axioms(vcx, ty_s, name_s, &field_ty_out, &mut axioms); cons_read_axioms(name_s, vcx, &field_ty_out, cons_name, &mut axioms); @@ -1183,20 +1185,17 @@ fn mk_structlike<'vir>( axioms.push(cons_axiom(name_s, vcx, cons_name, &field_ty_out, ty_s)); } - // predicate - let predicate = mk_struct_predicate(&field_ty_out, vcx, name_p); - Ok(TypeEncoderOutput { fields: &[], snapshot: vir::vir_domain! { vcx; domain [name_s] { with_funcs [funcs]; with_axioms [axioms]; } }, - predicate, + predicate: mk_struct_predicate(&field_ty_out, vcx, name_p), function_unreachable: mk_unreachable(vcx, name_s, ty_s), function_snap: mk_struct_snap(vcx, name_p, &field_ty_out, ty_s, cons_name), //method_refold: mk_refold(vcx, name_p, ty_s), - field_projection_p, + field_projection_p: vcx.alloc_slice(&field_projection_p), method_assign: mk_assign(vcx, name_p, ty_s), other_predicates: &[], }) diff --git a/vir/src/context.rs b/vir/src/context.rs index fc3f9a4395d..9a5f07976f1 100644 --- a/vir/src/context.rs +++ b/vir/src/context.rs @@ -158,6 +158,19 @@ impl<'tcx> VirCtxt<'tcx> { self.mk_bin_op(BinOpKind::CmpEq, lhs, rhs) } + pub fn mk_forall( + &'tcx self, + qvars: &'tcx [&'tcx LocalDeclData<'tcx>], + triggers: &'tcx [&'tcx [ExprGen<'tcx, Curr, Next>]], + body: ExprGen<'tcx, Curr, Next>, + ) -> ExprGen<'tcx, Curr, Next> { + self.alloc(ExprGenData::Forall(self.alloc(ForallGenData { + qvars, + triggers, + body, + }))) + } + pub fn mk_const(&'tcx self, cnst: ConstData) -> ExprGen<'tcx, Curr, Next> { self.alloc(ExprGenData::Const(self.alloc(cnst))) } From a0c636628a31bf897ace5a06bab4809a235e1533 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Tue, 7 Nov 2023 11:38:58 +0100 Subject: [PATCH 14/18] WIP --- prusti-encoder/src/encoders/mir_pure.rs | 56 +++++++++- prusti-encoder/src/encoders/typ.rs | 140 +++++++++++++----------- prusti-encoder/src/lib.rs | 1 - 3 files changed, 129 insertions(+), 68 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index f26926d107d..28651625089 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -518,14 +518,18 @@ where // then walk terminator let term = self.body[start].terminator.as_ref().unwrap(); match &term.kind { - mir::TerminatorKind::Goto { target } => { + mir::TerminatorKind::FalseEdge { + real_target: target, + .. + } + | mir::TerminatorKind::Goto { target } => { if *target == end { // We are done with the current fragment of the CFG, the // rest is handled in a parent call. return stmt_update; } - todo!() + todo!("Goto for target != end") } mir::TerminatorKind::SwitchInt { discr, targets } => { @@ -983,6 +987,36 @@ where .collect::>(), ) } + mir::AggregateKind::Adt(def_id, variant_idx, args, _, union_field) => { + assert!(union_field.is_none(), "unions not yet implemented"); + let ty = self.vcx.tcx.type_of(def_id).skip_binder(); + + let adt_typ = self + .deps + .require_ref::(ty) + .unwrap(); + + let snapshot_name = match adt_typ.specifics { + crate::encoders::typ::TypeEncoderOutputRefSub::Enum(e) => { + let variant = &e.variants[variant_idx.as_usize()]; + variant.snapshot_name + } + crate::encoders::typ::TypeEncoderOutputRefSub::StructLike(_) => { + assert_eq!(variant_idx.as_u32(), 0); + adt_typ.snapshot_name + } + crate::encoders::typ::TypeEncoderOutputRefSub::Primitive => todo!(), + }; + + let cons = vir::vir_format!(self.vcx, "{snapshot_name}_cons"); // TODO: get better + + let args = fields + .iter() + .map(|field| self.encode_operand(curr_ver, field)) + .collect::>(); + + self.vcx.mk_func_app(cons, self.vcx.alloc_slice(&args)) + } _ => todo!("Unsupported Rvalue::AggregateKind: {kind:?}"), }, mir::Rvalue::CheckedBinaryOp(binop, box (l, r)) => { @@ -1004,6 +1038,21 @@ where ], ) } + mir::Rvalue::Discriminant(place) => { + let discriminent_func = self + .deps + .require_ref::( + place.ty(&self.body.local_decls, self.vcx.tcx).ty, + ) + .unwrap() + .expect_enum() + .func_discriminant; + + self.vcx.mk_func_app( + discriminent_func, + self.vcx.alloc_slice(&[self.encode_place(curr_ver, place)]), + ) + } // ShallowInitBox // CopyForDeref k => { @@ -1056,6 +1105,9 @@ where todo!("unsupported constant literal type: {unsupported_ty:?}") } }, + mir::ConstantKind::Unevaluated(a, b) => { + todo!() + } unsupported_literal => { todo!("unsupported constant literal: {unsupported_literal:?}") } diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index 1273ac9e190..e47b862e79e 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -21,6 +21,7 @@ pub struct TypeEncoderOutputRefSubStruct<'vir> { #[derive(Clone, Debug)] pub struct TypeEncoderOutputRefSubEnum<'vir> { pub field_discriminant: &'vir str, + pub func_discriminant: &'vir str, pub variants: &'vir [TypeEncoderOutputRef<'vir>], } @@ -404,6 +405,9 @@ impl TaskEncoder for TypeEncoder { (), )) } + TyKind::Adt(adt_def, substs) if adt_def.is_enum() => { + Ok((mk_enum(vcx, deps, adt_def, task_key)?, ())) + } TyKind::Never => { let ty_s = vcx.alloc(vir::TypeData::Domain("s_Never")); @@ -441,14 +445,6 @@ impl TaskEncoder for TypeEncoder { )) } - TyKind::Adt(adt_def, substs) if adt_def.is_enum() => { - tracing::error!("encoding enum {adt_def:#?} with substs {substs:?}"); - tracing::warn!("{:?}", adt_def.all_fields().collect::>()); - tracing::warn!("{:#?}", adt_def.variants()); - - Ok((mk_enum(vcx, deps, adt_def, task_key)?, ())) - } - //_ => Err((TypeEncoderError::UnsupportedType, None)), unsupported_type => todo!("type not supported: {unsupported_type:?}"), }) @@ -476,23 +472,15 @@ fn mk_enum<'vir>( let ty_s = vcx.alloc(vir::TypeData::Domain(name_s)); + let s_discr_func_name = vir::vir_format!(vcx, "{name_s}_discriminant"); + let mut variants: Vec> = Vec::new(); for (idx, variant) in adt.variants().iter().enumerate() { let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}"); let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}_{idx}"); - let mut field_read_names = Vec::new(); - let mut field_write_names = Vec::new(); - let mut field_projection_p_names = Vec::new(); - for idx in 0..variant.fields.len() { - field_read_names.push(vir::vir_format!(vcx, "{name_s}_read_{idx}")); - field_write_names.push(vir::vir_format!(vcx, "{name_s}_write_{idx}")); - field_projection_p_names.push(vir::vir_format!(vcx, "{name_p}_field_{idx}")); - } - let field_read_names = vcx.alloc_slice(&field_read_names); - let field_write_names = vcx.alloc_slice(&field_write_names); - let field_projection_p_names = vcx.alloc_slice(&field_projection_p_names); + let ref_sub_struct = mk_output_ref_sub_struct(name_p, name_s, variant.fields.len(), vcx); let x = TypeEncoderOutputRef { snapshot_name: name_s, @@ -503,11 +491,7 @@ fn mk_enum<'vir>( function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), function_snap: vir::vir_format!(vcx, "{name_p}_snap"), //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), - specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { - field_read: field_read_names, - field_write: field_write_names, - field_projection_p: field_projection_p_names, - }), + specifics: TypeEncoderOutputRefSub::StructLike(ref_sub_struct), method_assign: vir::vir_format!(vcx, "assign_{name_p}"), }; @@ -527,6 +511,7 @@ fn mk_enum<'vir>( //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), specifics: TypeEncoderOutputRefSub::Enum(TypeEncoderOutputRefSubEnum { field_discriminant, + func_discriminant: s_discr_func_name, variants: vcx.alloc(variants), }), method_assign: vir::vir_format!(vcx, "assign_{name_p}"), @@ -538,8 +523,6 @@ fn mk_enum<'vir>( let mut field_projection_p = Vec::new(); let mut other_predicates = Vec::new(); - let s_discr_func_name = vir::vir_format!(vcx, "{name_s}_discriminant"); - funcs.push(vcx.alloc(vir::DomainFunctionData { unique: false, name: s_discr_func_name, @@ -689,7 +672,7 @@ fn mk_enum_variant<'vir>( name_p: &'vir str, s_discr_func_name: &'vir str, funcs: &mut Vec<&'vir vir::DomainFunctionData<'vir>>, - field_projection_p: &mut Vec<&'vir FunctionGenData<'vir, !, !>>, + field_projection_p: &mut Vec<&'vir vir::FunctionData<'vir>>, axioms: &mut Vec>, predicates: &mut Vec<&'vir vir::PredicateData<'vir>>, ) -> (&'vir vir::PredicateAppData<'vir>, &'vir ExprData<'vir>) { @@ -703,10 +686,6 @@ fn mk_enum_variant<'vir>( }) .collect::>(); - // let did_name = vcx.tcx.item_name(variant.def_id).to_ident_string(); - // let name_s = vir::vir_format!(vcx, "{parent_name_s}_variant_{did_name}"); - // let name_p = vir::vir_format!(vcx, "{parent_name_p}_variant_{did_name}"); - mk_field_projection_p( &fields, vcx, @@ -738,7 +717,7 @@ fn mk_enum_variant<'vir>( axioms.push(cons_axiom(name_s, vcx, cons_name, &fields, ty_s)); } - // discriminant of constructor + // discriminant of constructor matches the variant idx { let (cons_qvars, cons_args, cons_call) = cons_read_parts(vcx, &fields, cons_name); @@ -764,9 +743,32 @@ fn mk_enum_variant<'vir>( })); } - //TODO: discriminant for write axioms read_write_axioms(vcx, ty_s, name_s, &fields, axioms); + // discriminant of write call stays the same + for (write_idx, write_ty_out) in fields.iter().enumerate() { + let qvars = vcx.alloc_slice(&[ + vcx.mk_local_decl("self", ty_s), + vcx.mk_local_decl("val", write_ty_out.snapshot), + ]); + + let write_call = vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + ); + + let discriminant_of_write = vcx.mk_func_app(s_discr_func_name, &[write_call]); + + axioms.push(vcx.alloc(vir::DomainAxiomData { + name: vir::vir_format!(vcx, "ax_{name_s}_discriminant_write_{write_idx}"), + expr: vcx.mk_forall( + qvars, + vcx.alloc_slice(&[vcx.alloc_slice(&[discriminant_of_write])]), + vcx.mk_eq(discriminant_of_write, vcx.mk_const(variant_idx.into())), + ), + })); + } + predicates.push(mk_struct_predicate(&fields, vcx, name_p)); mk_struct_snap_parts(vcx, name_p, &fields, cons_name) @@ -975,20 +977,14 @@ fn read_write_axioms<'vir>( vcx.mk_local_decl("val", write_ty_out.snapshot), ]); - let triggers = vcx.alloc_slice(&[vcx.alloc_slice(&[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], - )], - )])]); + let write_call = vcx.mk_func_app( + vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), + &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], + ); - let lhs = vcx.mk_func_app( + let read_of_write = vcx.mk_func_app( vir::vir_format!(vcx, "{name_s}_read_{read_idx}"), - &[vcx.mk_func_app( - vir::vir_format!(vcx, "{name_s}_write_{write_idx}"), - &[vcx.mk_local_ex("self"), vcx.mk_local_ex("val")], - )], + &[write_call], ); let rhs = if read_idx == write_idx { @@ -1002,7 +998,11 @@ fn read_write_axioms<'vir>( axioms.push(vcx.alloc(vir::DomainAxiomData { name: vir::vir_format!(vcx, "ax_{name_s}_write_{write_idx}_read_{read_idx}"), - expr: vcx.mk_forall(qvars, triggers, vcx.mk_eq(lhs, rhs)), + expr: vcx.mk_forall( + qvars, + vcx.alloc_slice(&[vcx.alloc_slice(&[read_of_write])]), + vcx.mk_eq(read_of_write, rhs), + ), })); } } @@ -1077,7 +1077,7 @@ fn cons_read_axioms<'vir>( vcx: &'vir vir::VirCtxt<'vir>, fields: &[TypeEncoderOutputRef<'vir>], cons_name: &'vir str, - axioms: &mut Vec<&vir::DomainAxiomGenData<'vir, !, !>>, + axioms: &mut Vec<&vir::DomainAxiomData<'vir>>, ) { let (cons_qvars, cons_args, cons_call) = cons_read_parts(vcx, fields, cons_name); @@ -1115,17 +1115,7 @@ fn mk_structlike<'vir>( Option<::OutputFullDependency<'vir>>, ), > { - let mut field_read_names = Vec::new(); - let mut field_write_names = Vec::new(); - let mut field_projection_p_names = Vec::new(); - for idx in 0..field_ty_out.len() { - field_read_names.push(vir::vir_format!(vcx, "{name_s}_read_{idx}")); - field_write_names.push(vir::vir_format!(vcx, "{name_s}_write_{idx}")); - field_projection_p_names.push(vir::vir_format!(vcx, "{name_p}_field_{idx}")); - } - let field_read_names = vcx.alloc_slice(&field_read_names); - let field_write_names = vcx.alloc_slice(&field_write_names); - let field_projection_p_names = vcx.alloc_slice(&field_projection_p_names); + let ref_sub_struct = mk_output_ref_sub_struct(name_p, name_s, field_ty_out.len(), vcx); deps.emit_output_ref::( *task_key, @@ -1138,11 +1128,7 @@ fn mk_structlike<'vir>( function_unreachable: vir::vir_format!(vcx, "{name_s}_unreachable"), function_snap: vir::vir_format!(vcx, "{name_p}_snap"), //method_refold: vir::vir_format!(vcx, "refold_{name_p}"), - specifics: TypeEncoderOutputRefSub::StructLike(TypeEncoderOutputRefSubStruct { - field_read: field_read_names, - field_write: field_write_names, - field_projection_p: field_projection_p_names, - }), + specifics: TypeEncoderOutputRefSub::StructLike(ref_sub_struct), method_assign: vir::vir_format!(vcx, "assign_{name_p}"), }, ); @@ -1201,11 +1187,35 @@ fn mk_structlike<'vir>( }) } +fn mk_output_ref_sub_struct<'vir>( + name_p: &'vir str, + name_s: &'vir str, + field_count: usize, + vcx: &'vir vir::VirCtxt<'vir>, +) -> TypeEncoderOutputRefSubStruct<'vir> { + let mut field_read_names = Vec::new(); + let mut field_write_names = Vec::new(); + let mut field_projection_p_names = Vec::new(); + for idx in 0..field_count { + field_read_names.push(vir::vir_format!(vcx, "{name_s}_read_{idx}")); + field_write_names.push(vir::vir_format!(vcx, "{name_s}_write_{idx}")); + field_projection_p_names.push(vir::vir_format!(vcx, "{name_p}_field_{idx}")); + } + let field_read_names = vcx.alloc_slice(&field_read_names); + let field_write_names = vcx.alloc_slice(&field_write_names); + let field_projection_p_names = vcx.alloc_slice(&field_projection_p_names); + TypeEncoderOutputRefSubStruct { + field_read: field_read_names, + field_write: field_write_names, + field_projection_p: field_projection_p_names, + } +} + fn mk_struct_predicate<'vir>( fields: &Vec>, vcx: &'vir vir::VirCtxt<'vir>, name_p: &'vir str, -) -> &'vir vir::PredicateGenData<'vir, !, !> { +) -> &'vir vir::PredicateData<'vir> { let predicate = { let expr = fields .iter() @@ -1239,7 +1249,7 @@ fn mk_field_projection_p<'vir>( name_s: &'vir str, name_p: &'vir str, funcs: &mut Vec<&vir::DomainFunctionData<'vir>>, - field_projection_p: &mut Vec<&FunctionGenData<'vir, !, !>>, + field_projection_p: &mut Vec<&vir::FunctionData<'vir>>, ) { for (idx, ty_out) in fields.iter().enumerate() { let name_r = vir::vir_format!(vcx, "{name_s}_read_{idx}"); diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index e1b875f0da4..74d30ecb233 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -2,7 +2,6 @@ #![feature(associated_type_defaults)] #![feature(box_patterns)] #![feature(local_key_cell_methods)] -#![feature(never_type)] extern crate rustc_middle; extern crate rustc_serialize; From fda1064e98c957b81a4fd7753be45e4886b687fe Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Tue, 7 Nov 2023 17:35:20 +0100 Subject: [PATCH 15/18] WIP --- prusti-encoder/src/encoders/mir_pure.rs | 45 +++++++++++++++++++----- prusti-interface/src/environment/body.rs | 4 +++ task-encoder/src/lib.rs | 2 +- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index 28651625089..2350bba5091 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -7,6 +7,7 @@ use prusti_rustc_interface::{ }; use std::collections::HashMap; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use vir::Reify; pub struct MirPureEncoder; @@ -211,7 +212,7 @@ impl TaskEncoder for MirPureEncoder { // TODO task.encoding_depth, task.parent_def_id, - None, + task.promoted, task.substs, ) } @@ -234,13 +235,22 @@ impl TaskEncoder for MirPureEncoder { let def_id = task_key.1; //.parent_def_id; let local_def_id = def_id.expect_local(); - tracing::debug!("encoding {def_id:?}"); + tracing::debug!("encoding {task_key:?}"); let expr = vir::with_vcx(move |vcx| { //let body = vcx.tcx.mir_promoted(local_def_id).0.borrow(); - let body = vcx - .body - .borrow_mut() - .get_impure_fn_body_identity(local_def_id); + + let body = if let Some(promoted) = task_key.2 { + // TODO: is this optimal? + prusti_interface::environment::body::MirBody::new( + vcx.tcx.promoted_mir(def_id)[promoted].clone(), + ) + } else { + vcx.body + .borrow_mut() + .get_impure_fn_body_identity(local_def_id) + }; + + tracing::warn!("body {:?}", body.body()); let expr_inner = Encoder::new(vcx, task_key.0, &body, deps).encode_body(); @@ -1105,8 +1115,27 @@ where todo!("unsupported constant literal type: {unsupported_ty:?}") } }, - mir::ConstantKind::Unevaluated(a, b) => { - todo!() + e @ mir::ConstantKind::Unevaluated(uneval, b) => { + let expr = self + .deps + .require_local::(MirPureEncoderTask { + encoding_depth: 0, + parent_def_id: uneval.def, + promoted: Some(uneval.promoted.unwrap()), + param_env: self.vcx.tcx.param_env(uneval.def), + substs: ty::List::identity_for_item(self.vcx.tcx, uneval.def), + }) + .unwrap() + .expr; + + tracing::warn!("{e:?} became {expr:?}"); + + self.vcx.alloc(vir::ExprGenData::Lazy( + vir::vir_format!(self.vcx, "unevaluated const {:?}", uneval.def), + Box::new(move |vcx, lctx: ExprInput<'_>| { + expr.reify(vcx, (lctx.0, &[])) //TODO: no idea why i do this so probably wrong + }), + )) } unsupported_literal => { todo!("unsupported constant literal: {unsupported_literal:?}") diff --git a/prusti-interface/src/environment/body.rs b/prusti-interface/src/environment/body.rs index 84873731aae..21509672917 100644 --- a/prusti-interface/src/environment/body.rs +++ b/prusti-interface/src/environment/body.rs @@ -19,6 +19,10 @@ impl<'tcx> MirBody<'tcx> { pub fn body(&self) -> Rc> { self.0.clone() } + + pub fn new(body: mir::Body<'tcx>) -> Self { + MirBody(Rc::new(body)) + } } impl<'tcx> std::ops::Deref for MirBody<'tcx> { type Target = mir::Body<'tcx>; diff --git a/task-encoder/src/lib.rs b/task-encoder/src/lib.rs index fa36c724f9d..eb018dfaf33 100644 --- a/task-encoder/src/lib.rs +++ b/task-encoder/src/lib.rs @@ -296,7 +296,7 @@ pub trait TaskEncoder { output_dep.clone(), ))), TaskEncoderCacheState::Enqueued | TaskEncoderCacheState::Started { .. } => { - panic!("Encoding already started or enqueued") + panic!("Encoding already started or enqueued for {task_key:?}") } }, None => { From b80dececc5ba86199af35c528c186b2b2c8e6b77 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Wed, 8 Nov 2023 11:19:00 +0100 Subject: [PATCH 16/18] start work on terminator fpcs --- prusti-encoder/src/encoders/mir_impure.rs | 32 +++++++++++++++++++++-- prusti-encoder/src/encoders/mir_pure.rs | 2 -- vir/src/debug.rs | 12 +++++---- vir/src/gendata.rs | 6 ++++- vir/src/reify.rs | 11 ++++++-- 5 files changed, 51 insertions(+), 12 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 4df23ce1857..38fa83250a5 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -10,7 +10,7 @@ use prusti_rustc_interface::{ // SsaAnalysis, //}; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; -use vir::PredicateAppGenData; +use vir::{ExprData, ExprGenData, PredicateAppGenData}; pub struct MirImpureEncoder; @@ -552,6 +552,12 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { fn visit_basic_block_data(&mut self, block: mir::BasicBlock, data: &mir::BasicBlockData<'vir>) { self.current_fpcs = Some(self.fpcs_analysis.get_all_for_bb(block)); + + tracing::warn!( + "fpcs for {block:?} is {:#?}", + self.current_fpcs.as_ref().map(|e| &e.terminator) + ); + self.current_stmts = Some(Vec::with_capacity( data.statements.len(), // TODO: not exact? )); @@ -862,6 +868,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { mir::TerminatorKind::SwitchInt { discr, targets } => { //let discr_version = self.ssa_analysis.version.get(&(location, discr.local)).unwrap(); //let discr_name = vir::vir_format!(self.vcx, "_{}s_{}", discr.local.index(), discr_version); + let ty_out = self .deps .require_ref::( @@ -869,15 +876,35 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { ) .unwrap(); + tracing::warn!( + "terminator fpcs {:?}", + self.current_fpcs.as_ref().unwrap().terminator + ); + let goto_targets = self.vcx.alloc_slice( &targets .iter() - .map(|(value, target)| { + .enumerate() + .map(|(idx, (value, target))| { + let terminator_fpcs = + &self.current_fpcs.as_ref().unwrap().terminator.succs[idx]; + + let extra_exprs = terminator_fpcs + .repacks + .iter() + .map(|repack| { + self.vcx.alloc(ExprGenData::Todo( + self.vcx.alloc_str(&format!("{repack:?}")), + )) + }) + .collect::>(); + ( ty_out.expr_from_u128(value), //self.vcx.alloc(vir::ExprData::Todo(vir::vir_format!(self.vcx, "constant({value})"))), self.vcx .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), + self.vcx.alloc_slice(&extra_exprs), ) }) .collect::>(), @@ -1013,6 +1040,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { targets: self.vcx.alloc_slice(&[( self.vcx.mk_const(vir::ConstData::Bool(*expected)), &target_bb, + &[], )]), otherwise: self .vcx diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index 2350bba5091..6c10428905a 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -250,8 +250,6 @@ impl TaskEncoder for MirPureEncoder { .get_impure_fn_body_identity(local_def_id) }; - tracing::warn!("body {:?}", body.body()); - let expr_inner = Encoder::new(vcx, task_key.0, &body, deps).encode_body(); // We wrap the expression with an additional lazy that will perform diff --git a/vir/src/debug.rs b/vir/src/debug.rs index 20b94a59a19..1dd54b0300f 100644 --- a/vir/src/debug.rs +++ b/vir/src/debug.rs @@ -325,11 +325,13 @@ impl<'vir, Curr, Next> Debug for TerminatorStmtGenData<'vir, Curr, Next> { write!(f, "goto {:?}", data.otherwise) } else { for target in data.targets { - write!( - f, - "if ({:?} == {:?}) {{ goto {:?} }}\n else", - data.value, target.0, target.1 - )?; + write!(f, "if ({:?} == {:?}) {{", data.value, target.0)?; + + for extra in target.2 { + write!(f, "{extra:?}")?; + } + + write!(f, " goto {:?} }}\n else", target.1)?; } write!(f, " {{ goto {:?} }}", data.otherwise) } diff --git a/vir/src/gendata.rs b/vir/src/gendata.rs index e66fe87b499..c66648b3a8f 100644 --- a/vir/src/gendata.rs +++ b/vir/src/gendata.rs @@ -202,7 +202,11 @@ pub enum StmtGenData<'vir, Curr, Next> { #[derive(Reify)] pub struct GotoIfGenData<'vir, Curr, Next> { pub value: ExprGen<'vir, Curr, Next>, - pub targets: &'vir [(ExprGen<'vir, Curr, Next>, CfgBlockLabel<'vir>)], + pub targets: &'vir [( + ExprGen<'vir, Curr, Next>, + CfgBlockLabel<'vir>, + &'vir [ExprGen<'vir, Curr, Next>], + )], #[reify_copy] pub otherwise: CfgBlockLabel<'vir>, } diff --git a/vir/src/reify.rs b/vir/src/reify.rs index 6e40b4d69ae..f5267310bc3 100644 --- a/vir/src/reify.rs +++ b/vir/src/reify.rs @@ -73,14 +73,21 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> for [( ExprGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>, CfgBlockLabel<'vir>, + &'vir [ExprGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>], )] { - type Next = &'vir [(ExprGen<'vir, NextA, NextB>, CfgBlockLabel<'vir>)]; + type Next = &'vir [( + ExprGen<'vir, NextA, NextB>, + CfgBlockLabel<'vir>, + &'vir [ExprGen<'vir, NextA, NextB>], + )]; fn reify(&self, vcx: &'vir VirCtxt<'vir>, lctx: Curr) -> Self::Next { vcx.alloc_slice( &self .iter() - .map(|(elem, label)| (elem.reify(vcx, lctx), *label)) + .map(|(elem, label, extra_exprs)| { + (elem.reify(vcx, lctx), *label, extra_exprs.reify(vcx, lctx)) + }) .collect::>(), ) } From 275cf98b0a051f042b7b87e32312cec7e0979d5f Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Fri, 10 Nov 2023 11:21:14 +0100 Subject: [PATCH 17/18] Try terminator repacks --- prusti-encoder/src/encoders/mir_impure.rs | 113 +++++++++++++++++----- prusti-encoder/src/encoders/typ.rs | 8 +- vir/src/debug.rs | 8 +- vir/src/gendata.rs | 3 +- vir/src/reify.rs | 6 +- 5 files changed, 103 insertions(+), 35 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 38fa83250a5..914100a1319 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -10,7 +10,7 @@ use prusti_rustc_interface::{ // SsaAnalysis, //}; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; -use vir::{ExprData, ExprGenData, PredicateAppGenData}; +use vir::{ExprData, ExprGenData, PredicateAppGenData, StmtGenData}; pub struct MirImpureEncoder; @@ -337,11 +337,38 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { } */ - fn fpcs_location(&mut self, location: mir::Location) { - let repacks = self.current_fpcs.as_ref().unwrap().statements[location.statement_index] - .repacks - .clone(); + /// Do the same as [self::fpcs_repacks] but instead of adding the statements to [self.current_stmts] return them instead + fn collect_repacks(&mut self, repacks: Vec>) -> Vec<&'vir vir::StmtData<'vir>> { + let mut real_stmts = Some(Vec::new()); + std::mem::swap(&mut self.current_stmts, &mut real_stmts); + + // FIXME: debug delete + let has_elements = !repacks.is_empty(); + if has_elements { + self.stmt(StmtGenData::Comment( + self.vcx.alloc_str(&format!("collect_repacks start")), + )); + } + + self.fpcs_repacks(repacks); + + // FIXME: debug delete + if has_elements { + self.stmt(StmtGenData::Comment( + self.vcx.alloc_str(&format!("collect_repacks end")), + )); + } + + std::mem::swap(&mut self.current_stmts, &mut real_stmts); + + real_stmts.unwrap() + } + + fn fpcs_repacks(&mut self, repacks: Vec>) { for repack_op in repacks { + self.stmt(StmtGenData::Comment( + self.vcx.alloc_str(&format!("repack {repack_op:?}")), + )); match repack_op { RepackOp::Expand(place, _target, capability_kind) | RepackOp::Collapse(place, _target, capability_kind) => { @@ -400,11 +427,19 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { })), ))); } - unsupported_op => panic!("unsupported repack op: {unsupported_op:?}"), + unsupported_op => tracing::error!("unsupported repack op: {unsupported_op:?}"), } } } + fn fpcs_location(&mut self, location: mir::Location) { + let repacks = self.current_fpcs.as_ref().unwrap().statements[location.statement_index] + .repacks + .clone(); + + self.fpcs_repacks(repacks) + } + fn encode_operand_snap(&mut self, operand: &mir::Operand<'vir>) -> vir::Expr<'vir> { match operand { &mir::Operand::Move(source) => { @@ -552,11 +587,10 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { fn visit_basic_block_data(&mut self, block: mir::BasicBlock, data: &mir::BasicBlockData<'vir>) { self.current_fpcs = Some(self.fpcs_analysis.get_all_for_bb(block)); - - tracing::warn!( - "fpcs for {block:?} is {:#?}", - self.current_fpcs.as_ref().map(|e| &e.terminator) - ); + // tracing::warn!( + // "fpcs for {block:?} is {:#?}", + // self.current_fpcs.as_ref().map(|e| &e.terminator) + // ); self.current_stmts = Some(Vec::with_capacity( data.statements.len(), // TODO: not exact? @@ -758,7 +792,8 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { let cons_name = match kind { box mir::AggregateKind::Adt(_,vidx,_, _, _) if dest_ty_out.is_enum() => { - vir::vir_format!(self.vcx, "{}_{vidx:?}_cons", dest_ty_out.snapshot_name) + let v_ty = &dest_ty_out.expect_enum().variants[vidx.as_usize()]; + vir::vir_format!(self.vcx, "{}_cons", v_ty.snapshot_name) //TODO get differently } _ => vir::vir_format!(self.vcx, "{}_cons", dest_ty_out.snapshot_name) }; @@ -861,10 +896,30 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { | mir::TerminatorKind::FalseEdge { real_target: target, .. - } => self.vcx.alloc(vir::TerminatorStmtData::Goto( - self.vcx - .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), - )), + } => { + let len = { + let succs = &self.current_fpcs.as_ref().unwrap().terminator.succs; + assert!(succs.len() <= 2); + let real_succ = &succs[0]; + assert_eq!(&real_succ.location.block, target); + succs.len() + }; + for i in 0..len { + let repacks = self.current_fpcs.as_ref().unwrap().terminator.succs[i] + .repacks + .clone(); + let mut stmts = self.collect_repacks(repacks); + // TODO this is messy + for stmt in stmts.drain(..) { + self.current_stmts.as_mut().unwrap().push(stmt); + } + } + + self.vcx.alloc(vir::TerminatorStmtData::Goto( + self.vcx + .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())), + )) + } mir::TerminatorKind::SwitchInt { discr, targets } => { //let discr_version = self.ssa_analysis.version.get(&(location, discr.local)).unwrap(); //let discr_name = vir::vir_format!(self.vcx, "_{}s_{}", discr.local.index(), discr_version); @@ -877,7 +932,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { .unwrap(); tracing::warn!( - "terminator fpcs {:?}", + "goto terminator fpcs {:?}", self.current_fpcs.as_ref().unwrap().terminator ); @@ -889,15 +944,9 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { let terminator_fpcs = &self.current_fpcs.as_ref().unwrap().terminator.succs[idx]; - let extra_exprs = terminator_fpcs - .repacks - .iter() - .map(|repack| { - self.vcx.alloc(ExprGenData::Todo( - self.vcx.alloc_str(&format!("{repack:?}")), - )) - }) - .collect::>(); + assert_eq!(terminator_fpcs.location.block, target); + + let extra_exprs = self.collect_repacks(terminator_fpcs.repacks.clone()); ( ty_out.expr_from_u128(value), @@ -909,10 +958,20 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { }) .collect::>(), ); + let goto_otherwise = self.vcx.alloc(vir::CfgBlockLabelData::BasicBlock( targets.otherwise().as_usize(), )); + let otherwise_expr = { + let terminator_fpcs = + &self.current_fpcs.as_ref().unwrap().terminator.succs[goto_targets.len()]; + + assert_eq!(terminator_fpcs.location.block, targets.otherwise()); + + self.collect_repacks(terminator_fpcs.repacks.clone()) + }; + let discr_ex = self.encode_operand_snap(discr); self.vcx .alloc(vir::TerminatorStmtData::GotoIf(self.vcx.alloc( @@ -920,6 +979,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { value: discr_ex, // self.vcx.mk_local_ex(discr_name), targets: goto_targets, otherwise: goto_otherwise, + otherwise_extra: self.vcx.alloc_slice(&otherwise_expr), }, ))) } @@ -1045,6 +1105,7 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { otherwise: self .vcx .alloc(vir::CfgBlockLabelData::BasicBlock(otherwise.as_usize())), + otherwise_extra: &[], }), )) } diff --git a/prusti-encoder/src/encoders/typ.rs b/prusti-encoder/src/encoders/typ.rs index e47b862e79e..32e52b88fef 100644 --- a/prusti-encoder/src/encoders/typ.rs +++ b/prusti-encoder/src/encoders/typ.rs @@ -477,8 +477,8 @@ fn mk_enum<'vir>( let mut variants: Vec> = Vec::new(); for (idx, variant) in adt.variants().iter().enumerate() { - let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}"); - let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}_{idx}"); + let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}_{}", variant.name.as_str()); + let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}_{idx}_{}", variant.name.as_str()); let ref_sub_struct = mk_output_ref_sub_struct(name_p, name_s, variant.fields.len(), vcx); @@ -538,8 +538,8 @@ fn mk_enum<'vir>( let mut snap_cur = vcx.mk_func_app(vir::vir_format!(vcx, "{name_s}_unreachable"), &[]); for (idx, variant) in adt.variants().iter().enumerate() { - let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}"); - let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}_{idx}"); + let name_s = vir::vir_format!(vcx, "s_Adt_{did_name}_{idx}_{}", variant.name.as_str()); + let name_p = vir::vir_format!(vcx, "p_Adt_{did_name}_{idx}_{}", variant.name.as_str()); let (_, cons_call) = mk_enum_variant( vcx, diff --git a/vir/src/debug.rs b/vir/src/debug.rs index 1dd54b0300f..12710c31c88 100644 --- a/vir/src/debug.rs +++ b/vir/src/debug.rs @@ -333,7 +333,13 @@ impl<'vir, Curr, Next> Debug for TerminatorStmtGenData<'vir, Curr, Next> { write!(f, " goto {:?} }}\n else", target.1)?; } - write!(f, " {{ goto {:?} }}", data.otherwise) + write!(f, " {{ ")?; + + for extra in data.otherwise_extra { + write!(f, "{extra:?}")?; + } + + write!(f, "goto {:?} }}", data.otherwise) } } Self::Exit => write!(f, "// return"), diff --git a/vir/src/gendata.rs b/vir/src/gendata.rs index c66648b3a8f..5c523230dec 100644 --- a/vir/src/gendata.rs +++ b/vir/src/gendata.rs @@ -205,10 +205,11 @@ pub struct GotoIfGenData<'vir, Curr, Next> { pub targets: &'vir [( ExprGen<'vir, Curr, Next>, CfgBlockLabel<'vir>, - &'vir [ExprGen<'vir, Curr, Next>], + &'vir [StmtGen<'vir, Curr, Next>], )], #[reify_copy] pub otherwise: CfgBlockLabel<'vir>, + pub otherwise_extra: &'vir [StmtGen<'vir, Curr, Next>], } #[derive(Reify)] diff --git a/vir/src/reify.rs b/vir/src/reify.rs index f5267310bc3..8907849cce0 100644 --- a/vir/src/reify.rs +++ b/vir/src/reify.rs @@ -70,16 +70,16 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> } impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr, NextA, NextB> - for [( + for &[( ExprGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>, CfgBlockLabel<'vir>, - &'vir [ExprGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>], + &'vir [StmtGen<'vir, Curr, ExprGen<'vir, NextA, NextB>>], )] { type Next = &'vir [( ExprGen<'vir, NextA, NextB>, CfgBlockLabel<'vir>, - &'vir [ExprGen<'vir, NextA, NextB>], + &'vir [StmtGen<'vir, NextA, NextB>], )]; fn reify(&self, vcx: &'vir VirCtxt<'vir>, lctx: Curr) -> Self::Next { vcx.alloc_slice( From 8136f8de929c4d6a5798c66cdd2c6327219dfcf5 Mon Sep 17 00:00:00 2001 From: Till Arnold Date: Mon, 13 Nov 2023 13:37:59 +0100 Subject: [PATCH 18/18] WIP --- prusti-encoder/src/encoders/mir_impure.rs | 37 +++++++++++++++-------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 914100a1319..4345119cf64 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -375,7 +375,9 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { if matches!(capability_kind, CapabilityKind::Write) { // Collapsing an already exhaled place is a no-op // TODO: unless it's through a Ref I imagine? - assert!(matches!(repack_op, RepackOp::Collapse(..))); + if !matches!(repack_op, RepackOp::Collapse(..)) { + tracing::error!("TODO: Not matching Collapse: {repack_op:?}") + } return; } let place_ty = place.ty(self.local_decls, self.vcx.tcx); @@ -427,7 +429,13 @@ impl<'vir, 'enc> EncoderVisitor<'vir, 'enc> { })), ))); } - unsupported_op => tracing::error!("unsupported repack op: {unsupported_op:?}"), + unsupported_op => { + self.stmt(vir::StmtData::Comment( + self.vcx + .alloc_str(&format!("unsupported repack op: {unsupported_op:?}")), + )); + tracing::error!("unsupported repack op: {unsupported_op:?}"); + } } } } @@ -900,21 +908,24 @@ impl<'vir, 'enc> mir::visit::Visitor<'vir> for EncoderVisitor<'vir, 'enc> { let len = { let succs = &self.current_fpcs.as_ref().unwrap().terminator.succs; assert!(succs.len() <= 2); - let real_succ = &succs[0]; - assert_eq!(&real_succ.location.block, target); succs.len() }; - for i in 0..len { - let repacks = self.current_fpcs.as_ref().unwrap().terminator.succs[i] - .repacks - .clone(); - let mut stmts = self.collect_repacks(repacks); - // TODO this is messy - for stmt in stmts.drain(..) { - self.current_stmts.as_mut().unwrap().push(stmt); - } + + { + let real_succ = &self.current_fpcs.as_ref().unwrap().terminator.succs[0]; + assert_eq!(&real_succ.location.block, target); + self.fpcs_repacks(real_succ.repacks.clone()); } + // TODO: do we really need all succs or just the real_succ one? + // for i in 0..len { + // let repacks = self.current_fpcs.as_ref().unwrap().terminator.succs[i] + // .repacks + // .clone(); + + // self.fpcs_repacks(repacks) + // } + self.vcx.alloc(vir::TerminatorStmtData::Goto( self.vcx .alloc(vir::CfgBlockLabelData::BasicBlock(target.as_usize())),