From f20314490d4a5e4a9ff345bb737655fee1c20065 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Fri, 11 Dec 2020 14:54:56 -0800 Subject: [PATCH] ModifiedRescueGate generator --- src/field/field.rs | 4 ++ src/gates2/gate.rs | 1 + src/gates2/limb_sum.rs | 2 +- src/gates2/rescue.rs | 109 ++++++++++++++++++++++++++++++++++++++++- src/mds.rs | 41 +++++++++++----- src/rescue.rs | 1 + 6 files changed, 145 insertions(+), 13 deletions(-) diff --git a/src/field/field.rs b/src/field/field.rs index ad7fc58..9682f55 100644 --- a/src/field/field.rs +++ b/src/field/field.rs @@ -340,6 +340,10 @@ pub trait Field: self.exp(Self::from_canonical_usize(power)) } + fn kth_root_usize(&self, k: usize) -> Self { + self.kth_root(Self::from_canonical_usize(k)) + } + fn kth_root_u32(&self, k: u32) -> Self { self.kth_root(Self::from_canonical_u32(k)) } diff --git a/src/gates2/gate.rs b/src/gates2/gate.rs index ed5b21d..32a6a6b 100644 --- a/src/gates2/gate.rs +++ b/src/gates2/gate.rs @@ -4,6 +4,7 @@ use std::rc::Rc; use crate::{CircuitConfig, ConstraintPolynomial, Field, WitnessGenerator2}; /// A custom gate. +// TODO: Remove CircuitConfig params? Could just use fields within each struct. pub trait Gate2: 'static { fn id(&self) -> String; diff --git a/src/gates2/limb_sum.rs b/src/gates2/limb_sum.rs index af9a1b3..49313a2 100644 --- a/src/gates2/limb_sum.rs +++ b/src/gates2/limb_sum.rs @@ -16,7 +16,7 @@ impl LimbSumGate { impl DeterministicGate for LimbSumGate { fn id(&self) -> String { - format!("LimbSumGate-{}x{}", self.base, self.num_limbs) + format!("LimbSumGate[base={}, num_limbs={}]", self.base, self.num_limbs) } fn outputs(&self, _config: CircuitConfig) -> Vec<(usize, ConstraintPolynomial)> { diff --git a/src/gates2/rescue.rs b/src/gates2/rescue.rs index 44f3a25..bbad052 100644 --- a/src/gates2/rescue.rs +++ b/src/gates2/rescue.rs @@ -1,3 +1,110 @@ +use crate::{apply_mds, CircuitConfig, ConstraintPolynomial, Field, Gate2, PartialWitness2, SimpleGenerator, Target2, Wire, WitnessGenerator2}; + /// Implements a round of the Rescue permutation, modified with a different key schedule to reduce /// the number of constants involved. -pub struct ModifiedRescueGate; +#[derive(Copy, Clone)] +pub struct ModifiedRescueGate { + width: usize, + alpha: usize, +} + +impl ModifiedRescueGate { + /// Returns the index of the `i`th accumulator wire. + pub fn wire_acc(&self, i: usize) -> usize { + debug_assert!(i < self.width); + i + } + + /// Returns the index of the `i`th root wire. + pub fn wire_root(&self, i: usize) -> usize { + debug_assert!(i < self.width); + self.width + i + } +} + +impl Gate2 for ModifiedRescueGate { + fn id(&self) -> String { + format!("ModifiedRescueGate[width={}, alpha={}]", self.width, self.alpha) + } + + fn constraints(&self, _config: CircuitConfig) -> Vec> { + unimplemented!() + } + + fn generators( + &self, + _config: CircuitConfig, + gate_index: usize, + local_constants: Vec, + _next_constants: Vec, + ) -> Vec>> { + let gen = ModifiedRescueGenerator:: { + gate: *self, + gate_index, + constants: local_constants.clone(), + }; + vec![Box::new(gen)] + } +} + +struct ModifiedRescueGenerator { + gate: ModifiedRescueGate, + gate_index: usize, + constants: Vec, +} + +impl SimpleGenerator for ModifiedRescueGenerator { + fn dependencies(&self) -> Vec> { + (0..self.gate.width) + .map(|i| Target2::Wire(Wire { gate: self.gate_index, input: self.gate.wire_acc(i) })) + .collect() + } + + fn run_once(&mut self, witness: &PartialWitness2) -> PartialWitness2 { + let w = self.gate.width; + + // Load inputs. + let layer_0 = (0..w) + .map(|i| witness.get_wire( + Wire { gate: self.gate_index, input: self.gate.wire_acc(i) })) + .collect::>(); + + // Take alpha'th roots. + let layer_1 = layer_0.iter() + .map(|x| x.kth_root_usize(self.gate.alpha)) + .collect::>(); + let layer_roots = layer_1.clone(); + + // Apply MDS matrix. + let layer_2 = apply_mds(layer_1); + + // Add a constant to the first element. + let mut layer_3 = layer_2; + layer_3[0] = layer_3[0] + self.constants[0]; + + // Raise to the alpha'th power. + let layer_4 = layer_3.iter() + .map(|x| x.exp_usize(self.gate.alpha)) + .collect::>(); + + // Apply MDS matrix. + let layer_5 = apply_mds(layer_4); + + // Add a constant to the first element. + let mut layer_6 = layer_5; + layer_6[0] = layer_6[0] + self.constants[1]; + + let mut result = PartialWitness2::new(); + for i in 0..w { + // Set the i'th root wire. + result.set_wire( + Wire { gate: self.gate_index, input: self.gate.wire_root(i) }, + layer_roots[i]); + // Set the i'th output wire. + result.set_wire( + Wire { gate: self.gate_index + 1, input: self.gate.wire_acc(i) }, + layer_6[i]); + } + result + } +} diff --git a/src/mds.rs b/src/mds.rs index a8fc6eb..cd178c6 100644 --- a/src/mds.rs +++ b/src/mds.rs @@ -1,4 +1,4 @@ -use crate::Field; +use crate::{Field, ConstraintPolynomial}; use std::any::TypeId; use once_cell::sync::Lazy; use std::sync::Mutex; @@ -28,6 +28,12 @@ pub struct MdsMatrix { } impl MdsMatrix { + /// Returns the width and height of this matrix. + pub fn size(&self) -> usize { + self.unparameterized.rows.len() + } + + /// Returns the entry at row `r` and column `c`. pub fn get(&self, r: usize, c: usize) -> F { F::from_canonical_u64_vec(self.unparameterized.rows[r][c].clone()) } @@ -39,17 +45,30 @@ struct UnparameterizedMdsMatrix { rows: Vec>>, } -/// Apply an MDS matrix to the given state vector. -pub(crate) fn apply_mds(inputs: Vec) -> Vec { - let n = inputs.len(); - let mut result = vec![F::ZERO; n]; +/// Apply an MDS matrix to the given vector of field elements. +pub(crate) fn apply_mds(vec: Vec) -> Vec { + let n = vec.len(); let mds = mds_matrix::(n); - for r in 0..n { - for c in 0..n { - result[r] = result[r] + mds.get(r, c) * inputs[c]; - } - } - result + + (0..n) + .map(|r| (0..n) + .map(|c| mds.get(r, c) * vec[c]) + .fold(F::ZERO, |acc, x| acc + x)) + .collect() +} + +/// Applies an MDS matrix to the given vector of constraint polynomials. +pub(crate) fn apply_mds_constraint_polys( + vec: Vec>, +) -> Vec> { + let n = vec.len(); + let mds = mds_matrix::(n); + + (0..n) + .map(|r| (0..n) + .map(|c| &vec[c] * mds.get(r, c)) + .sum()) + .collect() } /// Returns entry `(r, c)` of an `n` by `n` MDS matrix. diff --git a/src/rescue.rs b/src/rescue.rs index 6cd74fd..2aa0bb9 100644 --- a/src/rescue.rs +++ b/src/rescue.rs @@ -99,6 +99,7 @@ pub(crate) fn generate_rescue_constants( security_bits: usize, ) -> Vec<(Vec, Vec)> { // TODO: This should use deterministic randomness. + // TODO: Reject subgroup elements. // FIX: Use ChaCha CSPRNG with a seed. This is somewhat similar to official implementation // at https://github.com/KULeuven-COSIC/Marvellous/blob/master/instance_generator.sage where they // use SHAKE256 with a seed to generate randomness.