Skip to content

Commit

Permalink
Fix generic module failures
Browse files Browse the repository at this point in the history
  • Loading branch information
mamonet committed Nov 12, 2024
1 parent 2c93acc commit 354e3ef
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ let _ =
let open Libcrux_ml_kem.Vector.Traits in
()

#push-options "--z3rlimit 300"

let validate_private_key
(v_K v_SECRET_KEY_SIZE v_CIPHERTEXT_SIZE: usize)
(#v_Hasher: Type0)
Expand Down Expand Up @@ -44,6 +46,8 @@ let validate_private_key
in
t =. expected

#pop-options

#push-options "--z3rlimit 150"

let serialize_kem_secret_key
Expand Down
224 changes: 125 additions & 99 deletions libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Ind_cpa.fst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,53 @@ let _ =
let open Libcrux_ml_kem.Vector.Traits in
()

#push-options "--z3rlimit 200"

let prf_input_inc (v_K: usize) (prf_inputs: t_Array (t_Array u8 (sz 33)) v_K) (domain_separator: u8) =
let v__domain_separator_init:u8 = domain_separator in
let v__prf_inputs_init:t_Array (t_Array u8 (sz 33)) v_K =
Core.Clone.f_clone #(t_Array (t_Array u8 (sz 33)) v_K)
#FStar.Tactics.Typeclasses.solve
prf_inputs
in
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) =
Rust_primitives.Hax.Folds.fold_range (sz 0)
v_K
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
let i:usize = i in
v domain_separator == v v__domain_separator_init + v i /\
(v i < v v_K ==>
(forall (j: nat).
(j >= v i /\ j < v v_K) ==> prf_inputs.[ sz j ] == v__prf_inputs_init.[ sz j ])) /\
(forall (j: nat).
j < v i ==>
v (Seq.index (Seq.index prf_inputs j) 32) == v v__domain_separator_init + j /\
Seq.slice (Seq.index prf_inputs j) 0 32 ==
Seq.slice (Seq.index v__prf_inputs_init j) 0 32))
(domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
let i:usize = i in
let prf_inputs:t_Array (t_Array u8 (sz 33)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize prf_inputs
i
(Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (prf_inputs.[ i ]
<:
t_Array u8 (sz 33))
(sz 32)
domain_separator
<:
t_Array u8 (sz 33))
in
let domain_separator:u8 = domain_separator +! 1uy in
domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
in
let hax_temp_output:u8 = domain_separator in
prf_inputs, hax_temp_output <: (t_Array (t_Array u8 (sz 33)) v_K & u8)

#pop-options

#push-options "--ext context_pruning"

let deserialize_secret_key
Expand Down Expand Up @@ -80,7 +127,7 @@ let deserialize_secret_key

#pop-options

#push-options "--max_fuel 10 --z3rlimit 1000 --ext context_pruning --z3refresh --split_queries always"
#push-options "--max_fuel 15 --z3rlimit 1500 --ext context_pruning --z3refresh --split_queries always"

let sample_ring_element_cbd
(v_K v_ETA2_RANDOMNESS_SIZE v_ETA2: usize)
Expand All @@ -105,40 +152,11 @@ let sample_ring_element_cbd
in
let prf_inputs:t_Array (t_Array u8 (sz 33)) v_K = Rust_primitives.Hax.repeat prf_input v_K in
let v__domain_separator_init:u8 = domain_separator in
let v__prf_inputs_init:t_Array (t_Array u8 (sz 33)) v_K = prf_inputs in
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) =
Rust_primitives.Hax.Folds.fold_range (sz 0)
v_K
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
let i:usize = i in
v domain_separator == v v__domain_separator_init + v i /\
(v i < v v_K ==>
(forall (j: nat).
(j >= v i /\ j < v v_K) ==> prf_inputs.[ sz j ] == v__prf_inputs_init.[ sz j ])) /\
(forall (j: nat).
j < v i ==>
v (Seq.index (Seq.index prf_inputs j) 32) == v v__domain_separator_init + j /\
Seq.slice (Seq.index prf_inputs j) 0 32 ==
Seq.slice (Seq.index v__prf_inputs_init j) 0 32))
(domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
let i:usize = i in
let prf_inputs:t_Array (t_Array u8 (sz 33)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize prf_inputs
i
(Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (prf_inputs.[ i ]
<:
t_Array u8 (sz 33))
(sz 32)
domain_separator
<:
t_Array u8 (sz 33))
in
let domain_separator:u8 = domain_separator +! 1uy in
domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
let tmp0, out:(t_Array (t_Array u8 (sz 33)) v_K & u8) =
prf_input_inc v_K prf_inputs domain_separator
in
let prf_inputs:t_Array (t_Array u8 (sz 33)) v_K = tmp0 in
let domain_separator:u8 = out in
let _:Prims.unit =
let lemma_aux (i: nat{i < v v_K})
: Lemma
Expand Down Expand Up @@ -212,7 +230,60 @@ let sample_ring_element_cbd

#pop-options

#push-options "--max_fuel 10 --z3rlimit 1000 --ext context_pruning --z3refresh --split_queries always"
let sample_vector_cbd_then_ntt_helper_1
(v_K: usize)
(prf_inputs: t_Array (t_Array u8 (sz 33)) v_K)
(prf_input: t_Array u8 (sz 33))
(domain_separator: u8) : Lemma
(requires Spec.MLKEM.is_rank v_K /\ v domain_separator < 2 * v v_K /\
(forall (i: nat). i < v v_K ==>
v (Seq.index (Seq.index prf_inputs i) 32) == v domain_separator + i /\
Seq.slice (Seq.index prf_inputs i) 0 32 == Seq.slice prf_input 0 32))
(ensures prf_inputs == createi v_K
(Spec.MLKEM.sample_vector_cbd1_prf_input #v_K
(Seq.slice prf_input 0 32) (sz (v domain_separator))))
=
let lemma_aux (i: nat{i < v v_K}) : Lemma
(prf_inputs.[ sz i ] == (Seq.append (Seq.slice prf_input 0 32) (Seq.create 1
(mk_int #u8_inttype (v (domain_separator +! (mk_int #u8_inttype i))))))) =
Lib.Sequence.eq_intro #u8 #33 prf_inputs.[ sz i ]
(Seq.append (Seq.slice prf_input 0 32)
(Seq.create 1 (mk_int #u8_inttype (v domain_separator + i))))
in
Classical.forall_intro lemma_aux;
Lib.Sequence.eq_intro #(t_Array u8 (sz 33)) #(v v_K) prf_inputs
(createi v_K (Spec.MLKEM.sample_vector_cbd1_prf_input #v_K
(Seq.slice prf_input 0 32) (sz (v domain_separator))))

let sample_vector_cbd_then_ntt_helper_2
(v_K v_ETA v_ETA_RANDOMNESS_SIZE: usize)
(#v_Vector: Type0)
(#[FStar.Tactics.Typeclasses.tcresolve ()]
i2:
Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector)
(re_as_ntt: t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K)
(prf_input: t_Array u8 (sz 33))
(domain_separator: u8) : Lemma
(requires Spec.MLKEM.is_rank v_K /\ v_ETA == Spec.MLKEM.v_ETA1 v_K /\
v_ETA_RANDOMNESS_SIZE == Spec.MLKEM.v_ETA1_RANDOMNESS_SIZE v_K /\
v domain_separator < 2 * v v_K /\
(let prf_outputs = Spec.MLKEM.v_PRFxN v_K v_ETA_RANDOMNESS_SIZE
(createi v_K (Spec.MLKEM.sample_vector_cbd1_prf_input #v_K
(Seq.slice prf_input 0 32) (sz (v domain_separator)))) in
forall (i: nat). i < v v_K ==>
Libcrux_ml_kem.Polynomial.to_spec_poly_t #v_Vector re_as_ntt.[ sz i ] ==
Spec.MLKEM.poly_ntt (Spec.MLKEM.sample_poly_cbd v_ETA prf_outputs.[ sz i ])))
(ensures Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector re_as_ntt ==
(Spec.MLKEM.sample_vector_cbd_then_ntt #v_K
(Seq.slice prf_input 0 32) (sz (v domain_separator))))
=
reveal_opaque (`%Spec.MLKEM.sample_vector_cbd_then_ntt) (Spec.MLKEM.sample_vector_cbd_then_ntt #v_K);
Lib.Sequence.eq_intro #(Spec.MLKEM.polynomial) #(v v_K)
(Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector re_as_ntt)
(Spec.MLKEM.sample_vector_cbd_then_ntt #v_K
(Seq.slice prf_input 0 32) (sz (v domain_separator)))

#push-options "--max_fuel 15 --z3rlimit 1500 --ext context_pruning --z3refresh --split_queries always"

let sample_vector_cbd_then_ntt
(v_K v_ETA v_ETA_RANDOMNESS_SIZE: usize)
Expand All @@ -229,61 +300,13 @@ let sample_vector_cbd_then_ntt
=
let prf_inputs:t_Array (t_Array u8 (sz 33)) v_K = Rust_primitives.Hax.repeat prf_input v_K in
let v__domain_separator_init:u8 = domain_separator in
let v__prf_inputs_init:t_Array (t_Array u8 (sz 33)) v_K = prf_inputs in
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) =
Rust_primitives.Hax.Folds.fold_range (sz 0)
v_K
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
let i:usize = i in
v domain_separator == v v__domain_separator_init + v i /\
(v i < v v_K ==>
(forall (j: nat).
(j >= v i /\ j < v v_K) ==> prf_inputs.[ sz j ] == v__prf_inputs_init.[ sz j ])) /\
(forall (j: nat).
j < v i ==>
v (Seq.index (Seq.index prf_inputs j) 32) == v v__domain_separator_init + j /\
Seq.slice (Seq.index prf_inputs j) 0 32 ==
Seq.slice (Seq.index v__prf_inputs_init j) 0 32))
(domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
let i:usize = i in
let prf_inputs:t_Array (t_Array u8 (sz 33)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize prf_inputs
i
(Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (prf_inputs.[ i ]
<:
t_Array u8 (sz 33))
(sz 32)
domain_separator
<:
t_Array u8 (sz 33))
in
let domain_separator:u8 = domain_separator +! 1uy in
domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
let tmp0, out:(t_Array (t_Array u8 (sz 33)) v_K & u8) =
prf_input_inc v_K prf_inputs domain_separator
in
let prf_inputs:t_Array (t_Array u8 (sz 33)) v_K = tmp0 in
let domain_separator:u8 = out in
let _:Prims.unit =
let lemma_aux (i: nat{i < v v_K})
: Lemma
(prf_inputs.[ sz i ] ==
(Seq.append (Seq.slice prf_input 0 32)
(Seq.create 1
(mk_int #u8_inttype (v (v__domain_separator_init +! (mk_int #u8_inttype i))))))) =
Lib.Sequence.eq_intro #u8
#33
prf_inputs.[ sz i ]
(Seq.append (Seq.slice prf_input 0 32)
(Seq.create 1 (mk_int #u8_inttype (v v__domain_separator_init + i))))
in
Classical.forall_intro lemma_aux;
Lib.Sequence.eq_intro #(t_Array u8 (sz 33))
#(v v_K)
prf_inputs
(createi v_K
(Spec.MLKEM.sample_vector_cbd1_prf_input #v_K
(Seq.slice prf_input 0 32)
(sz (v v__domain_separator_init))))
sample_vector_cbd_then_ntt_helper_1 v_K prf_inputs prf_input v__domain_separator_init
in
let (prf_outputs: t_Array (t_Array u8 v_ETA_RANDOMNESS_SIZE) v_K):t_Array
(t_Array u8 v_ETA_RANDOMNESS_SIZE) v_K =
Expand All @@ -304,7 +327,8 @@ let sample_vector_cbd_then_ntt
forall (j: nat).
j < v i ==>
Libcrux_ml_kem.Polynomial.to_spec_poly_t #v_Vector re_as_ntt.[ sz j ] ==
Spec.MLKEM.poly_ntt (Spec.MLKEM.sample_poly_cbd v_ETA prf_outputs.[ sz j ]))
Spec.MLKEM.poly_ntt (Spec.MLKEM.sample_poly_cbd v_ETA prf_outputs.[ sz j ]) /\
Libcrux_ml_kem.Serialize.coefficients_field_modulus_range #v_Vector re_as_ntt.[ sz j ])
re_as_ntt
(fun re_as_ntt i ->
let re_as_ntt:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K =
Expand All @@ -331,12 +355,13 @@ let sample_vector_cbd_then_ntt
re_as_ntt)
in
let _:Prims.unit =
Lib.Sequence.eq_intro #(Spec.MLKEM.polynomial)
#(v v_K)
(Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector re_as_ntt)
(Spec.MLKEM.sample_vector_cbd_then_ntt #v_K
(Seq.slice prf_input 0 32)
(sz (v v__domain_separator_init)))
sample_vector_cbd_then_ntt_helper_2 v_K
v_ETA
v_ETA_RANDOMNESS_SIZE
#v_Vector
re_as_ntt
prf_input
v__domain_separator_init
in
let hax_temp_output:u8 = domain_separator in
re_as_ntt, hax_temp_output
Expand Down Expand Up @@ -526,7 +551,7 @@ let generate_keypair_unpacked

#pop-options

#push-options "--z3rlimit 200 --ext context_pruning --z3refresh"
#push-options "--z3rlimit 800 --ext context_pruning --z3refresh"

let compress_then_serialize_u
(v_K v_OUT_LEN v_COMPRESSION_FACTOR v_BLOCK_LEN: usize)
Expand Down Expand Up @@ -804,7 +829,7 @@ let encrypt
Lib.Sequence.eq_intro #u8
#32
seed
(Seq.slice (Libcrux_ml_kem.Utils.into_padded_array (Rust_primitives.mk_usize 34) seed) 0 32)
(Seq.slice (Libcrux_ml_kem.Utils.into_padded_array (sz 34) seed) 0 32)
in
let unpacked_public_key:Libcrux_ml_kem.Ind_cpa.Unpacked.t_IndCpaPublicKeyUnpacked v_K v_Vector =
{
Expand All @@ -825,7 +850,7 @@ let encrypt
v_U_COMPRESSION_FACTOR v_V_COMPRESSION_FACTOR v_BLOCK_LEN v_ETA1 v_ETA1_RANDOMNESS_SIZE v_ETA2
v_ETA2_RANDOMNESS_SIZE #v_Vector #v_Hasher unpacked_public_key message randomness

#push-options "--ext context_pruning"
#push-options "--z3rlimit 800 --ext context_pruning"

let deserialize_then_decompress_u
(v_K v_CIPHERTEXT_SIZE v_U_COMPRESSION_FACTOR: usize)
Expand All @@ -837,7 +862,7 @@ let deserialize_then_decompress_u
=
let _:Prims.unit =
assert (v ((Libcrux_ml_kem.Constants.v_COEFFICIENTS_IN_RING_ELEMENT *! v_U_COMPRESSION_FACTOR) /!
Rust_primitives.mk_usize 8) ==
sz 8) ==
v (Spec.MLKEM.v_C1_BLOCK_SIZE v_K))
in
let u_as_ntt:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K =
Expand Down Expand Up @@ -1062,6 +1087,7 @@ let serialize_secret_key
#v_Vector
key) ==
Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector key);
reveal_opaque (`%Spec.MLKEM.vector_encode_12) (Spec.MLKEM.vector_encode_12 #v_K);
Lib.Sequence.eq_intro #u8
#(v v_OUT_LEN)
out
Expand Down
17 changes: 15 additions & 2 deletions libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Ind_cpa.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ let _ =
let open Libcrux_ml_kem.Vector.Traits in
()

val prf_input_inc (v_K: usize) (prf_inputs: t_Array (t_Array u8 (sz 33)) v_K) (domain_separator: u8)
: Prims.Pure (t_Array (t_Array u8 (sz 33)) v_K & u8)
(requires range (v domain_separator + v v_K) u8_inttype)
(ensures
fun temp_0_ ->
let prf_inputs_future, ds:(t_Array (t_Array u8 (sz 33)) v_K & u8) = temp_0_ in
v ds == v domain_separator + v v_K /\
(forall (i: nat).
i < v v_K ==>
v (Seq.index (Seq.index prf_inputs_future i) 32) == v domain_separator + i /\
Seq.slice (Seq.index prf_inputs_future i) 0 32 ==
Seq.slice (Seq.index prf_inputs i) 0 32))

/// Call [`deserialize_to_uncompressed_ring_element`] for each ring element.
val deserialize_secret_key
(v_K: usize)
Expand Down Expand Up @@ -82,8 +95,8 @@ val sample_vector_cbd_then_ntt
(sz (v domain_separator)) /\
(forall (i: nat).
i < v v_K ==>
Libcrux_ml_kem.Serialize.coefficients_field_modulus_range (Seq.index re_as_ntt_future
i)))
Libcrux_ml_kem.Serialize.coefficients_field_modulus_range #v_Vector
(Seq.index re_as_ntt_future i)))

val sample_vector_cbd_then_ntt_out
(v_K v_ETA v_ETA_RANDOMNESS_SIZE: usize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ let inv_ntt_layer_int_vec_step_reduce
let b:v_Vector = Libcrux_ml_kem.Vector.Traits.montgomery_multiply_fe #v_Vector a_minus_b zeta_r in
a, b <: (v_Vector & v_Vector)

#push-options "--z3rlimit 200 --ext context_pruning"

let invert_ntt_at_layer_1_
(#v_Vector: Type0)
(#[FStar.Tactics.Typeclasses.tcresolve ()]
Expand Down Expand Up @@ -107,6 +109,10 @@ let invert_ntt_at_layer_1_
let hax_temp_output:Prims.unit = () <: Prims.unit in
zeta_i, re <: (usize & Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)

#pop-options

#push-options "--z3rlimit 200 --ext context_pruning"

let invert_ntt_at_layer_2_
(#v_Vector: Type0)
(#[FStar.Tactics.Typeclasses.tcresolve ()]
Expand Down Expand Up @@ -182,6 +188,10 @@ let invert_ntt_at_layer_2_
let hax_temp_output:Prims.unit = () <: Prims.unit in
zeta_i, re <: (usize & Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)

#pop-options

#push-options "--z3rlimit 200 --ext context_pruning"

let invert_ntt_at_layer_3_
(#v_Vector: Type0)
(#[FStar.Tactics.Typeclasses.tcresolve ()]
Expand Down Expand Up @@ -255,6 +265,8 @@ let invert_ntt_at_layer_3_
let hax_temp_output:Prims.unit = () <: Prims.unit in
zeta_i, re <: (usize & Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)

#pop-options

#push-options "--admit_smt_queries true"

let invert_ntt_at_layer_4_plus
Expand Down
Loading

0 comments on commit 354e3ef

Please sign in to comment.