From 37a79fa44d7674432acfedf159c792a858ba2ca1 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Wed, 5 Feb 2025 10:11:27 -0600 Subject: [PATCH] chore: integrate gnark-crypto refactors --- go.mod | 2 +- go.sum | 4 +- std/hash/hash.go | 4 +- std/hash/poseidon2/poseidon2_test.go | 2 +- std/hash/poseidon2/posiedon2.go | 4 +- std/permutation/poseidon2/poseidon2.go | 111 ++++++++++---------- std/permutation/poseidon2/poseidon2_test.go | 34 +++--- 7 files changed, 81 insertions(+), 80 deletions(-) diff --git a/go.mod b/go.mod index 0175a3a03..b7b70697e 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.27 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1 + github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index 94084a48c..200c8c71e 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.1.27 h1:j6hKUrGAy/H+gpNrpLU3I26n1yc+VMGmd6ID5+gAh github.com/consensys/bavard v0.1.27/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1 h1:PuRSTn2hpFm+mqysWl/hjTU2AvXYMNZT1nvxQT5j5PY= -github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1/go.mod h1:Ke3j06ndtPTVvo++PhGNgvm+lgpLvzbcE2MqljY7diU= +github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca h1:u6iXwMBfbXODF+hDSwKSTBg6yfD3+eMX6o3PILAK474= +github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca/go.mod h1:Ke3j06ndtPTVvo++PhGNgvm+lgpLvzbcE2MqljY7diU= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/std/hash/hash.go b/std/hash/hash.go index ed14d8c6d..66452a797 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -90,7 +90,7 @@ type BinaryFixedLengthHasher interface { // CompressionFunction is a 2 to 1 function type CompressionFunction interface { - Apply(frontend.API, frontend.Variable, frontend.Variable) frontend.Variable // TODO @Tabaie @ThomasPiellard better name + Compress(frontend.API, frontend.Variable, frontend.Variable) frontend.Variable } type merkleDamgardHasher struct { @@ -117,7 +117,7 @@ func (h *merkleDamgardHasher) Reset() { func (h *merkleDamgardHasher) Write(data ...frontend.Variable) { for _, d := range data { - h.state = h.f.Apply(h.api, h.state, d) + h.state = h.f.Compress(h.api, h.state, d) } } diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 001521854..72293926f 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -11,7 +11,7 @@ import ( func TestPoseidon2Hash(t *testing.T) { // prepare expected output - h := poseidon2.NewPoseidon2() + h := poseidon2.NewMerkleDamgardHasher() for i := range 5 { _, err := h.Write([]byte{byte(i)}) require.NoError(t, err) diff --git a/std/hash/poseidon2/posiedon2.go b/std/hash/poseidon2/posiedon2.go index 8e19ebc47..cad8cea68 100644 --- a/std/hash/poseidon2/posiedon2.go +++ b/std/hash/poseidon2/posiedon2.go @@ -7,7 +7,6 @@ import ( "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/hash" poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2" - "strings" ) func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { @@ -16,8 +15,7 @@ func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { if !ok { return nil, fmt.Errorf("poseidon2 hash for curve \"%s\" not yet supported", curve.String()) } - seed := fmt.Sprintf("Poseidon2 hash for %s with t=2, rF=%d, rP=%d, d=%d", strings.ToUpper(curve.String()), params.rF, params.rP, params.d) - f := poseidon2.NewHash(2, params.d, params.rF, params.rP, seed, curve) + f := poseidon2.NewHash(2, params.d, params.rF, params.rP, curve) return hash.NewMerkleDamgardHasher(api, &f, 0), nil } diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index 51ba0553a..ddfdbaa5d 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -2,6 +2,7 @@ package poseidon import ( "errors" + "fmt" "math/big" "github.com/consensys/gnark-crypto/ecc" @@ -19,7 +20,7 @@ var ( ErrInvalidSizebuffer = errors.New("the size of the input should match the size of the hash buffer") ) -type Hash struct { +type Permutation struct { params parameters } @@ -45,77 +46,79 @@ type parameters struct { roundKeys [][]big.Int } -func NewHash(t, d, rf, rp int, seed string, curve ecc.ID) Hash { +func NewHash(t, d, rf, rp int, curve ecc.ID) Permutation { params := parameters{t: t, d: d, rF: rf, rP: rp} if curve == ecc.BN254 { - rc := poseidonbn254.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbn254.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BLS12_381 { - rc := poseidonbls12381.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbls12381.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BLS12_377 { - rc := poseidonbls12377.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbls12377.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BW6_761 { - rc := poseidonbw6761.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbw6761.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BW6_633 { - rc := poseidonbw6633.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbw6633.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BLS24_315 { - rc := poseidonbls24315.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbls24315.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } } else if curve == ecc.BLS24_317 { - rc := poseidonbls24317.InitRC(seed, rf, rp, t) - params.roundKeys = make([][]big.Int, len(rc)) - for i := 0; i < len(rc); i++ { - params.roundKeys[i] = make([]big.Int, len(rc[i])) - for j := 0; j < len(rc[i]); j++ { - rc[i][j].BigInt(¶ms.roundKeys[i][j]) + concreteParams := poseidonbls24317.NewParameters(t, rf, rp) + params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.roundKeys { + params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.roundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) } } + } else { + panic(fmt.Errorf("curve %s not supported", curve.String())) } - return Hash{params: params} + return Permutation{params: params} } // sBox applies the sBox on buffer[index] -func (h *Hash) sBox(api frontend.API, index int, input []frontend.Variable) { +func (h *Permutation) sBox(api frontend.API, index int, input []frontend.Variable) { tmp := input[index] if h.params.d == 3 { input[index] = api.Mul(input[index], input[index]) @@ -149,7 +152,7 @@ func (h *Hash) sBox(api frontend.API, index int, input []frontend.Variable) { // (1 1 4 6) // on chunks of 4 elements on each part of the buffer // see https://eprint.iacr.org/2023/323.pdf appendix B for the addition chain -func (h *Hash) matMulM4InPlace(api frontend.API, s []frontend.Variable) { +func (h *Permutation) matMulM4InPlace(api frontend.API, s []frontend.Variable) { c := len(s) / 4 for i := 0; i < c; i++ { t0 := api.Add(s[4*i], s[4*i+1]) // s0+s1 @@ -176,7 +179,7 @@ func (h *Hash) matMulM4InPlace(api frontend.API, s []frontend.Variable) { // // when t=0[4], the buffer is multiplied by circ(2M4,M4,..,M4) // see https://eprint.iacr.org/2023/323.pdf -func (h *Hash) matMulExternalInPlace(api frontend.API, input []frontend.Variable) { +func (h *Permutation) matMulExternalInPlace(api frontend.API, input []frontend.Variable) { if h.params.t == 2 { tmp := api.Add(input[0], input[1]) @@ -213,7 +216,7 @@ func (h *Hash) matMulExternalInPlace(api frontend.API, input []frontend.Variable // when t=2,3 the matrix are respectibely [[2,1][1,3]] and [[2,1,1][1,2,1][1,1,3]] // otherwise the matrix is filled with ones except on the diagonal, -func (h *Hash) matMulInternalInPlace(api frontend.API, input []frontend.Variable) { +func (h *Permutation) matMulInternalInPlace(api frontend.API, input []frontend.Variable) { if h.params.t == 2 { sum := api.Add(input[0], input[1]) input[0] = api.Add(input[0], sum) @@ -241,13 +244,13 @@ func (h *Hash) matMulInternalInPlace(api frontend.API, input []frontend.Variable } // addRoundKeyInPlace adds the round-th key to the buffer -func (h *Hash) addRoundKeyInPlace(api frontend.API, round int, input []frontend.Variable) { +func (h *Permutation) addRoundKeyInPlace(api frontend.API, round int, input []frontend.Variable) { for i := 0; i < len(h.params.roundKeys[round]); i++ { input[i] = api.Add(input[i], h.params.roundKeys[round][i]) } } -func (h *Hash) Permutation(api frontend.API, input []frontend.Variable) error { +func (h *Permutation) Permutation(api frontend.API, input []frontend.Variable) error { if len(input) != h.params.t { return ErrInvalidSizebuffer } @@ -283,11 +286,11 @@ func (h *Hash) Permutation(api frontend.API, input []frontend.Variable) error { return nil } -// Apply aliases Permutation in the t=2 case +// Compress aliases Permutation in the t=2 case // implements hash.CompressionFunction -func (h *Hash) Apply(api frontend.API, l, r frontend.Variable) frontend.Variable { +func (h *Permutation) Compress(api frontend.API, l, r frontend.Variable) frontend.Variable { if h.params.t != 2 { - panic("poseidon2: Apply can only be used when t=2") + panic("poseidon2: Compress can only be used when t=2") } vars := [2]frontend.Variable{l, r} if err := h.Permutation(api, vars[:]); err != nil { diff --git a/std/permutation/poseidon2/poseidon2_test.go b/std/permutation/poseidon2/poseidon2_test.go index f14c07813..66ed81488 100644 --- a/std/permutation/poseidon2/poseidon2_test.go +++ b/std/permutation/poseidon2/poseidon2_test.go @@ -44,7 +44,7 @@ type circuitParams struct { } func (c *Poseidon2Circuit) Define(api frontend.API) error { - h := NewHash(c.params.t, c.params.d, c.params.rf, c.params.rp, c.params.seed, c.params.id) + h := NewHash(c.params.t, c.params.d, c.params.rf, c.params.rp, c.params.id) h.Permutation(api, c.Input) for i := 0; i < len(c.Input); i++ { api.AssertIsEqual(c.Output[i], c.Input[i]) @@ -68,11 +68,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbn254.NewHash( + h := poseidonbn254.NewPermutation( params[ecc.BN254].t, params[ecc.BN254].rf, params[ecc.BN254].rp, - "seed") + ) var in, out [3]frbn254.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -101,11 +101,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls12377.NewHash( + h := poseidonbls12377.NewPermutation( params[ecc.BLS12_377].t, params[ecc.BLS12_377].rf, params[ecc.BLS12_377].rp, - "seed") + ) var in, out [3]frbls12377.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -134,11 +134,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls12381.NewHash( + h := poseidonbls12381.NewPermutation( params[ecc.BLS12_381].t, params[ecc.BLS12_381].rf, params[ecc.BLS12_381].rp, - "seed") + ) var in, out [3]frbls12381.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -167,11 +167,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbw6633.NewHash( + h := poseidonbw6633.NewPermutation( params[ecc.BW6_633].t, params[ecc.BW6_633].rf, params[ecc.BW6_633].rp, - "seed") + ) var in, out [3]frbw6633.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -200,11 +200,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbw6633.NewHash( + h := poseidonbw6633.NewPermutation( params[ecc.BW6_633].t, params[ecc.BW6_633].rf, params[ecc.BW6_633].rp, - "seed") + ) var in, out [3]frbw6633.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -233,11 +233,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbw6761.NewHash( + h := poseidonbw6761.NewPermutation( params[ecc.BW6_761].t, params[ecc.BW6_761].rf, params[ecc.BW6_761].rp, - "seed") + ) var in, out [3]frbw6761.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -266,11 +266,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls24315.NewHash( + h := poseidonbls24315.NewPermutation( params[ecc.BLS24_315].t, params[ecc.BLS24_315].rf, params[ecc.BLS24_315].rp, - "seed") + ) var in, out [3]frbls24315.Element for i := 0; i < 3; i++ { in[i].SetRandom() @@ -299,11 +299,11 @@ func TestPoseidon2(t *testing.T) { { var circuit, validWitness Poseidon2Circuit - h := poseidonbls24317.NewHash( + h := poseidonbls24317.NewPermutation( params[ecc.BLS24_317].t, params[ecc.BLS24_317].rf, params[ecc.BLS24_317].rp, - "seed") + ) var in, out [3]frbls24317.Element for i := 0; i < 3; i++ { in[i].SetRandom()