From c3b2a75ff0edc587c471140af15cf505ffe2a336 Mon Sep 17 00:00:00 2001 From: Aurel Date: Tue, 14 May 2024 17:36:40 +0200 Subject: [PATCH] Support cycles in encoders (#47) * domain fields don't need a full encoding of their type yet * type alias for do_encode_full result * parameterise TaskEncoderDependencies by the owning encoder * remove some dependency unwraps * remove 'tcx lifetime, use 'vir * check for cycles when requesting dependencies or emitting output ref * add try operators for some emit output refs --- .../src/encoder_traits/function_enc.rs | 16 +- .../src/encoder_traits/impure_function_enc.rs | 34 +-- .../src/encoder_traits/pure_func_app_enc.rs | 30 +- .../src/encoder_traits/pure_function_enc.rs | 30 +- prusti-encoder/src/encoders/const.rs | 34 +-- prusti-encoder/src/encoders/generic.rs | 25 +- prusti-encoder/src/encoders/local_def.rs | 29 +- prusti-encoder/src/encoders/mir_builtin.rs | 59 ++-- prusti-encoder/src/encoders/mir_impure.rs | 93 +++--- .../src/encoders/mir_poly_impure.rs | 22 +- prusti-encoder/src/encoders/mir_pure.rs | 79 +++-- .../src/encoders/mir_pure_function.rs | 19 +- .../src/encoders/mono/mir_impure.rs | 22 +- .../src/encoders/mono/mir_pure_function.rs | 19 +- prusti-encoder/src/encoders/pure/spec.rs | 30 +- prusti-encoder/src/encoders/spec.rs | 17 +- prusti-encoder/src/encoders/type/domain.rs | 108 ++++--- .../encoders/type/lifted/aggregate_cast.rs | 21 +- .../src/encoders/type/lifted/cast.rs | 51 ++-- .../encoders/type/lifted/cast_functions.rs | 32 +- .../src/encoders/type/lifted/casters.rs | 44 +-- .../type/lifted/func_app_ty_params.rs | 21 +- .../type/lifted/func_def_ty_params.rs | 21 +- .../src/encoders/type/lifted/generic.rs | 25 +- .../src/encoders/type/lifted/rust_ty_cast.rs | 51 ++-- prusti-encoder/src/encoders/type/lifted/ty.rs | 46 +-- .../encoders/type/lifted/ty_constructor.rs | 25 +- prusti-encoder/src/encoders/type/predicate.rs | 44 ++- .../src/encoders/type/rust_ty_predicates.rs | 30 +- .../src/encoders/type/rust_ty_snapshots.rs | 30 +- prusti-encoder/src/encoders/type/snapshot.rs | 27 +- .../src/encoders/type/viper_tuple.rs | 19 +- prusti-encoder/src/lib.rs | 2 +- task-encoder/src/lib.rs | 277 ++++++++++++------ 34 files changed, 634 insertions(+), 798 deletions(-) diff --git a/prusti-encoder/src/encoder_traits/function_enc.rs b/prusti-encoder/src/encoder_traits/function_enc.rs index 1fedfbf363f..1ae34f70b09 100644 --- a/prusti-encoder/src/encoder_traits/function_enc.rs +++ b/prusti-encoder/src/encoder_traits/function_enc.rs @@ -19,10 +19,10 @@ pub trait FunctionEnc /// this should be the identity substitution obtained from the DefId of the /// function. For the monomorphic encoding, the substitutions at the call /// site should be used. - fn get_substs<'tcx>( - vcx: &vir::VirCtxt<'tcx>, - substs_src: &Self::TaskKey<'tcx>, - ) -> &'tcx GenericArgs<'tcx>; + fn get_substs<'vir>( + vcx: &vir::VirCtxt<'vir>, + substs_src: &Self::TaskKey<'vir>, + ) -> &'vir GenericArgs<'vir>; } /// Implementation for polymorphic encoding @@ -35,10 +35,10 @@ impl TaskEncoder = DefId>> FunctionEnc for None } - fn get_substs<'tcx>( - vcx: &vir::VirCtxt<'tcx>, - def_id: &Self::TaskKey<'tcx>, - ) -> &'tcx GenericArgs<'tcx> { + fn get_substs<'vir>( + vcx: &vir::VirCtxt<'vir>, + def_id: &Self::TaskKey<'vir>, + ) -> &'vir GenericArgs<'vir> { GenericArgs::identity_for_item(vcx.tcx(), *def_id) } diff --git a/prusti-encoder/src/encoder_traits/impure_function_enc.rs b/prusti-encoder/src/encoder_traits/impure_function_enc.rs index 5a355143646..0eff53dce5c 100644 --- a/prusti-encoder/src/encoder_traits/impure_function_enc.rs +++ b/prusti-encoder/src/encoder_traits/impure_function_enc.rs @@ -1,5 +1,5 @@ use prusti_rustc_interface::middle::mir; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{EncodeFullError, TaskEncoder, TaskEncoderDependencies}; use vir::{MethodIdent, UnknownArity, ViperIdent}; use crate::encoders::{ @@ -33,15 +33,18 @@ where { /// Generates the identifier for the method; for a monomorphic encoding, /// this should be a name including (mangled) type arguments - fn mk_method_ident<'vir, 'tcx>( - vcx: &'vir vir::VirCtxt<'tcx>, - task_key: &Self::TaskKey<'tcx>, + fn mk_method_ident<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + task_key: &Self::TaskKey<'vir>, ) -> ViperIdent<'vir>; - fn encode<'vir, 'tcx: 'vir>( - task_key: Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> ImpureFunctionEncOutput<'vir> { + fn encode<'vir>( + task_key: Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> Result< + ImpureFunctionEncOutput<'vir>, + EncodeFullError<'vir, Self>, + > { let def_id = Self::get_def_id(&task_key); let caller_def_id = Self::get_caller_def_id(&task_key); let trusted = crate::encoders::with_proc_spec(def_id, |def_spec| { @@ -52,8 +55,7 @@ where use mir::visit::Visitor; let substs = Self::get_substs(vcx, &task_key); let local_defs = deps - .require_local::((def_id, substs, caller_def_id)) - .unwrap(); + .require_local::((def_id, substs, caller_def_id))?; // Argument count for the Viper method: // - one (`Ref`) for the return place; @@ -70,15 +72,14 @@ where let method_name = Self::mk_method_ident(vcx, &task_key); let mut args = vec![&vir::TypeData::Ref; arg_count]; let param_ty_decls = deps - .require_local::(substs) - .unwrap() + .require_local::(substs)? .iter() .map(|g| g.decl()) .collect::>(); args.extend(param_ty_decls.iter().map(|decl| decl.ty)); let args = UnknownArity::new(vcx.alloc_slice(&args)); let method_ref = MethodIdent::new(method_name, args); - deps.emit_output_ref::(task_key, ImpureFunctionEncOutputRef { method_ref }); + deps.emit_output_ref(task_key, ImpureFunctionEncOutputRef { method_ref })?; // Do not encode the method body if it is external, trusted or just // a call stub. @@ -157,8 +158,7 @@ where }; let spec = deps - .require_local::((def_id, substs, None, false)) - .unwrap(); + .require_local::((def_id, substs, None, false))?; let (spec_pres, spec_posts) = (spec.pres, spec.posts); let mut pres = Vec::with_capacity(arg_count - 1); @@ -177,7 +177,7 @@ where posts.push(local_defs.locals[mir::RETURN_PLACE].impure_pred); posts.extend(spec_posts); - ImpureFunctionEncOutput { + Ok(ImpureFunctionEncOutput { method: vcx.mk_method( method_ref, vcx.alloc_slice(&args), @@ -186,7 +186,7 @@ where vcx.alloc_slice(&posts), blocks, ), - } + }) }) } } diff --git a/prusti-encoder/src/encoder_traits/pure_func_app_enc.rs b/prusti-encoder/src/encoder_traits/pure_func_app_enc.rs index 04202f73937..739eda044b0 100644 --- a/prusti-encoder/src/encoder_traits/pure_func_app_enc.rs +++ b/prusti-encoder/src/encoder_traits/pure_func_app_enc.rs @@ -5,7 +5,7 @@ use prusti_rustc_interface::{ }, span::def_id::DefId, }; -use task_encoder::TaskEncoderDependencies; +use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use crate::encoders::{ lifted::{ @@ -16,7 +16,7 @@ use crate::encoders::{ /// Encoders (such as [`crate::encoders::MirPureEnc`], /// [`crate::encoders::MirImpureEnc`]) implement this trait to encode /// applications of Rust functions annotated as pure. -pub trait PureFuncAppEnc<'tcx: 'vir, 'vir> { +pub trait PureFuncAppEnc<'vir, E: TaskEncoder + 'vir + ?Sized> { /// Extra arguments required for the encoder to encode an argument to the /// function (in mir this is an `Operand`) type EncodeOperandArgs; @@ -29,32 +29,32 @@ pub trait PureFuncAppEnc<'tcx: 'vir, 'vir> { /// The type of the data source that can provide local declarations; this is used /// when getting the type of the function. - type LocalDeclsSrc: ?Sized + HasLocalDecls<'tcx>; + type LocalDeclsSrc: ?Sized + HasLocalDecls<'vir>; // Are we monomorphizing functions? fn monomorphize(&self) -> bool; /// Task encoder dependencies are required for encoding Viper casts between /// generic and concrete types. - fn deps(&mut self) -> &mut TaskEncoderDependencies<'vir>; + fn deps(&mut self) -> &mut TaskEncoderDependencies<'vir, E>; /// The data source that can provide local declarations, necesary for determining /// the function type fn local_decls_src(&self) -> &Self::LocalDeclsSrc; - fn vcx(&self) -> &'vir vir::VirCtxt<'tcx>; + fn vcx(&self) -> &'vir vir::VirCtxt<'vir>; /// Encodes an operand (an argument to a function) as a pure Viper expression. fn encode_operand( &mut self, args: &Self::EncodeOperandArgs, - operand: &mir::Operand<'tcx>, + operand: &mir::Operand<'vir>, ) -> vir::ExprGen<'vir, Self::Curr, Self::Next>; /// Obtains the function's definition ID and the substitutions made at the callsite fn get_def_id_and_caller_substs( &self, - func: &mir::Operand<'tcx>, - ) -> (DefId, &'tcx List>) { + func: &mir::Operand<'vir>, + ) -> (DefId, &'vir List>) { let func_ty = func.ty(self.local_decls_src(), self.vcx().tcx()); match func_ty.kind() { &ty::TyKind::FnDef(def_id, arg_tys) => (def_id, arg_tys), @@ -67,9 +67,9 @@ pub trait PureFuncAppEnc<'tcx: 'vir, 'vir> { /// are inserted to convert from/to generic and concrete arguments as necessary. fn encode_fn_args( &mut self, - sig: Binder<'tcx, FnSig<'tcx>>, - substs: &'tcx List>, - args: &[mir::Operand<'tcx>], + sig: Binder<'vir, FnSig<'vir>>, + substs: &'vir List>, + args: &[mir::Operand<'vir>], encode_operand_args: &Self::EncodeOperandArgs, ) -> Vec> { let mono = self.monomorphize(); @@ -118,10 +118,10 @@ pub trait PureFuncAppEnc<'tcx: 'vir, 'vir> { fn encode_pure_func_app( &mut self, def_id: DefId, - sig: Binder<'tcx, FnSig<'tcx>>, - substs: &'tcx List>, - args: &Vec>, - destination: &mir::Place<'tcx>, + sig: Binder<'vir, FnSig<'vir>>, + substs: &'vir List>, + args: &Vec>, + destination: &mir::Place<'vir>, caller_def_id: DefId, encode_operand_args: &Self::EncodeOperandArgs, ) -> vir::ExprGen<'vir, Self::Curr, Self::Next> { diff --git a/prusti-encoder/src/encoder_traits/pure_function_enc.rs b/prusti-encoder/src/encoder_traits/pure_function_enc.rs index 11c8cffb082..6a7269653a7 100644 --- a/prusti-encoder/src/encoder_traits/pure_function_enc.rs +++ b/prusti-encoder/src/encoder_traits/pure_function_enc.rs @@ -1,6 +1,7 @@ -use prusti_rustc_interface:: - middle::{mir, ty::Ty} -; +use prusti_rustc_interface::{ + middle::{mir, ty::{GenericArgs, Ty}}, + span::def_id::DefId, +}; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; use vir::{CallableIdent, ExprGen, FunctionIdent, Reify, UnknownArity, ViperIdent}; @@ -33,21 +34,20 @@ where /// Generates the identifier for the function; for a monomorphic encoding, /// this should be a name including (mangled) type arguments - fn mk_function_ident<'vir, 'tcx>( - vcx: &'vir vir::VirCtxt<'tcx>, - task_key: &Self::TaskKey<'tcx>, + fn mk_function_ident<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + task_key: &Self::TaskKey<'vir>, ) -> ViperIdent<'vir>; - /// Adds an assertion connecting the type of an argument (or return) of the /// function with the appropriate type based on the param, e.g. in f(u: U) -> T, this would be called to require that the type of `u` be /// `U` - fn mk_type_assertion<'vir, 'tcx: 'vir, Curr, Next>( - vcx: &'vir vir::VirCtxt<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, + fn mk_type_assertion<'vir, Curr, Next>( + vcx: &'vir vir::VirCtxt<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, arg: ExprGen<'vir, Curr, Next>, // Snapshot encoded argument - ty: Ty<'tcx>, + ty: Ty<'vir>, ) -> Option> { let lifted_ty = deps .require_local::>(ty) @@ -77,9 +77,9 @@ where } } - fn encode<'vir, 'tcx: 'vir>( - task_key: Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, + fn encode<'vir>( + task_key: Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, ) -> MirFunctionEncOutput<'vir> { let def_id = Self::get_def_id(&task_key); let caller_def_id = Self::get_caller_def_id(&task_key); @@ -106,7 +106,7 @@ where let ident_args = UnknownArity::new(vcx.alloc_slice(&ident_args)); let return_type = local_defs.locals[mir::RETURN_PLACE].ty; let function_ref = FunctionIdent::new(function_ident, ident_args, return_type.snapshot); - deps.emit_output_ref::(task_key, MirFunctionEncOutputRef { function_ref }); + deps.emit_output_ref(task_key, MirFunctionEncOutputRef { function_ref }); let spec = deps .require_local::((def_id, substs, None, true)) diff --git a/prusti-encoder/src/encoders/const.rs b/prusti-encoder/src/encoders/const.rs index 25733deb340..6043fb0ad08 100644 --- a/prusti-encoder/src/encoders/const.rs +++ b/prusti-encoder/src/encoders/const.rs @@ -6,6 +6,7 @@ use rustc_middle::mir::interpret::{ConstValue, Scalar, GlobalAlloc}; use task_encoder::{ TaskEncoder, TaskEncoderDependencies, + EncodeFullResult, }; use vir::{CallableIdent, Arity}; @@ -25,8 +26,8 @@ use super::{lifted::{casters::CastTypePure, rust_ty_cast::RustTyCastersEnc}, rus impl TaskEncoder for ConstEnc { task_encoder::encoder_cache!(ConstEnc); - type TaskDescription<'tcx> = ( - mir::ConstantKind<'tcx>, + type TaskDescription<'vir> = ( + mir::ConstantKind<'vir>, usize, // current encoding depth DefId, // DefId of the current function ); @@ -37,21 +38,15 @@ impl TaskEncoder for ConstEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { - deps.emit_output_ref::(*task_key, ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; let (const_, encoding_depth, def_id) = *task_key; let res = match const_ { mir::ConstantKind::Val(val, ty) => { - let kind = deps.require_local::(ty).unwrap().generic_snapshot.specifics; + let kind = deps.require_local::(ty)?.generic_snapshot.specifics; match val { ConstValue::Scalar(Scalar::Int(int)) => { let prim = kind.expect_primitive(); @@ -84,12 +79,11 @@ impl TaskEncoder for ConstEnc { let ref_ty = kind.expect_structlike(); let str_ty = ty.peel_refs(); let str_snap = deps - .require_local::(str_ty) - .unwrap() + .require_local::(str_ty)? .generic_snapshot .specifics .expect_structlike(); - let cast = deps.require_local::>(str_ty).unwrap(); + let cast = deps.require_local::>(str_ty)?; vir::with_vcx(|vcx| { // first, we create a string snapshot let snap = str_snap.field_snaps_to_snap.apply(vcx, &[]); @@ -112,10 +106,10 @@ impl TaskEncoder for ConstEnc { kind: PureKind::Constant(uneval.promoted.unwrap()), caller_def_id: Some(def_id) }; - let expr = deps.require_local::(task).unwrap().expr; + let expr = deps.require_local::(task)?.expr; use vir::Reify; - expr.reify(vcx, (uneval.def, &[])) - }), + Ok(expr.reify(vcx, (uneval.def, &[]))) + })?, mir::ConstantKind::Ty(_) => todo!("ConstantKind::Ty"), }; Ok((res, ())) diff --git a/prusti-encoder/src/encoders/generic.rs b/prusti-encoder/src/encoders/generic.rs index 19361bb2a4a..ad562a82095 100644 --- a/prusti-encoder/src/encoders/generic.rs +++ b/prusti-encoder/src/encoders/generic.rs @@ -1,4 +1,4 @@ -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use vir::{ BinaryArity, CallableIdent, DomainIdent, DomainParamData, FunctionIdent, KnownArityAny, NullaryArity, PredicateIdent, TypeData, UnaryArity, ViperIdent, @@ -39,7 +39,7 @@ const SNAPSHOT_PARAM_DOMAIN: TypeData<'static> = TypeData::Domain("s_Param", &[] impl TaskEncoder for GenericEnc { task_encoder::encoder_cache!(GenericEnc); - type TaskDescription<'tcx> = (); // ? + type TaskDescription<'vir> = (); // ? type OutputRef<'vir> = GenericEncOutputRef<'vir>; type OutputFullLocal<'vir> = GenericEncOutput<'vir>; @@ -51,19 +51,10 @@ impl TaskEncoder for GenericEnc { } #[allow(non_snake_case)] - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { let ref_to_pred = PredicateIdent::new(ViperIdent::new("p_Param"), BinaryArity::new(&[&TypeData::Ref, &TYP_DOMAIN])); let type_domain_ident = DomainIdent::nullary(ViperIdent::new("Type")); @@ -98,10 +89,10 @@ impl TaskEncoder for GenericEnc { }; #[allow(clippy::unit_arg)] - deps.emit_output_ref::( + deps.emit_output_ref( *task_key, output_ref - ); + )?; let typ = FunctionIdent::new( ViperIdent::new("typ"), diff --git a/prusti-encoder/src/encoders/local_def.rs b/prusti-encoder/src/encoders/local_def.rs index 08070db5b7a..de10e2379bc 100644 --- a/prusti-encoder/src/encoders/local_def.rs +++ b/prusti-encoder/src/encoders/local_def.rs @@ -4,7 +4,7 @@ use prusti_rustc_interface::{ span::def_id::DefId }; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use crate::encoders::{ rust_ty_predicates::{RustTyPredicatesEnc, RustTyPredicatesEncOutputRef}, @@ -31,9 +31,9 @@ pub struct LocalDef<'vir> { impl TaskEncoder for MirLocalDefEnc { task_encoder::encoder_cache!(MirLocalDefEnc); - type TaskDescription<'tcx> = ( + type TaskDescription<'vir> = ( DefId, // ID of the function - ty::GenericArgsRef<'tcx>, // ? this should be the "signature", after applying the env/substs + ty::GenericArgsRef<'vir>, // ? this should be the "signature", after applying the env/substs Option, // ID of the caller function, if any ); @@ -45,24 +45,15 @@ impl TaskEncoder for MirLocalDefEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { let (def_id, substs, caller_def_id) = *task_key; - deps.emit_output_ref::(*task_key, ()); + deps.emit_output_ref(*task_key, ())?; - fn mk_local_def<'vir, 'tcx>( - vcx: &'vir vir::VirCtxt<'tcx>, + fn mk_local_def<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, name: &'vir str, ty: RustTyPredicatesEncOutputRef<'vir>, ) -> LocalDef<'vir> { diff --git a/prusti-encoder/src/encoders/mir_builtin.rs b/prusti-encoder/src/encoders/mir_builtin.rs index e7a87fd05d4..ef1d664939e 100644 --- a/prusti-encoder/src/encoders/mir_builtin.rs +++ b/prusti-encoder/src/encoders/mir_builtin.rs @@ -5,6 +5,7 @@ use prusti_rustc_interface::{ use task_encoder::{ TaskEncoder, TaskEncoderDependencies, + EncodeFullResult, }; use vir::{UnknownArity, FunctionIdent, CallableIdent}; @@ -51,16 +52,10 @@ impl TaskEncoder for MirBuiltinEnc { task.clone() } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { vir::with_vcx(|vcx| { match *task_key { MirBuiltinEncTask::UnOp(res_ty, op, operand_ty) => { @@ -92,12 +87,12 @@ fn int_name<'tcx>(ty: ty::Ty<'tcx>) -> &'static str { } impl MirBuiltinEnc { - fn handle_un_op<'vir, 'tcx>( - vcx: &'vir vir::VirCtxt<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - key: ::TaskKey<'tcx>, + fn handle_un_op<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + key: ::TaskKey<'vir>, op: mir::UnOp, - ty: ty::Ty<'tcx>, + ty: ty::Ty<'vir>, ) -> vir::Function<'vir> { let e_ty = deps .require_local::(ty) @@ -107,7 +102,7 @@ impl MirBuiltinEnc { let name = vir::vir_format_identifier!(vcx, "mir_unop_{op:?}_{}", int_name(ty)); let arity = UnknownArity::new(vcx.alloc_slice(&[e_ty.snapshot])); let function = FunctionIdent::new(name, arity, e_ty.snapshot); - deps.emit_output_ref::(key, MirBuiltinEncOutputRef { + deps.emit_output_ref(key, MirBuiltinEncOutputRef { function, }); @@ -140,14 +135,14 @@ impl MirBuiltinEnc { ) } - fn handle_bin_op<'vir, 'tcx>( - vcx: &'vir vir::VirCtxt<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - key: ::TaskKey<'tcx>, - res_ty: ty::Ty<'tcx>, + fn handle_bin_op<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + key: ::TaskKey<'vir>, + res_ty: ty::Ty<'vir>, op: mir::BinOp, - l_ty: ty::Ty<'tcx>, - r_ty: ty::Ty<'tcx>, + l_ty: ty::Ty<'vir>, + r_ty: ty::Ty<'vir>, ) -> vir::Function<'vir> { use mir::BinOp::*; let e_l_ty = deps @@ -169,7 +164,7 @@ impl MirBuiltinEnc { let name = vir::vir_format_identifier!(vcx, "mir_binop_{op:?}_{}_{}", int_name(l_ty), int_name(r_ty)); let arity = UnknownArity::new(vcx.alloc_slice(&[e_l_ty.snapshot, e_r_ty.snapshot])); let function = FunctionIdent::new(name, arity, e_res_ty.snapshot); - deps.emit_output_ref::(key, MirBuiltinEncOutputRef { + deps.emit_output_ref(key, MirBuiltinEncOutputRef { function, }); let lhs = prim_l_ty.snap_to_prim.apply(vcx, @@ -268,14 +263,14 @@ impl MirBuiltinEnc { ) } - fn handle_checked_bin_op<'vir, 'tcx>( - vcx: &'vir vir::VirCtxt<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - key: ::TaskKey<'tcx>, - res_ty: ty::Ty<'tcx>, + fn handle_checked_bin_op<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + key: ::TaskKey<'vir>, + res_ty: ty::Ty<'vir>, op: mir::BinOp, - l_ty: ty::Ty<'tcx>, - r_ty: ty::Ty<'tcx>, + l_ty: ty::Ty<'vir>, + r_ty: ty::Ty<'vir>, ) -> vir::Function<'vir> { // `op` can only be `Add`, `Sub` or `Mul` assert!(matches!( @@ -303,7 +298,7 @@ impl MirBuiltinEnc { .unwrap() .generic_snapshot; let function = FunctionIdent::new(name, arity, e_res_ty.snapshot); - deps.emit_output_ref::(key, MirBuiltinEncOutputRef { function }); + deps.emit_output_ref(key, MirBuiltinEncOutputRef { function }); let e_res_ty = deps .require_local::(res_ty) diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 8a8c9d11278..dcb9fa8b796 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -13,7 +13,7 @@ use prusti_rustc_interface::{ //use mir_ssa_analysis::{ // SsaAnalysis, //}; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use vir::{MethodIdent, UnknownArity}; pub struct MirImpureEnc; @@ -74,7 +74,7 @@ impl MirImpureEnc { impl TaskEncoder for MirImpureEnc { task_encoder::encoder_cache!(MirImpureEnc); - type TaskDescription<'tcx> = FunctionCallTaskDescription<'tcx>; + type TaskDescription<'vir> = FunctionCallTaskDescription<'vir>; type OutputRef<'vir> = ImpureFunctionEncOutputRef<'vir>; type OutputFullLocal<'vir> = ImpureFunctionEncOutput<'vir>; @@ -85,52 +85,43 @@ impl TaskEncoder for MirImpureEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { let monomorphize = Self::monomorphize(); let output_ref = if monomorphize { - deps.require_ref::(*task_key).unwrap() + deps.require_ref::(*task_key)? } else { - deps.require_ref::(task_key.def_id).unwrap() + deps.require_ref::(task_key.def_id)? }; - deps.emit_output_ref::(*task_key, output_ref); + deps.emit_output_ref(*task_key, output_ref); let output: ImpureFunctionEncOutput<'_> = if monomorphize { - deps.require_local::(*task_key).unwrap() + deps.require_local::(*task_key)? } else { - deps.require_local::(task_key.def_id).unwrap() + deps.require_local::(task_key.def_id)? }; Ok((output, ())) } } -pub struct ImpureEncVisitor<'tcx, 'vir, 'enc> +pub struct ImpureEncVisitor<'vir, 'enc, E: TaskEncoder> where 'vir: 'enc { - pub vcx: &'vir vir::VirCtxt<'tcx>, + pub vcx: &'vir vir::VirCtxt<'vir>, // Are we monomorphizing functions? pub monomorphize: bool, - pub deps: &'enc mut TaskEncoderDependencies<'vir>, + pub deps: &'enc mut TaskEncoderDependencies<'vir, E>, pub def_id: DefId, - pub local_decls: &'enc mir::LocalDecls<'tcx>, + pub local_decls: &'enc mir::LocalDecls<'vir>, //ssa_analysis: SsaAnalysis, - pub fpcs_analysis: FreePcsAnalysis<'enc, 'tcx>, + pub fpcs_analysis: FreePcsAnalysis<'enc, 'vir>, pub local_defs: crate::encoders::MirLocalDefEncOutput<'vir>, pub tmp_ctr: usize, // for the current basic block - pub current_fpcs: Option>, + pub current_fpcs: Option>, pub current_stmts: Option>>, pub current_terminator: Option>, @@ -138,16 +129,16 @@ pub struct ImpureEncVisitor<'tcx, 'vir, 'enc> pub encoded_blocks: Vec>, // TODO: use IndexVec ? } -impl<'tcx: 'vir, 'vir> PureFuncAppEnc<'tcx, 'vir> for ImpureEncVisitor<'tcx, 'vir, '_> { +impl<'vir, E: TaskEncoder> PureFuncAppEnc<'vir, E> for ImpureEncVisitor<'vir, '_, E> { type EncodeOperandArgs = (); type Curr = !; type Next = !; - type LocalDeclsSrc = mir::LocalDecls<'tcx>; - fn vcx(&self) -> &'vir vir::VirCtxt<'tcx> { + type LocalDeclsSrc = mir::LocalDecls<'vir>; + fn vcx(&self) -> &'vir vir::VirCtxt<'vir> { self.vcx } - fn deps(&mut self) -> &mut TaskEncoderDependencies<'vir> { + fn deps(&mut self) -> &mut TaskEncoderDependencies<'vir, E> { self.deps } @@ -158,7 +149,7 @@ impl<'tcx: 'vir, 'vir> PureFuncAppEnc<'tcx, 'vir> for ImpureEncVisitor<'tcx, 'vi fn encode_operand( &mut self, _args: &Self::EncodeOperandArgs, - operand: &mir::Operand<'tcx>, + operand: &mir::Operand<'vir>, ) -> vir::ExprGen<'vir, Self::Curr, Self::Next> { self.encode_operand_snap(operand) } @@ -189,7 +180,7 @@ impl<'vir> EncodePlaceResult<'vir> { } } -impl<'tcx, 'vir, 'enc> ImpureEncVisitor<'tcx, 'vir, 'enc> { +impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { fn stmt(&mut self, stmt: vir::Stmt<'vir>) { self.current_stmts .as_mut() @@ -261,7 +252,7 @@ impl<'tcx, 'vir, 'enc> ImpureEncVisitor<'tcx, 'vir, 'enc> { fn fpcs_repacks( &mut self, - repacks: &[RepackOp<'tcx>], + repacks: &[RepackOp<'vir>], ) { for &repack_op in repacks { match repack_op { @@ -334,7 +325,7 @@ impl<'tcx, 'vir, 'enc> ImpureEncVisitor<'tcx, 'vir, 'enc> { result.undo_casts.iter().for_each(|stmt| self.stmt(stmt)); } - fn encode_operand_snap(&mut self, operand: &mir::Operand<'tcx>) -> vir::Expr<'vir> { + fn encode_operand_snap(&mut self, operand: &mir::Operand<'vir>) -> vir::Expr<'vir> { let ty = operand.ty(self.local_decls, self.vcx.tcx()); match operand { &mir::Operand::Move(source) => { @@ -368,7 +359,7 @@ impl<'tcx, 'vir, 'enc> ImpureEncVisitor<'tcx, 'vir, 'enc> { fn encode_operand( &mut self, - operand: &mir::Operand<'tcx>, + operand: &mir::Operand<'vir>, ) -> vir::Expr<'vir> { let ty = operand.ty(self.local_decls, self.vcx.tcx()); let (encode_place_result, ty_out) = match operand { @@ -393,7 +384,7 @@ impl<'tcx, 'vir, 'enc> ImpureEncVisitor<'tcx, 'vir, 'enc> { fn encode_place( &mut self, - place: Place<'tcx>, + place: Place<'vir>, ) -> EncodePlaceResult<'vir> { let mut place_ty = mir::tcx::PlaceTy::from_ty(self.local_decls[place.local].ty); let mut result = EncodePlaceResult::new(self.local_defs.locals[place.local].local_ex); @@ -414,8 +405,8 @@ impl<'tcx, 'vir, 'enc> ImpureEncVisitor<'tcx, 'vir, 'enc> { // it. fn encode_place_element( &mut self, - place_ty: mir::tcx::PlaceTy<'tcx>, - elem: mir::PlaceElem<'tcx>, + place_ty: mir::tcx::PlaceTy<'vir>, + elem: mir::PlaceElem<'vir>, expr: vir::Expr<'vir> ) -> (vir::Expr<'vir>, Option>) { match elem { @@ -495,14 +486,14 @@ impl<'tcx, 'vir, 'enc> ImpureEncVisitor<'tcx, 'vir, 'enc> { } } -impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for ImpureEncVisitor<'tcx, 'vir, 'enc> { +impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor<'vir, 'enc, E> { // fn visit_body(&mut self, body: &mir::Body<'tcx>) { // println!("visiting body!"); // } fn visit_basic_block_data( &mut self, block: mir::BasicBlock, - data: &mir::BasicBlockData<'tcx>, + data: &mir::BasicBlockData<'vir>, ) { self.current_fpcs = Some(self.fpcs_analysis.get_all_for_bb(block)); @@ -563,7 +554,7 @@ impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for ImpureEncVisitor<'tcx, 'vir fn visit_statement( &mut self, - statement: &mir::Statement<'tcx>, + statement: &mir::Statement<'vir>, location: mir::Location, ) { // TODO: these should not be ignored, but should havoc the local instead @@ -606,12 +597,12 @@ impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for ImpureEncVisitor<'tcx, 'vir let expr = match rvalue { mir::Rvalue::Use(op) => self.encode_operand_snap(op), - //mir::Rvalue::Repeat(Operand<'tcx>, Const<'tcx>) => {} - //mir::Rvalue::Ref(Region<'tcx>, BorrowKind, Place<'tcx>) => {} + //mir::Rvalue::Repeat(Operand<'vir>, Const<'vir>) => {} + //mir::Rvalue::Ref(Region<'vir>, BorrowKind, Place<'vir>) => {} //mir::Rvalue::ThreadLocalRef(DefId) => {} - //mir::Rvalue::AddressOf(Mutability, Place<'tcx>) => {} - //mir::Rvalue::Len(Place<'tcx>) => {} - //mir::Rvalue::Cast(CastKind, Operand<'tcx>, Ty<'tcx>) => {} + //mir::Rvalue::AddressOf(Mutability, Place<'vir>) => {} + //mir::Rvalue::Len(Place<'vir>) => {} + //mir::Rvalue::Cast(CastKind, Operand<'vir>, Ty<'vir>) => {} rv@mir::Rvalue::BinaryOp(op, box (l, r)) | rv@mir::Rvalue::CheckedBinaryOp(op, box (l, r)) => { @@ -632,7 +623,7 @@ impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for ImpureEncVisitor<'tcx, 'vir ]) } - //mir::Rvalue::NullaryOp(NullOp, Ty<'tcx>) => {} + //mir::Rvalue::NullaryOp(NullOp, Ty<'vir>) => {} mir::Rvalue::UnaryOp(unop, operand) => { let operand_ty = operand.ty(self.local_decls, self.vcx.tcx()); @@ -703,9 +694,9 @@ impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for ImpureEncVisitor<'tcx, 'vir } } - //mir::Rvalue::Discriminant(Place<'tcx>) => {} - //mir::Rvalue::ShallowInitBox(Operand<'tcx>, Ty<'tcx>) => {} - //mir::Rvalue::CopyForDeref(Place<'tcx>) => {} + //mir::Rvalue::Discriminant(Place<'vir>) => {} + //mir::Rvalue::ShallowInitBox(Operand<'vir>, Ty<'vir>) => {} + //mir::Rvalue::CopyForDeref(Place<'vir>) => {} other => { tracing::error!("unsupported rvalue {other:?}"); self.vcx.mk_todo_expr(vir::vir_format!(self.vcx, "rvalue {rvalue:?}")) @@ -743,7 +734,7 @@ impl<'tcx, 'vir, 'enc> mir::visit::Visitor<'tcx> for ImpureEncVisitor<'tcx, 'vir fn visit_terminator( &mut self, - terminator: &mir::Terminator<'tcx>, + terminator: &mir::Terminator<'vir>, location: mir::Location, ) { self.stmt(self.vcx.mk_comment_stmt( diff --git a/prusti-encoder/src/encoders/mir_poly_impure.rs b/prusti-encoder/src/encoders/mir_poly_impure.rs index bc5028e3d74..8bd2f075518 100644 --- a/prusti-encoder/src/encoders/mir_poly_impure.rs +++ b/prusti-encoder/src/encoders/mir_poly_impure.rs @@ -1,5 +1,5 @@ use prusti_rustc_interface::span::def_id::DefId; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; use vir::{MethodIdent, UnknownArity}; /// Encodes a Rust function as a Viper method using the polymorphic encoding of generics. @@ -45,19 +45,11 @@ impl TaskEncoder for MirPolyImpureEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - def_id: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - Ok((::encode(*def_id, deps), ())) + fn do_encode_full<'vir>( + def_id: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + ::encode(*def_id, deps) + .map(|r| (r, ())) } } diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index 40242cdec70..cebf1c40f04 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -8,6 +8,7 @@ use prusti_rustc_interface::{ use task_encoder::{ TaskEncoder, TaskEncoderDependencies, + EncodeFullResult, }; use vir::add_debug_note; use std::collections::HashMap; @@ -52,27 +53,27 @@ pub enum PureKind { } #[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct MirPureEncTask<'tcx> { +pub struct MirPureEncTask<'vir> { // TODO: depth of encoding should be in the lazy context rather than here; // can we integrate the lazy context into the identifier system? pub encoding_depth: usize, pub kind: PureKind, pub parent_def_id: DefId, // ID of the function - pub param_env: ty::ParamEnv<'tcx>, // param environment at the usage site - pub substs: ty::GenericArgsRef<'tcx>, // type substitutions at the usage site + pub param_env: ty::ParamEnv<'vir>, // param environment at the usage site + pub substs: ty::GenericArgsRef<'vir>, // type substitutions at the usage site pub caller_def_id: Option, // ID of the caller function, if any } impl TaskEncoder for MirPureEnc { task_encoder::encoder_cache!(MirPureEnc); - type TaskDescription<'tcx> = MirPureEncTask<'tcx>; + type TaskDescription<'vir> = MirPureEncTask<'vir>; - type TaskKey<'tcx> = ( + type TaskKey<'vir> = ( usize, // encoding depth PureKind, // encoding a pure function? DefId, // ID of the function - ty::GenericArgsRef<'tcx>, // ? this should be the "signature", after applying the env/substs + ty::GenericArgsRef<'vir>, // ? this should be the "signature", after applying the env/substs Option, // Caller/Use DefID ); @@ -80,7 +81,7 @@ impl TaskEncoder for MirPureEnc { type EncodingError = MirPureEncError; - fn task_to_key<'tcx>(task: &Self::TaskDescription<'tcx>) -> Self::TaskKey<'tcx> { + fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> { ( // TODO task.encoding_depth, @@ -91,17 +92,11 @@ impl TaskEncoder for MirPureEnc { ) } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { - deps.emit_output_ref::(*task_key, ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; let (_, kind, def_id, substs, caller_def_id) = *task_key; @@ -174,22 +169,21 @@ impl<'vir> Update<'vir> { } } -struct Enc<'tcx, 'vir: 'enc, 'enc> -{ +struct Enc<'vir: 'enc, 'enc> { monomorphize: bool, - vcx: &'vir vir::VirCtxt<'tcx>, + vcx: &'vir vir::VirCtxt<'vir>, encoding_depth: usize, def_id: DefId, - body: &'enc mir::Body<'tcx>, + body: &'enc mir::Body<'vir>, rev_doms: rev_doms::ReverseDominators, - deps: &'enc mut TaskEncoderDependencies<'vir>, + deps: &'enc mut TaskEncoderDependencies<'vir, MirPureEnc>, visited: IndexVec, version_ctr: IndexVec, phi_ctr: usize, } -impl <'tcx: 'vir, 'vir, 'enc> PureFuncAppEnc<'tcx, 'vir> for Enc<'tcx, 'vir, 'enc> { - fn vcx(&self) -> &'vir vir::VirCtxt<'tcx> { +impl <'vir, 'enc> PureFuncAppEnc<'vir, MirPureEnc> for Enc<'vir, 'enc> { + fn vcx(&self) -> &'vir vir::VirCtxt<'vir> { self.vcx } @@ -199,9 +193,9 @@ impl <'tcx: 'vir, 'vir, 'enc> PureFuncAppEnc<'tcx, 'vir> for Enc<'tcx, 'vir, 'en type Next = vir::ExprKind<'vir>; - type LocalDeclsSrc = Body<'tcx>; + type LocalDeclsSrc = Body<'vir>; - fn deps(&mut self) -> &mut TaskEncoderDependencies<'vir> { + fn deps(&mut self) -> &mut TaskEncoderDependencies<'vir, MirPureEnc> { self.deps } @@ -212,7 +206,7 @@ impl <'tcx: 'vir, 'vir, 'enc> PureFuncAppEnc<'tcx, 'vir> for Enc<'tcx, 'vir, 'en fn encode_operand( &mut self, args: &Self::EncodeOperandArgs, - operand: &mir::Operand<'tcx>, + operand: &mir::Operand<'vir>, ) -> vir::ExprGen<'vir, Self::Curr, Self::Next> { self.encode_operand(args, operand) } @@ -222,15 +216,14 @@ impl <'tcx: 'vir, 'vir, 'enc> PureFuncAppEnc<'tcx, 'vir> for Enc<'tcx, 'vir, 'en } } -impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc> -{ +impl<'vir: 'enc, 'enc> Enc<'vir, 'enc> { fn new( - vcx: &'vir vir::VirCtxt<'tcx>, + vcx: &'vir vir::VirCtxt<'vir>, monomorphize: bool, encoding_depth: usize, def_id: DefId, - body: &'enc mir::Body<'tcx>, - deps: &'enc mut TaskEncoderDependencies<'vir>, + body: &'enc mir::Body<'vir>, + deps: &'enc mut TaskEncoderDependencies<'vir, MirPureEnc>, ) -> Self { assert!(!body.basic_blocks.is_cfg_cyclic(), "MIR pure encoding does not support loops"); let rev_doms = rev_doms::ReverseDominators::new(&body.basic_blocks); @@ -537,7 +530,7 @@ impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc> fn encode_stmt( &mut self, curr_ver: &HashMap, - stmt: &mir::Statement<'tcx>, + stmt: &mir::Statement<'vir>, ) -> Update<'vir> { let mut update = Update::new(); match &stmt.kind { @@ -563,7 +556,7 @@ impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc> fn encode_rvalue( &mut self, curr_ver: &HashMap, - rvalue: &mir::Rvalue<'tcx>, + rvalue: &mir::Rvalue<'vir>, ) -> ExprRet<'vir> { let rvalue_ty = rvalue.ty(self.body, self.vcx.tcx()); match rvalue { @@ -718,7 +711,7 @@ impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc> fn encode_operand( &mut self, curr_ver: &HashMap, - operand: &mir::Operand<'tcx>, + operand: &mir::Operand<'vir>, ) -> ExprRet<'vir> { match operand { mir::Operand::Copy(place) @@ -731,7 +724,7 @@ impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc> fn encode_place( &mut self, curr_ver: &HashMap, - place: &mir::Place<'tcx>, + place: &mir::Place<'vir>, ) -> ExprRet<'vir> { self.encode_place_with_ref(curr_ver, place).0 } @@ -739,7 +732,7 @@ impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc> fn encode_place_with_ref( &mut self, curr_ver: &HashMap, - place: &mir::Place<'tcx>, + place: &mir::Place<'vir>, ) -> (ExprRet<'vir>, Option>) { // TODO: remove (debug) assert!(curr_ver.contains_key(&place.local)); @@ -759,8 +752,8 @@ impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc> fn encode_place_element( &mut self, - place_ty: mir::tcx::PlaceTy<'tcx>, - elem: mir::PlaceElem<'tcx>, + place_ty: mir::tcx::PlaceTy<'vir>, + elem: mir::PlaceElem<'vir>, expr: ExprRet<'vir>, place_ref: Option>, ) -> (ExprRet<'vir>, Option>) { @@ -836,7 +829,7 @@ impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc> } } - fn encode_prusti_builtin(&mut self, curr_ver: &HashMap, def_id: DefId, arg_tys: ty::GenericArgsRef<'tcx>, args: &Vec>) -> ExprRet<'vir> { + fn encode_prusti_builtin(&mut self, curr_ver: &HashMap, def_id: DefId, arg_tys: ty::GenericArgsRef<'vir>, args: &Vec>) -> ExprRet<'vir> { #[derive(Debug)] enum PrustiBuiltin { Forall, @@ -985,7 +978,7 @@ mod rev_doms { pub end: mir::BasicBlock, } impl ReverseDominators { - pub fn new<'a, 'tcx>(blocks: &'a mir::BasicBlocks<'tcx>) -> Self { + pub fn new<'a, 'vir>(blocks: &'a mir::BasicBlocks<'vir>) -> Self { let no_succ_blocks = blocks.iter_enumerated().filter(|(_, data)| data.terminator().successors().next().is_none() ).map(|(bb, _)| bb).collect(); @@ -1006,7 +999,7 @@ mod rev_doms { /// A wrapper around `mir::BasicBlocks` which reverses the direction of the /// edges. Implements `ControlFlowGraph` such that we can call `dominators`. - struct RevBasicBlocks<'a, 'tcx>(&'a mir::BasicBlocks<'tcx>, Vec); + struct RevBasicBlocks<'a, 'vir>(&'a mir::BasicBlocks<'vir>, Vec); impl DirectedGraph for RevBasicBlocks<'_, '_> { type Node = mir::BasicBlock; } diff --git a/prusti-encoder/src/encoders/mir_pure_function.rs b/prusti-encoder/src/encoders/mir_pure_function.rs index e035adba750..db2074ef718 100644 --- a/prusti-encoder/src/encoders/mir_pure_function.rs +++ b/prusti-encoder/src/encoders/mir_pure_function.rs @@ -1,6 +1,6 @@ use prusti_rustc_interface::span::def_id::DefId; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use crate::encoder_traits::pure_function_enc::{ MirFunctionEncOutput, MirFunctionEncOutputRef, PureFunctionEnc @@ -40,19 +40,10 @@ impl TaskEncoder for MirFunctionEnc { task.def_id } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { Ok((::encode(*task_key, deps), ())) } } diff --git a/prusti-encoder/src/encoders/mono/mir_impure.rs b/prusti-encoder/src/encoders/mono/mir_impure.rs index 84e681c6b1d..b5975a02ec5 100644 --- a/prusti-encoder/src/encoders/mono/mir_impure.rs +++ b/prusti-encoder/src/encoders/mono/mir_impure.rs @@ -1,5 +1,5 @@ use prusti_rustc_interface::{middle::ty::GenericArgs, span::def_id::DefId}; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use vir::{MethodIdent, UnknownArity}; /// Encodes a Rust function as a Viper method using the monomorphic encoding of generics. pub struct MirMonoImpureEnc; @@ -66,19 +66,11 @@ impl TaskEncoder for MirMonoImpureEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - Ok((::encode(*task_key, deps), ())) + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + ::encode(*task_key, deps) + .map(|r| (r, ())) } } diff --git a/prusti-encoder/src/encoders/mono/mir_pure_function.rs b/prusti-encoder/src/encoders/mono/mir_pure_function.rs index 36530879fdd..e984a882fa5 100644 --- a/prusti-encoder/src/encoders/mono/mir_pure_function.rs +++ b/prusti-encoder/src/encoders/mono/mir_pure_function.rs @@ -1,5 +1,5 @@ use prusti_rustc_interface::{middle::ty::GenericArgs, span::def_id::DefId}; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use crate::{ encoder_traits::{function_enc::FunctionEnc, pure_function_enc::{MirFunctionEncOutput, MirFunctionEncOutputRef, PureFunctionEnc}}, @@ -59,19 +59,10 @@ impl TaskEncoder for MirMonoFunctionEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { Ok((::encode(*task_key, deps), ())) } } diff --git a/prusti-encoder/src/encoders/pure/spec.rs b/prusti-encoder/src/encoders/pure/spec.rs index 35626836d85..77c4af7ba67 100644 --- a/prusti-encoder/src/encoders/pure/spec.rs +++ b/prusti-encoder/src/encoders/pure/spec.rs @@ -3,7 +3,7 @@ use prusti_rustc_interface::{ span::def_id::DefId, }; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use vir::Reify; use crate::encoders::{mir_pure::PureKind, rust_ty_predicates::RustTyPredicatesEnc, MirPureEnc}; @@ -35,32 +35,21 @@ impl TaskEncoder for MirSpecEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { let (def_id, substs, caller_def_id, pure) = *task_key; - deps.emit_output_ref::(*task_key, ()); + deps.emit_output_ref(*task_key, ())?; let local_defs = deps .require_local::(( def_id, substs, caller_def_id, - )) - .unwrap(); + ))?; let specs = deps - .require_local::(crate::encoders::SpecEncTask { def_id }) - .unwrap(); + .require_local::(crate::encoders::SpecEncTask { def_id })?; vir::with_vcx(|vcx| { let local_iter = (1..=local_defs.arg_count).map(mir::Local::from); @@ -88,8 +77,7 @@ impl TaskEncoder for MirSpecEnc { }; let to_bool = deps - .require_ref::(vcx.tcx().types.bool) - .unwrap() + .require_ref::(vcx.tcx().types.bool)? .generic_predicate .expect_prim() .snap_to_prim; diff --git a/prusti-encoder/src/encoders/spec.rs b/prusti-encoder/src/encoders/spec.rs index e0418dc38c0..6d2431da65b 100644 --- a/prusti-encoder/src/encoders/spec.rs +++ b/prusti-encoder/src/encoders/spec.rs @@ -6,6 +6,7 @@ use prusti_interface::specs::typed::{DefSpecificationMap, ProcedureSpecification use task_encoder::{ TaskEncoder, TaskEncoderDependencies, + EncodeFullResult, }; pub struct SpecEnc; @@ -75,17 +76,11 @@ impl TaskEncoder for SpecEnc { ) } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { - deps.emit_output_ref::(task_key.clone(), ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(task_key.clone(), ())?; vir::with_vcx(|vcx| { with_def_spec(|def_spec| { let specs = def_spec.get_proc_spec(&task_key.0); diff --git a/prusti-encoder/src/encoders/type/domain.rs b/prusti-encoder/src/encoders/type/domain.rs index 532258334e4..56d166a0531 100644 --- a/prusti-encoder/src/encoders/type/domain.rs +++ b/prusti-encoder/src/encoders/type/domain.rs @@ -5,8 +5,7 @@ use prusti_rustc_interface::{ }; use rustc_middle::ty::ParamTy; use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, + EncodeFullError, EncodeFullResult, TaskEncoder, TaskEncoderDependencies }; use vir::{ BinaryArity, CallableIdent, DomainParamData, FunctionIdent, NullaryArityAny, ToKnownArity, UnaryArity, UnknownArity @@ -135,16 +134,10 @@ impl TaskEncoder for DomainEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { vir::with_vcx(|vcx| { let base_name = task_key.get_vir_base_name(vcx); match task_key.kind() { @@ -157,7 +150,7 @@ impl TaskEncoder for DomainEnc { }; let mut enc = DomainEncData::new(vcx, task_key, vec![], deps); - enc.deps.emit_output_ref::(*task_key, enc.output_ref(base_name)); + enc.deps.emit_output_ref(*task_key, enc.output_ref(base_name))?; let specifics = enc.mk_prim_specifics( task_key.ty(), prim_type @@ -172,18 +165,18 @@ impl TaskEncoder for DomainEnc { .map(|ty| deps.require_local::>(ty).unwrap().expect_generic()) .collect(); let mut enc = DomainEncData::new(vcx, task_key, generics, deps); - enc.deps.emit_output_ref::(*task_key, enc.output_ref(base_name)); + enc.deps.emit_output_ref(*task_key, enc.output_ref(base_name))?; match adt.adt_kind() { ty::AdtKind::Struct => { let fields = if !adt.is_box() { let variant = adt.non_enum_variant(); - enc.mk_field_tys(variant, params) + enc.mk_field_tys(variant, params)? } else { // Box special case (this should be replaced by an // extern spec in the future) vec![ FieldTy { - ty: enc.deps.require_ref::(()).unwrap().param_snapshot, + ty: enc.deps.require_ref::(())?.param_snapshot, rust_ty_data: None } ] @@ -194,9 +187,9 @@ impl TaskEncoder for DomainEnc { ty::AdtKind::Enum => { let variants: Vec<_> = adt.discriminants(vcx.tcx()).map(|(v, d)| { let variant = adt.variant(v); - let field_tys = enc.mk_field_tys(variant, params); - (variant.name, v, field_tys, d) - }).collect(); + let field_tys = enc.mk_field_tys(variant, params)?; + Ok((variant.name, v, field_tys, d)) + }).collect::, _>>()?; let variants = if variants.is_empty() { None } else { @@ -206,8 +199,7 @@ impl TaskEncoder for DomainEnc { .any(|v| matches!(v.discr, ty::VariantDiscr::Explicit(_))); let discr_ty = adt.repr().discr_type().to_ty(vcx.tcx()); let discr_ty = enc.deps - .require_local::(discr_ty) - .unwrap() + .require_local::(discr_ty)? .generic_snapshot; Some(VariantData { discr_ty: discr_ty.snapshot, @@ -228,28 +220,30 @@ impl TaskEncoder for DomainEnc { .map(|ty| deps.require_local::>(ty).unwrap().expect_generic()) .collect(); let mut enc = DomainEncData::new(vcx, task_key, generics, deps); - enc.deps.emit_output_ref::(*task_key, enc.output_ref(base_name)); - let field_tys = params.iter().map(|ty| FieldTy::from_ty(vcx, enc.deps, ty)).collect(); + enc.deps.emit_output_ref(*task_key, enc.output_ref(base_name))?; + let field_tys = params.iter() + .map(|ty| FieldTy::from_ty(vcx, enc.deps, ty)) + .collect::, _>>()?; let specifics = enc.mk_struct_specifics(field_tys); Ok((Some(enc.finalize(task_key)), specifics)) } TyKind::Never => { let mut enc = DomainEncData::new(vcx, task_key, vec![], deps); - enc.deps.emit_output_ref::(*task_key, enc.output_ref(base_name)); + enc.deps.emit_output_ref(*task_key, enc.output_ref(base_name))?; let specifics = enc.mk_enum_specifics(None); Ok((Some(enc.finalize(task_key)), specifics)) } &TyKind::Ref(_, inner, _) => { - let generics = vec![deps.require_local::>(inner).unwrap().expect_generic()]; + let generics = vec![deps.require_local::>(inner)?.expect_generic()]; let mut enc = DomainEncData::new(vcx, task_key, generics, deps); - enc.deps.emit_output_ref::(*task_key, enc.output_ref(base_name)); - let field_tys = vec![FieldTy::from_ty(vcx, enc.deps, inner)]; + enc.deps.emit_output_ref(*task_key, enc.output_ref(base_name))?; + let field_tys = vec![FieldTy::from_ty(vcx, enc.deps, inner)?]; let specifics = enc.mk_struct_specifics(field_tys); Ok((Some(enc.finalize(task_key)), specifics)) } &TyKind::Param(_) => { - let out = deps.require_ref::(()).unwrap(); - deps.emit_output_ref::( + let out = deps.require_ref::(())?; + deps.emit_output_ref( *task_key, DomainEncOutputRef { base_name, @@ -257,13 +251,13 @@ impl TaskEncoder for DomainEnc { ty_param_accessors: &[], typeof_function: out.param_type_function, }, - ); + )?; Ok((None, DomainEncSpecifics::Param)) } &TyKind::Str => { let mut enc = DomainEncData::new(vcx, task_key, vec![], deps); let base_name = String::from("String"); - enc.deps.emit_output_ref::(*task_key, enc.output_ref(base_name)); + enc.deps.emit_output_ref(*task_key, enc.output_ref(base_name))?; let specifics = enc.mk_struct_specifics(vec![]); Ok((Some(enc.finalize(task_key)), specifics)) } @@ -273,16 +267,16 @@ impl TaskEncoder for DomainEnc { } } -pub struct VariantData<'vir, 'tcx> { +pub struct VariantData<'vir> { discr_ty: vir::Type<'vir>, discr_prim: DomainDataPrim<'vir>, /// Do any of the variants have an explicit discriminant value? has_explicit: bool, - variants: Vec<(symbol::Symbol, abi::VariantIdx, Vec>, ty::util::Discr<'tcx>)>, + variants: Vec<(symbol::Symbol, abi::VariantIdx, Vec>, ty::util::Discr<'vir>)>, } -struct DomainEncData<'vir, 'tcx, 'enc> { - vcx: &'vir vir::VirCtxt<'tcx>, +struct DomainEncData<'vir, 'enc> { + vcx: &'vir vir::VirCtxt<'vir>, domain: vir::DomainIdent<'vir, NullaryArityAny<'vir, DomainParamData<'vir>>>, generics: Vec<(ParamTy, vir::FunctionIdent<'vir, UnaryArity<'vir>>)>, typeof_function: vir::FunctionIdent<'vir, UnaryArity<'vir>>, @@ -292,15 +286,15 @@ struct DomainEncData<'vir, 'tcx, 'enc> { axioms: Vec>, functions: Vec>, generic_enc: GenericEncOutputRef<'vir>, - deps: &'enc mut TaskEncoderDependencies<'vir>, + deps: &'enc mut TaskEncoderDependencies<'vir, DomainEnc>, } -impl<'vir, 'tcx: 'vir, 'enc> DomainEncData<'vir, 'tcx, 'enc> { +impl<'vir, 'enc> DomainEncData<'vir, 'enc> { // Creation fn new( - vcx: &'vir vir::VirCtxt<'tcx>, - ty: &MostGenericTy<'tcx>, + vcx: &'vir vir::VirCtxt<'vir>, + ty: &MostGenericTy<'vir>, generics: Vec, - deps: &'enc mut TaskEncoderDependencies<'vir>, + deps: &'enc mut TaskEncoderDependencies<'vir, DomainEnc>, ) -> Self { let domain = ty.get_vir_domain_ident(vcx); let self_ty = domain.apply(vcx, []); @@ -350,20 +344,23 @@ impl<'vir, 'tcx: 'vir, 'enc> DomainEncData<'vir, 'tcx, 'enc> { pub fn mk_field_tys( &mut self, variant: &ty::VariantDef, - params: ty::GenericArgsRef<'tcx>, - ) -> Vec> { + params: ty::GenericArgsRef<'vir>, + ) -> Result< + Vec>, + EncodeFullError<'vir, DomainEnc>, + > { variant .fields .iter() .map(|f| f.ty(self.vcx.tcx(), params)) .map(|ty| FieldTy::from_ty(self.vcx, self.deps, ty)) - .collect() + .collect::, _>>() } // Creating specifics pub fn mk_prim_specifics( &mut self, - ty: ty::Ty<'tcx>, + ty: ty::Ty<'vir>, prim_type: vir::Type<'vir>, ) -> DomainEncSpecifics<'vir> { let prim_type_args = vec![FieldTy { @@ -398,7 +395,7 @@ impl<'vir, 'tcx: 'vir, 'enc> DomainEncData<'vir, 'tcx, 'enc> { } pub fn mk_enum_specifics( &mut self, - data: Option>, + data: Option>, ) -> DomainEncSpecifics<'vir> { let specifics = data.map(|data| { let discr_vals: Vec<_> = data.variants.iter().map(|(_, _, _, discr)| data.discr_prim.expr_from_bits(discr.ty, discr.val)).collect(); @@ -665,7 +662,7 @@ impl<'vir, 'tcx: 'vir, 'enc> DomainEncData<'vir, 'tcx, 'enc> { ), } } - fn finalize(mut self, ty: &MostGenericTy<'tcx>) -> vir::Domain<'vir> { + fn finalize(mut self, ty: &MostGenericTy<'vir>) -> vir::Domain<'vir> { // If this type has generics, assert a bijectivity axiom on the type // constructor: For any value of type T, with type parameters T1, ..., @@ -733,7 +730,7 @@ impl<'vir> DomainEncSpecifics<'vir> { } } impl<'vir> DomainDataPrim<'vir> { - pub fn expr_from_bits<'tcx>(&self, ty: ty::Ty<'tcx>, value: u128) -> vir::Expr<'vir> { + pub fn expr_from_bits(&self, ty: ty::Ty<'vir>, value: u128) -> vir::Expr<'vir> { match *self.prim_type { vir::TypeData::Bool => vir::with_vcx(|vcx| vcx.mk_const_expr(vir::ConstData::Bool(value != 0))), vir::TypeData::Int => { @@ -763,7 +760,7 @@ impl<'vir> DomainDataPrim<'vir> { ref k => unreachable!("{k:?}"), } } - fn bounds<'tcx>(&self, ty: ty::Ty<'tcx>) -> Option<(vir::Expr<'vir>, vir::Expr<'vir>)> { + fn bounds(&self, ty: ty::Ty<'vir>) -> Option<(vir::Expr<'vir>, vir::Expr<'vir>)> { match *self.prim_type { vir::TypeData::Bool => None, ref int@vir::TypeData::Int { .. } => { @@ -797,17 +794,18 @@ struct LiftedRustTyData<'vir> { } impl <'vir> FieldTy<'vir> { - fn from_ty<'tcx: 'vir>(vcx: &'vir vir::VirCtxt<'tcx>, deps: &mut TaskEncoderDependencies, ty: ty::Ty<'tcx>) -> FieldTy<'vir> { - let vir_ty = deps.require_local::(ty) - .unwrap() + fn from_ty(vcx: &'vir vir::VirCtxt<'vir>, deps: &mut TaskEncoderDependencies<'vir, DomainEnc>, ty: ty::Ty<'vir>) -> Result< + FieldTy<'vir>, + EncodeFullError<'vir, DomainEnc>, + > { + let vir_ty = deps.require_ref::(ty)? .generic_snapshot .snapshot; let typeof_function = deps.require_ref::( extract_type_params(vcx.tcx(), ty).0 - ).unwrap().typeof_function; - let lifted_ty = deps.require_local::>(ty) - .unwrap(); - FieldTy {ty: vir_ty, rust_ty_data: Some(LiftedRustTyData {lifted_ty, typeof_function})} + )?.typeof_function; + let lifted_ty = deps.require_local::>(ty)?; + Ok(FieldTy { ty: vir_ty, rust_ty_data: Some(LiftedRustTyData { lifted_ty, typeof_function }) }) } } diff --git a/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs b/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs index 4054c19365c..710b09d8be5 100644 --- a/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs +++ b/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs @@ -3,7 +3,7 @@ use prusti_rustc_interface::{ middle::{mir, ty::{GenericArgs, Ty}}, span::def_id::DefId, }; -use task_encoder::TaskEncoder; +use task_encoder::{TaskEncoder, EncodeFullResult}; use crate::encoders::lifted::cast::{CastArgs, CastToEnc}; @@ -79,20 +79,11 @@ impl TaskEncoder for AggregateSnapArgsCastEnc { task.clone() } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - deps.emit_output_ref::(task_key.clone(), ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(task_key.clone(), ())?; vir::with_vcx(|vcx| { let cast_functions: Vec>> = match task_key.aggregate_type { diff --git a/prusti-encoder/src/encoders/type/lifted/cast.rs b/prusti-encoder/src/encoders/type/lifted/cast.rs index 5afe520261b..81d5e0943a0 100644 --- a/prusti-encoder/src/encoders/type/lifted/cast.rs +++ b/prusti-encoder/src/encoders/type/lifted/cast.rs @@ -1,5 +1,5 @@ use prusti_rustc_interface::middle::ty; -use task_encoder::{TaskEncoder, TaskEncoderDependencies, TaskEncoderError}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, TaskEncoderError, EncodeFullResult}; use vir::{FunctionIdent, MethodIdent, StmtGen, UnknownArity, VirCtxt}; use super::{ @@ -158,15 +158,16 @@ pub struct CastToEnc(std::marker::PhantomData); impl CastToEnc where - RustTyCastersEnc: for<'tcx, 'vir> TaskEncoder< - TaskDescription<'tcx> = ty::Ty<'tcx>, + Self: TaskEncoder, + RustTyCastersEnc: for<'vir> TaskEncoder< + TaskDescription<'vir> = ty::Ty<'vir>, OutputFullLocal<'vir> = RustTyGenericCastEncOutput<'vir, Casters<'vir, T>>, >, TaskEncoderError>: Sized, { - fn encode_cast<'tcx: 'vir, 'vir>( - task_key: CastArgs<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, + fn encode_cast<'vir>( + task_key: CastArgs<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, ) -> GenericCastOutputRef<'vir, T::CastApplicator<'vir>> { let expected_is_param = matches!(task_key.expected.kind(), ty::Param(_)); let actual_is_param = matches!(task_key.actual.kind(), ty::Param(_)); @@ -210,21 +211,12 @@ impl TaskEncoder for CastToEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { let output_ref = Self::encode_cast(*task_key, deps); - deps.emit_output_ref::(*task_key, output_ref); + deps.emit_output_ref(*task_key, output_ref)?; Ok(((), ())) } } @@ -240,21 +232,12 @@ impl TaskEncoder for CastToEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { let output_ref = Self::encode_cast(*task_key, deps); - deps.emit_output_ref::(*task_key, output_ref); + deps.emit_output_ref(*task_key, output_ref)?; Ok(((), ())) } } diff --git a/prusti-encoder/src/encoders/type/lifted/cast_functions.rs b/prusti-encoder/src/encoders/type/lifted/cast_functions.rs index 6449249fe53..866f7106d7e 100644 --- a/prusti-encoder/src/encoders/type/lifted/cast_functions.rs +++ b/prusti-encoder/src/encoders/type/lifted/cast_functions.rs @@ -1,4 +1,4 @@ -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use vir::{CallableIdent, FunctionIdent, UnaryArity, UnknownArity}; use crate::encoders::{ @@ -100,31 +100,21 @@ impl TaskEncoder for CastFunctionsEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - ty: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + ty: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { if ty.is_generic() { - deps.emit_output_ref::(*ty, CastFunctionsOutputRef::AlreadyGeneric); + deps.emit_output_ref(*ty, CastFunctionsOutputRef::AlreadyGeneric)?; return Ok((&[], ())); } vir::with_vcx(|vcx| { - let domain_ref = deps.require_ref::(*ty).unwrap(); - let generic_ref = deps.require_ref::(()).unwrap(); + let domain_ref = deps.require_ref::(*ty)?; + let generic_ref = deps.require_ref::(())?; let self_ty = domain_ref.domain.apply(vcx, []); let base_name = &domain_ref.base_name; let ty_constructor = deps - .require_ref::(*ty) - .unwrap() + .require_ref::(*ty)? .ty_constructor; let make_generic_arg_tys = [self_ty]; @@ -149,13 +139,13 @@ impl TaskEncoder for CastFunctionsEnc { self_ty, ); - deps.emit_output_ref::( + deps.emit_output_ref( *ty, CastFunctionsOutputRef::CastFunctions { make_generic: make_generic_ident, make_concrete: make_concrete_ident, }, - ); + )?; let make_generic_arg = vcx.mk_local_decl("self", self_ty); let make_generic_expr = vcx.mk_local_ex(make_generic_arg.name, make_generic_arg.ty); diff --git a/prusti-encoder/src/encoders/type/lifted/casters.rs b/prusti-encoder/src/encoders/type/lifted/casters.rs index 187ab4500e0..a57b36e36ba 100644 --- a/prusti-encoder/src/encoders/type/lifted/casters.rs +++ b/prusti-encoder/src/encoders/type/lifted/casters.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use vir::{Arity, CallableIdent, FunctionIdent, MethodIdent, TypeData, UnaryArity, UnknownArity}; use crate::encoders::{ @@ -221,21 +221,12 @@ impl TaskEncoder for CastersEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - ty: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + ty: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { if ty.is_generic() { - deps.emit_output_ref::(*ty, CastFunctionsOutputRef::AlreadyGeneric); + deps.emit_output_ref(*ty, CastFunctionsOutputRef::AlreadyGeneric); return Ok((&[], ())); } vir::with_vcx(|vcx| { @@ -271,7 +262,7 @@ impl TaskEncoder for CastersEnc { self_ty, ); - deps.emit_output_ref::( + deps.emit_output_ref( *ty, CastFunctionsOutputRef::Casters { make_generic: make_generic_ident, @@ -374,21 +365,12 @@ impl TaskEncoder for CastersEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - ty: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + ty: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { if ty.is_generic() { - deps.emit_output_ref::(*ty, CastMethodsOutputRef::AlreadyGeneric); + deps.emit_output_ref(*ty, CastMethodsOutputRef::AlreadyGeneric); return Ok((&[], ())); } vir::with_vcx(|vcx| { @@ -413,7 +395,7 @@ impl TaskEncoder for CastersEnc { UnknownArity::new(arg_tys), ); - deps.emit_output_ref::( + deps.emit_output_ref( *ty, CastMethodsOutputRef::Casters { make_generic: make_generic_ident, diff --git a/prusti-encoder/src/encoders/type/lifted/func_app_ty_params.rs b/prusti-encoder/src/encoders/type/lifted/func_app_ty_params.rs index 22267a67135..b5e380bcd29 100644 --- a/prusti-encoder/src/encoders/type/lifted/func_app_ty_params.rs +++ b/prusti-encoder/src/encoders/type/lifted/func_app_ty_params.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use prusti_rustc_interface::middle::ty::{GenericArgsRef, Ty, TyKind}; -use task_encoder::TaskEncoder; +use task_encoder::{TaskEncoder, EncodeFullResult}; use super::{ generic::LiftedGeneric, @@ -28,20 +28,11 @@ impl TaskEncoder for LiftedFuncAppTyParamsEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - deps.emit_output_ref::(*task_key, ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; vir::with_vcx(|vcx| { let (monomorphize, substs) = task_key; let tys = substs.iter().filter_map(|arg| arg.as_type()); diff --git a/prusti-encoder/src/encoders/type/lifted/func_def_ty_params.rs b/prusti-encoder/src/encoders/type/lifted/func_def_ty_params.rs index 0f9041c4e18..1cbb437fd3e 100644 --- a/prusti-encoder/src/encoders/type/lifted/func_def_ty_params.rs +++ b/prusti-encoder/src/encoders/type/lifted/func_def_ty_params.rs @@ -1,6 +1,6 @@ use prusti_rustc_interface::middle::ty::{self, ParamTy, Ty, TyKind}; use std::collections::HashSet; -use task_encoder::TaskEncoder; +use task_encoder::{TaskEncoder, EncodeFullResult}; use super::generic::{LiftedGeneric, LiftedGenericEnc}; @@ -25,20 +25,11 @@ impl TaskEncoder for LiftedTyParamsEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - deps.emit_output_ref::(*task_key, ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; vir::with_vcx(|vcx| { let ty_args = task_key .iter() diff --git a/prusti-encoder/src/encoders/type/lifted/generic.rs b/prusti-encoder/src/encoders/type/lifted/generic.rs index 0dbaabeaceb..41c3af077e3 100644 --- a/prusti-encoder/src/encoders/type/lifted/generic.rs +++ b/prusti-encoder/src/encoders/type/lifted/generic.rs @@ -1,5 +1,5 @@ use prusti_rustc_interface::middle::ty; -use task_encoder::{OutputRefAny, TaskEncoder}; +use task_encoder::{OutputRefAny, TaskEncoder, EncodeFullResult}; use vir::with_vcx; use crate::encoders::GenericEnc; @@ -44,25 +44,16 @@ impl TaskEncoder for LiftedGenericEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { with_vcx(|vcx| { let output_ref = vcx.mk_local_decl( - task_key.name.as_str(), - deps.require_ref::(()).unwrap().type_snapshot, + vcx.alloc_str(task_key.name.as_str()), + deps.require_ref::(())?.type_snapshot, ); - deps.emit_output_ref::(*task_key, LiftedGeneric(output_ref)); + deps.emit_output_ref(*task_key, LiftedGeneric(output_ref))?; Ok(((), ())) }) } diff --git a/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs b/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs index 4c099910e27..7164279ba67 100644 --- a/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs +++ b/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use prusti_rustc_interface::middle::ty; -use task_encoder::{TaskEncoder, TaskEncoderError}; +use task_encoder::{TaskEncoder, TaskEncoderError, EncodeFullResult}; use vir::with_vcx; use crate::encoders::most_generic_ty::{extract_type_params, MostGenericTy}; @@ -69,15 +69,16 @@ impl<'vir, T> task_encoder::OutputRefAny for RustTyGenericCastEncOutput<'vir, T> impl RustTyCastersEnc where - CastersEnc: for<'vir, 'tcx> TaskEncoder< - TaskDescription<'tcx> = MostGenericTy<'tcx>, + Self: TaskEncoder, + CastersEnc: for<'vir> TaskEncoder< + TaskDescription<'vir> = MostGenericTy<'vir>, OutputRef<'vir> = Casters<'vir, T>, >, TaskEncoderError>: Sized, { - fn encode<'tcx: 'vir, 'vir>( - task_key: &ty::Ty<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, + fn encode<'vir>( + task_key: &ty::Ty<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, ) -> RustTyGenericCastEncOutput<'vir, Casters<'vir, T>> { with_vcx(|vcx| { let (generic_ty, args) = extract_type_params(vcx.tcx(), *task_key); @@ -112,20 +113,11 @@ impl TaskEncoder for RustTyCastersEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - deps.emit_output_ref::(*task_key, ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; Ok((Self::encode(task_key, deps), ())) } } @@ -145,20 +137,11 @@ impl TaskEncoder for RustTyCastersEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - deps.emit_output_ref::(*task_key, ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; Ok((Self::encode(task_key, deps), ())) } } diff --git a/prusti-encoder/src/encoders/type/lifted/ty.rs b/prusti-encoder/src/encoders/type/lifted/ty.rs index 86913dc7b29..3be3475ca8d 100644 --- a/prusti-encoder/src/encoders/type/lifted/ty.rs +++ b/prusti-encoder/src/encoders/type/lifted/ty.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use prusti_rustc_interface::middle::ty::{self, ParamTy, TyKind}; -use task_encoder::TaskEncoder; +use task_encoder::{TaskEncoder, EncodeFullResult}; use vir::{with_vcx, FunctionIdent, UnknownArity}; use crate::encoders::{ @@ -124,24 +124,14 @@ impl TaskEncoder for LiftedTyEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - deps.emit_output_ref::(*task_key, ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; with_vcx(|vcx| { let result = deps - .require_local::>(*task_key) - .unwrap(); + .require_local::>(*task_key)?; let result = result.map(vcx, &mut |g| { deps.require_ref::(g).unwrap() }); @@ -167,28 +157,18 @@ impl TaskEncoder for LiftedTyEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - deps.emit_output_ref::(*task_key, ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; with_vcx(|vcx| { if let TyKind::Param(p) = task_key.kind() { return Ok((LiftedTy::Generic(*p), ())); } let (ty_constructor, args) = extract_type_params(vcx.tcx(), *task_key); let ty_constructor = deps - .require_ref::(ty_constructor) - .unwrap() + .require_ref::(ty_constructor)? .ty_constructor; let args = args .into_iter() diff --git a/prusti-encoder/src/encoders/type/lifted/ty_constructor.rs b/prusti-encoder/src/encoders/type/lifted/ty_constructor.rs index e47254e421f..70b6bea875b 100644 --- a/prusti-encoder/src/encoders/type/lifted/ty_constructor.rs +++ b/prusti-encoder/src/encoders/type/lifted/ty_constructor.rs @@ -1,4 +1,4 @@ -use task_encoder::{OutputRefAny, TaskEncoder}; +use task_encoder::{OutputRefAny, TaskEncoder, EncodeFullResult}; use vir::{ vir_format_identifier, CallableIdent, FunctionIdent, UnaryArity, UnknownArity }; @@ -53,20 +53,11 @@ impl TaskEncoder for TyConstructorEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - let generic_ref = deps.require_ref::(()).unwrap(); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + let generic_ref = deps.require_ref::(())?; let mut functions = vec![]; let mut axioms = vec![]; vir::with_vcx(|vcx| { @@ -111,13 +102,13 @@ impl TaskEncoder for TyConstructorEnc { ) }) .collect::>(); - deps.emit_output_ref::( + deps.emit_output_ref( *task_key, TyConstructorEncOutputRef { ty_constructor: type_function_ident, ty_param_accessors: vcx.alloc_slice(&ty_accessor_functions), }, - ); + )?; let axiom_qvars = vcx.alloc_slice(&ty_arg_decls); let axiom_triggers = vcx.alloc_slice( diff --git a/prusti-encoder/src/encoders/type/predicate.rs b/prusti-encoder/src/encoders/type/predicate.rs index e8eefe5644b..ac0a526a0f1 100644 --- a/prusti-encoder/src/encoders/type/predicate.rs +++ b/prusti-encoder/src/encoders/type/predicate.rs @@ -2,7 +2,7 @@ use prusti_rustc_interface::{ abi, middle::ty::{self, TyKind}, }; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use vir::{ add_debug_note, CallableIdent, FunctionIdent, MethodIdent, NullaryArity, PredicateIdent, TypeData, UnaryArity, UnknownArity, VirCtxt, @@ -198,21 +198,12 @@ impl TaskEncoder for PredicateEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { - let snap = deps.require_local::(*task_key).unwrap(); - let generic_output_ref = deps.require_ref::(()).unwrap(); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + let snap = deps.require_local::(*task_key)?; + let generic_output_ref = deps.require_ref::(())?; let mut enc = vir::with_vcx(|vcx| { PredicateEncValues::new(vcx, &snap.base_name, snap.snapshot, snap.generics) }); @@ -228,7 +219,7 @@ impl TaskEncoder for PredicateEnc { ])), ) }); - deps.emit_output_ref::( + deps.emit_output_ref( *task_key, PredicateEncOutputRef { ref_to_pred: generic_output_ref.ref_to_pred.as_unknown_arity(), @@ -240,7 +231,7 @@ impl TaskEncoder for PredicateEnc { generics: &[], }, ); - let dep = deps.require_local::(()).unwrap(); + let dep = deps.require_local::(())?; vir::with_vcx(|vcx| { let method_assign = mk_method_assign( vcx, @@ -265,13 +256,13 @@ impl TaskEncoder for PredicateEnc { } TyKind::Bool | TyKind::Char | TyKind::Int(_) | TyKind::Uint(_) | TyKind::Float(_) => { let specifics = PredicateEncData::Primitive(snap.specifics.expect_primitive()); - deps.emit_output_ref::(*task_key, enc.output_ref(specifics)); + deps.emit_output_ref(*task_key, enc.output_ref(specifics)); Ok((enc.mk_prim(&snap.base_name), ())) } TyKind::Tuple(tys) => { let snap_data = snap.specifics.expect_structlike(); let specifics = enc.mk_struct_ref(None, snap_data); - deps.emit_output_ref::( + deps.emit_output_ref( *task_key, enc.output_ref(PredicateEncData::StructLike(specifics)), ); @@ -289,7 +280,7 @@ impl TaskEncoder for PredicateEnc { ty::AdtKind::Struct => { let snap_data = snap.specifics.expect_structlike(); let specifics = enc.mk_struct_ref(None, snap_data); - deps.emit_output_ref::( + deps.emit_output_ref( *task_key, enc.output_ref(PredicateEncData::StructLike(specifics)), ); @@ -316,7 +307,7 @@ impl TaskEncoder for PredicateEnc { } ty::AdtKind::Enum => { let specifics = enc.mk_enum_ref(snap.specifics.expect_enumlike()); - deps.emit_output_ref::( + deps.emit_output_ref( *task_key, enc.output_ref(PredicateEncData::EnumLike(specifics)), ); @@ -348,7 +339,7 @@ impl TaskEncoder for PredicateEnc { TyKind::Never => { let specifics = enc.mk_enum_ref(snap.specifics.expect_enumlike()); assert!(specifics.is_none()); - deps.emit_output_ref::( + deps.emit_output_ref( *task_key, enc.output_ref(PredicateEncData::EnumLike(None)), ); @@ -358,22 +349,21 @@ impl TaskEncoder for PredicateEnc { &TyKind::Ref(_, inner, m) => { let snap_data = snap.specifics.expect_structlike(); let specifics = enc.mk_ref_ref(snap_data, m.is_mut()); - deps.emit_output_ref::( + deps.emit_output_ref( *task_key, enc.output_ref(PredicateEncData::Ref(specifics)), ); let lifted_ty = deps.require_local::>(inner).unwrap(); let inner = deps - .require_ref::(inner) - .unwrap() + .require_ref::(inner)? .generic_predicate; Ok((enc.mk_ref(inner, lifted_ty, specifics), ())) } TyKind::Str => { let specifics = enc.mk_struct_ref(None, snap.specifics.expect_structlike()); - deps.emit_output_ref::( + deps.emit_output_ref( *task_key, enc.output_ref(PredicateEncData::StructLike(specifics)), ); diff --git a/prusti-encoder/src/encoders/type/rust_ty_predicates.rs b/prusti-encoder/src/encoders/type/rust_ty_predicates.rs index d83831719ac..2f4c44eba39 100644 --- a/prusti-encoder/src/encoders/type/rust_ty_predicates.rs +++ b/prusti-encoder/src/encoders/type/rust_ty_predicates.rs @@ -1,5 +1,5 @@ use prusti_rustc_interface::middle::ty::{self}; -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use vir::{with_vcx, Type, TypeData}; use crate::encoders::{PredicateEnc, PredicateEncOutputRef}; @@ -89,34 +89,24 @@ impl TaskEncoder for RustTyPredicatesEnc { type OutputRef<'vir> = RustTyPredicatesEncOutputRef<'vir>; type OutputFullLocal<'vir> = (); - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { with_vcx(|vcx| { let (generic_ty, args) = extract_type_params(vcx.tcx(), *task_key); - let generic_predicate = deps.require_ref::(generic_ty).unwrap(); + let generic_predicate = deps.require_ref::(generic_ty)?; let ty = deps - .require_local::>(*task_key) - .unwrap(); - deps.emit_output_ref::( + .require_local::>(*task_key)?; + deps.emit_output_ref( *task_key, RustTyPredicatesEncOutputRef { generic_predicate, ty, }, - ); + )?; for arg in args { - deps.require_ref::(arg).unwrap(); + deps.require_ref::(arg)?; } Ok(((), ())) }) diff --git a/prusti-encoder/src/encoders/type/rust_ty_snapshots.rs b/prusti-encoder/src/encoders/type/rust_ty_snapshots.rs index ae8ae0518ab..60a50c6a909 100644 --- a/prusti-encoder/src/encoders/type/rust_ty_snapshots.rs +++ b/prusti-encoder/src/encoders/type/rust_ty_snapshots.rs @@ -1,5 +1,5 @@ use prusti_rustc_interface::middle::ty; -use task_encoder::TaskEncoder; +use task_encoder::{TaskEncoder, EncodeFullResult}; use vir::with_vcx; use crate::encoders::SnapshotEnc; @@ -36,32 +36,22 @@ impl TaskEncoder for RustTySnapshotsEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut task_encoder::TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { with_vcx(|vcx| { let (generic_ty, args) = extract_type_params(vcx.tcx(), *task_key); - let generic_snapshot = deps.require_ref::(generic_ty).unwrap(); - deps.emit_output_ref::( + let generic_snapshot = deps.require_ref::(generic_ty)?; + deps.emit_output_ref( *task_key, RustTySnapshotsEncOutputRef { generic_snapshot }, - ); + )?; for arg in args { - deps.require_ref::(arg).unwrap(); + deps.require_ref::(arg)?; } let generic_snapshot = deps - .require_local::(generic_ty) - .unwrap(); + .require_local::(generic_ty)?; Ok((RustTySnapshotsEncOutput { generic_snapshot }, ())) }) } diff --git a/prusti-encoder/src/encoders/type/snapshot.rs b/prusti-encoder/src/encoders/type/snapshot.rs index 6265f72b86d..e150f9abc71 100644 --- a/prusti-encoder/src/encoders/type/snapshot.rs +++ b/prusti-encoder/src/encoders/type/snapshot.rs @@ -1,4 +1,4 @@ -use task_encoder::{TaskEncoder, TaskEncoderDependencies}; +use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult}; use super::{domain::{DomainEnc, DomainEncSpecifics}, lifted::generic::{LiftedGeneric, LiftedGenericEnc}, most_generic_ty::MostGenericTy}; @@ -33,29 +33,20 @@ impl TaskEncoder for SnapshotEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - ty: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result< - ( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), - ( - Self::EncodingError, - Option>, - ), - > { + fn do_encode_full<'vir>( + ty: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { vir::with_vcx(|vcx| { - let out = deps.require_ref::(*ty).unwrap(); + let out = deps.require_ref::(*ty)?; let snapshot = out.domain.apply(vcx, []); - deps.emit_output_ref::( + deps.emit_output_ref( *ty, SnapshotEncOutputRef { snapshot, }, - ); - let specifics = deps.require_dep::(*ty).unwrap(); + )?; + let specifics = deps.require_dep::(*ty)?; let generics = vcx.alloc_slice( &ty.generics() .into_iter() diff --git a/prusti-encoder/src/encoders/type/viper_tuple.rs b/prusti-encoder/src/encoders/type/viper_tuple.rs index 63598a91335..82ad8f83d32 100644 --- a/prusti-encoder/src/encoders/type/viper_tuple.rs +++ b/prusti-encoder/src/encoders/type/viper_tuple.rs @@ -1,6 +1,7 @@ use task_encoder::{ TaskEncoder, TaskEncoderDependencies, + EncodeFullResult, }; use super::{domain::{DomainDataStruct, DomainEnc}, most_generic_ty::MostGenericTy}; @@ -47,21 +48,15 @@ impl TaskEncoder for ViperTupleEnc { *task } - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )> { - deps.emit_output_ref::(*task_key, ()); + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; if *task_key == 1 { Ok((ViperTupleEncOutput { tuple: None }, ())) } else { - let ret = deps.require_dep::(MostGenericTy::tuple(*task_key)).unwrap(); + let ret = deps.require_dep::(MostGenericTy::tuple(*task_key))?; Ok((ViperTupleEncOutput { tuple: Some(ret.expect_structlike()) }, ())) } } diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index 5b12cef1ceb..bb01c9cf99b 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -52,7 +52,7 @@ pub fn test_entrypoint<'tcx>( }).unwrap_or_default(); if !(is_trusted && is_pure) { - let res = MirPolyImpureEnc::encode(def_id); + let res = MirPolyImpureEnc::encode(def_id, false); assert!(res.is_ok()); } } diff --git a/task-encoder/src/lib.rs b/task-encoder/src/lib.rs index c2fab0eba50..1e39303851c 100644 --- a/task-encoder/src/lib.rs +++ b/task-encoder/src/lib.rs @@ -1,7 +1,7 @@ #![feature(associated_type_defaults)] use hashlink::LinkedHashMap; -use std::cell::RefCell; +use std::{cell::RefCell, marker::PhantomData}; pub trait OutputRefAny {} impl OutputRefAny for () {} @@ -23,7 +23,7 @@ pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { /// TODO: can still collect errors? Encoded { output_ref: ::OutputRef<'vir>, - deps: TaskEncoderDependencies<'vir>, + deps: TaskEncoderDependencies<'vir, E>, output_local: ::OutputFullLocal<'vir>, output_dep: ::OutputFullDependency<'vir>, }, @@ -41,7 +41,7 @@ pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { /// to encode its signature, to be included in dependents' programs. ErrorEncode { output_ref: ::OutputRef<'vir>, - deps: TaskEncoderDependencies<'vir>, + deps: TaskEncoderDependencies<'vir, E>, error: TaskEncoderError, output_dep: Option<::OutputFullDependency<'vir>>, }, @@ -49,11 +49,11 @@ pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { /// Cache for a task encoder. See `TaskEncoderCacheState` for a description of /// the possible values in the encoding process. -pub type Cache<'tcx, 'vir, E> = LinkedHashMap< - ::TaskKey<'tcx>, +pub type Cache<'vir, E> = LinkedHashMap< + ::TaskKey<'vir>, TaskEncoderCacheState<'vir, E>, >; -pub type CacheRef<'tcx, 'vir, E> = RefCell>; +pub type CacheRef<'vir, E> = RefCell>; pub type CacheStatic = LinkedHashMap< ::TaskKey<'static>, @@ -77,6 +77,43 @@ impl<'vir, E: TaskEncoder> TaskEncoderOutput<'vir, E> { } } */ + +/// The result of the actual encoder implementation (`do_encode_full`). +pub type EncodeFullResult<'vir, E: TaskEncoder + 'vir + ?Sized> = Result<( + E::OutputFullLocal<'vir>, + E::OutputFullDependency<'vir>, +), EncodeFullError<'vir, E>>; + +/// An unsuccessful result occurring in `do_encode_full`. +pub enum EncodeFullError<'vir, E: TaskEncoder + 'vir + ?Sized> { + /// Indicates that the current task has already been encoded. This can + /// occur when there are cyclic dependencies between multiple encoders. + /// This error is specifically returned when one encoder depends on + /// another encoder (using e.g. `TaskEncoderDependencies::require_ref`), + /// that latter encoder then depending on the former again, causing the + /// former encoder to complete its full encoding in the inner invocation. + /// The outer invocation remains on the stack, but will be aborted early + /// as soon as the control flow returns to it. + AlreadyEncoded, + + /// An actual error occurred during encoding. + EncodingError(::EncodingError, Option>), + + DependencyError, +} + +// Manual implementation, since neither `E` nor `E::OutputFullDependency` are +// required to be `Debug`. +impl<'vir, E: TaskEncoder + 'vir + ?Sized> std::fmt::Debug for EncodeFullError<'vir, E> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::AlreadyEncoded => write!(f, "AlreadyEncoded"), + Self::EncodingError(err, _output_dep) => f.debug_tuple("EncodingError").field(err)/*.field(output_dep)*/.finish(), + Self::DependencyError => write!(f, "DependencyError"), + } + } +} + pub enum TaskEncoderError { EnqueueingError(::EnqueueingError), EncodingError(::EncodingError), @@ -109,64 +146,102 @@ impl Clone for TaskEncoderError { } } -#[derive(Default)] -pub struct TaskEncoderDependencies<'a> { - pub deps_local: Vec<&'a dyn OutputRefAny>, - pub deps_dep: Vec<&'a dyn OutputRefAny>, +pub struct TaskEncoderDependencies<'vir, E: TaskEncoder + 'vir + ?Sized> { + _marker: PhantomData, + task_key: Option>, + pub deps_local: Vec<&'vir dyn OutputRefAny>, + pub deps_dep: Vec<&'vir dyn OutputRefAny>, } -impl<'a> TaskEncoderDependencies<'a> { - pub fn require_ref<'vir, 'tcx: 'vir, E: TaskEncoder>( +impl<'vir, E: TaskEncoder + 'vir + ?Sized> TaskEncoderDependencies<'vir, E> { + fn check_cycle(&self) -> Result<(), EncodeFullError<'vir, E>> { + if let Some(task_key) = self.task_key.as_ref() { + if E::with_cache(move |cache| matches!( + cache.borrow().get(task_key), + Some(TaskEncoderCacheState::Encoded { .. } + | TaskEncoderCacheState::ErrorEncode { .. } + | TaskEncoderCacheState::ErrorEnqueue { .. }), + )) { + return Err(EncodeFullError::AlreadyEncoded); + } + } + Ok(()) + } + + pub fn require_ref( &mut self, - task: ::TaskDescription<'tcx>, + task: ::TaskDescription<'vir>, ) -> Result< - ::OutputRef<'vir>, - TaskEncoderError, + ::OutputRef<'vir>, + EncodeFullError<'vir, E>, > { - E::encode_ref(task) + EOther::encode_ref(task) + .map_err(|_| EncodeFullError::DependencyError) + .and_then(|result| { + self.check_cycle()?; + Ok(result) + }) } - pub fn require_local<'vir, 'tcx: 'vir, E: TaskEncoder + 'vir>( + pub fn require_local( &mut self, - task: ::TaskDescription<'tcx>, + task: ::TaskDescription<'vir>, ) -> Result< - ::OutputFullLocal<'vir>, - TaskEncoderError, + ::OutputFullLocal<'vir>, + EncodeFullError<'vir, E>, > { - E::encode(task).map(|(_output_ref, output_local, _output_dep)| output_local) + EOther::encode(task, true) + .map(Option::unwrap) + .map(|(_output_ref, output_local, _output_dep)| output_local) + .map_err(|_| EncodeFullError::DependencyError) + .and_then(|result| { + self.check_cycle()?; + Ok(result) + }) } - pub fn require_dep<'vir, 'tcx: 'vir, E: TaskEncoder + 'vir>( + pub fn require_dep( &mut self, - task: ::TaskDescription<'tcx>, + task: ::TaskDescription<'vir>, ) -> Result< - ::OutputFullDependency<'vir>, - TaskEncoderError, + ::OutputFullDependency<'vir>, + EncodeFullError<'vir, E>, > { - E::encode(task).map(|(_output_ref, _output_local, output_dep)| output_dep) + EOther::encode(task, true) + .map(Option::unwrap) + .map(|(_output_ref, _output_local, output_dep)| output_dep) + .map_err(|_| EncodeFullError::DependencyError) + .and_then(|result| { + self.check_cycle()?; + Ok(result) + }) } - pub fn emit_output_ref<'vir, 'tcx: 'vir, E: TaskEncoder + 'vir>( + pub fn emit_output_ref( &mut self, - task_key: E::TaskKey<'tcx>, + task_key: E::TaskKey<'vir>, output_ref: E::OutputRef<'vir>, - ) { + ) -> Result<(), EncodeFullError<'vir, E>> { + assert!(self.task_key.replace(task_key.clone()).is_none(), "output ref already set for task key {task_key:?}"); + self.check_cycle()?; assert!(E::with_cache(move |cache| matches!(cache.borrow_mut().insert( task_key, TaskEncoderCacheState::Started { output_ref }, - ), Some(TaskEncoderCacheState::Enqueued)))); + ), Some(TaskEncoderCacheState::Enqueued + | TaskEncoderCacheState::Started { .. })))); + Ok(()) } } pub trait TaskEncoder { /// Description of a task to be performed. Should be easily obtained by /// clients of this encoder. - type TaskDescription<'tcx>: std::hash::Hash + Eq + Clone + std::fmt::Debug; + type TaskDescription<'vir>: std::hash::Hash + Eq + Clone + std::fmt::Debug; /// Cache key for a task to be performed. May differ from `TaskDescription`, /// for example if the description should be normalised or some non-trivial /// resolution needs to happen. In other words, multiple descriptions may /// lead to the same key and hence the same output. - type TaskKey<'tcx>: std::hash::Hash + Eq + Clone + std::fmt::Debug = Self::TaskDescription<'tcx>; + type TaskKey<'vir>: std::hash::Hash + Eq + Clone + std::fmt::Debug = Self::TaskDescription<'vir>; /// A reference to an encoded item. Should be non-unit for tasks which can /// be "referred" to from other parts of a program, as opposed to tasks @@ -177,20 +252,22 @@ pub trait TaskEncoder { /// Fully encoded output for this task. When encoding items which can be /// dependencies (such as methods), this output should only be emitted in /// one Viper program. - type OutputFullLocal<'vir>: Clone; + type OutputFullLocal<'vir>: Clone + where Self: 'vir; /// Fully encoded output for this task for dependents. When encoding items /// which can be dependencies (such as methods), this output should be /// emitted in each Viper program that depends on this task. - type OutputFullDependency<'vir>: Clone = (); + type OutputFullDependency<'vir>: Clone = () + where Self: 'vir; type EnqueueingError: Clone + std::fmt::Debug = (); type EncodingError: Clone + std::fmt::Debug; /// Enters the given function with a reference to the cache for this /// encoder. - fn with_cache<'tcx: 'vir, 'vir, F, R>(f: F) -> R - where Self: 'vir, F: FnOnce(&'vir CacheRef<'tcx, 'vir, Self>) -> R; + fn with_cache<'vir, F, R>(f: F) -> R + where Self: 'vir, F: FnOnce(&'vir CacheRef<'vir, Self>) -> R; //fn get_all_outputs() -> Self::CacheRef<'vir> { // todo!() @@ -215,7 +292,7 @@ pub trait TaskEncoder { ).is_none())); } - fn encode_ref<'tcx: 'vir, 'vir>(task: Self::TaskDescription<'tcx>) -> Result< + fn encode_ref<'vir>(task: Self::TaskDescription<'vir>) -> Result< Self::OutputRef<'vir>, TaskEncoderError, > @@ -234,22 +311,13 @@ pub trait TaskEncoder { return Ok(output_ref); } - // is the task enqueued already? - let task_key_clone = task_key.clone(); - if Self::with_cache(move |cache| cache.borrow().contains_key(&task_key_clone)) { - // Cyclic dependency error because: - // 1. An ouput ref was requested for the task, - // 2. the task was already enqueued, and - // 3. there is not an output ref available. - // - // This would happen if the current encoder directly or indirectly - // requested the encoding for a task it is already working on, - // before it called the `emit_output_ref` method. - return Err(TaskEncoderError::CyclicError); - } - - // otherwise, we need to start the encoding - Self::encode(task)?; + // Otherwise, we need to start the encoding. Note that this is done + // even if the encoding was started previously, i.e. if the cache + // contains a `Enqueued` entry for this task. This can happen if the + // same task was (recursively) requested from the same encoder, before + // its first invocation reached a call to `emit_output_ref`. + // TODO: we should still make sure that *some* progress is done, because an actual cyclic dependency could cause a stack overflow? + Self::encode(task, false)?; let task_key_clone = task_key.clone(); if let Some(output_ref) = Self::with_cache(move |cache| match cache.borrow().get(&task_key_clone) { @@ -264,11 +332,11 @@ pub trait TaskEncoder { panic!("output ref not found after encoding") // TODO: error? } - fn encode<'tcx: 'vir, 'vir>(task: Self::TaskDescription<'tcx>) -> Result<( + fn encode<'vir>(task: Self::TaskDescription<'vir>, need_output: bool) -> Result, Self::OutputFullLocal<'vir>, Self::OutputFullDependency<'vir>, - ), TaskEncoderError> + )>, TaskEncoderError> where Self: 'vir { let task_key = Self::task_to_key(&task); @@ -285,13 +353,17 @@ pub trait TaskEncoder { output_local, output_dep, .. - } => Some(Ok(( - output_ref.clone(), - output_local.clone(), - output_dep.clone(), - ))), - TaskEncoderCacheState::Enqueued | TaskEncoderCacheState::Started { .. } => - panic!("Encoding already started or enqueued"), + } => if need_output { + Some(Ok(Some(( + output_ref.clone(), + output_local.clone(), + output_dep.clone(), + )))) + } else { + Some(Ok(None)) + } + // TODO: should we return Some(Ok(None)) for `Started`, if `!need_output` ? + TaskEncoderCacheState::Enqueued | TaskEncoderCacheState::Started { .. } => None, }, None => { // enqueue @@ -304,29 +376,68 @@ pub trait TaskEncoder { return in_cache; } - let mut deps = TaskEncoderDependencies::default(); + let mut deps = TaskEncoderDependencies { + _marker: PhantomData, + task_key: None, + deps_local: vec![], + deps_dep: vec![], + }; let encode_result = Self::do_encode_full(&task_key, &mut deps); let output_ref = Self::with_cache(|cache| match cache.borrow().get(&task_key) { - Some(TaskEncoderCacheState::Started { output_ref }) => output_ref.clone(), + Some(TaskEncoderCacheState::Started { output_ref } + | TaskEncoderCacheState::Encoded { output_ref, .. }) => output_ref.clone(), _ => panic!("encoder did not provide output ref for task {task_key:?}"), }); match encode_result { Ok((output_local, output_dep)) => { - Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { - output_ref: output_ref.clone(), - deps, - output_local: output_local.clone(), - output_dep: output_dep.clone(), - })); - Ok(( + if need_output { + Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { + output_ref: output_ref.clone(), + deps, + output_local: output_local.clone(), + output_dep: output_dep.clone(), + })); + Ok(Some(( + output_ref, + output_local, + output_dep, + ))) + } else { + Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::Encoded { + output_ref: output_ref, + deps, + output_local: output_local, + output_dep: output_dep, + })); + Ok(None) + } + } + Err(EncodeFullError::AlreadyEncoded) => Self::with_cache(|cache| match cache.borrow().get(&task_key).unwrap() { + TaskEncoderCacheState::Encoded { output_ref, output_local, output_dep, - )) - } - Err((err, maybe_output_dep)) => { + .. + } => if need_output { + Ok(Some(( + // TODO: does it even make sense for an encoder to request the full encoding + // when a cycle can occur? + output_ref.clone(), + output_local.clone(), + output_dep.clone(), + ))) + } else { + Ok(None) + }, + TaskEncoderCacheState::ErrorEnqueue { error } + | TaskEncoderCacheState::ErrorEncode { error, .. } => Err(error.clone()), + TaskEncoderCacheState::Started { .. } + | TaskEncoderCacheState::Enqueued => panic!("encoder did not finish for task {task_key:?}"), + }), + Err(EncodeFullError::DependencyError) => todo!(), + Err(EncodeFullError::EncodingError(err, maybe_output_dep)) => { Self::with_cache(|cache| cache.borrow_mut().insert(task_key, TaskEncoderCacheState::ErrorEncode { output_ref: output_ref.clone(), deps, @@ -449,16 +560,10 @@ pub trait TaskEncoder { /// Given a task description, create a reference to the output. fn task_to_output_ref<'vir>(task: &Self::TaskDescription<'vir>) -> Self::OutputRef<'vir>; */ - fn do_encode_full<'tcx: 'vir, 'vir>( - task_key: &Self::TaskKey<'tcx>, - deps: &mut TaskEncoderDependencies<'vir>, - ) -> Result<( - Self::OutputFullLocal<'vir>, - Self::OutputFullDependency<'vir>, - ), ( - Self::EncodingError, - Option>, - )>; + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self>; fn all_outputs<'vir>() -> Vec> where Self: 'vir @@ -489,8 +594,8 @@ pub trait TaskEncoder { #[macro_export] macro_rules! encoder_cache { ($encoder: ty) => { - fn with_cache<'tcx: 'vir, 'vir, F, R>(f: F) -> R - where F: FnOnce(&'vir $crate::CacheRef<'tcx, 'vir, $encoder>) -> R, + fn with_cache<'vir, F, R>(f: F) -> R + where F: FnOnce(&'vir $crate::CacheRef<'vir, $encoder>) -> R, { ::std::thread_local! { static CACHE: $crate::CacheStaticRef<$encoder> = ::std::cell::RefCell::new(Default::default());