Skip to content

Commit

Permalink
chore: integrate gnark-crypto refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabaie committed Feb 5, 2025
1 parent 1b57817 commit 37a79fa
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 80 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/blang/semver/v4 v4.0.0
github.com/consensys/bavard v0.1.27
github.com/consensys/compress v0.2.5
github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1
github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca
github.com/fxamacker/cbor/v2 v2.7.0
github.com/google/go-cmp v0.6.0
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.1.27 h1:j6hKUrGAy/H+gpNrpLU3I26n1yc+VMGmd6ID5+gAh
github.com/consensys/bavard v0.1.27/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs=
github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk=
github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk=
github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1 h1:PuRSTn2hpFm+mqysWl/hjTU2AvXYMNZT1nvxQT5j5PY=
github.com/consensys/gnark-crypto v0.15.1-0.20250203033118-19afe00d3be1/go.mod h1:Ke3j06ndtPTVvo++PhGNgvm+lgpLvzbcE2MqljY7diU=
github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca h1:u6iXwMBfbXODF+hDSwKSTBg6yfD3+eMX6o3PILAK474=
github.com/consensys/gnark-crypto v0.16.1-0.20250205153847-10a243d332ca/go.mod h1:Ke3j06ndtPTVvo++PhGNgvm+lgpLvzbcE2MqljY7diU=
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
Expand Down
4 changes: 2 additions & 2 deletions std/hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ type BinaryFixedLengthHasher interface {

// CompressionFunction is a 2 to 1 function
type CompressionFunction interface {
Apply(frontend.API, frontend.Variable, frontend.Variable) frontend.Variable // TODO @Tabaie @ThomasPiellard better name
Compress(frontend.API, frontend.Variable, frontend.Variable) frontend.Variable
}

type merkleDamgardHasher struct {
Expand All @@ -117,7 +117,7 @@ func (h *merkleDamgardHasher) Reset() {

func (h *merkleDamgardHasher) Write(data ...frontend.Variable) {
for _, d := range data {
h.state = h.f.Apply(h.api, h.state, d)
h.state = h.f.Compress(h.api, h.state, d)
}
}

Expand Down
2 changes: 1 addition & 1 deletion std/hash/poseidon2/poseidon2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

func TestPoseidon2Hash(t *testing.T) {
// prepare expected output
h := poseidon2.NewPoseidon2()
h := poseidon2.NewMerkleDamgardHasher()
for i := range 5 {
_, err := h.Write([]byte{byte(i)})
require.NoError(t, err)
Expand Down
4 changes: 1 addition & 3 deletions std/hash/poseidon2/posiedon2.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"github.com/consensys/gnark/internal/utils"
"github.com/consensys/gnark/std/hash"
poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2"
"strings"
)

func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) {
Expand All @@ -16,8 +15,7 @@ func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) {
if !ok {
return nil, fmt.Errorf("poseidon2 hash for curve \"%s\" not yet supported", curve.String())
}
seed := fmt.Sprintf("Poseidon2 hash for %s with t=2, rF=%d, rP=%d, d=%d", strings.ToUpper(curve.String()), params.rF, params.rP, params.d)
f := poseidon2.NewHash(2, params.d, params.rF, params.rP, seed, curve)
f := poseidon2.NewHash(2, params.d, params.rF, params.rP, curve)
return hash.NewMerkleDamgardHasher(api, &f, 0), nil
}

Expand Down
111 changes: 57 additions & 54 deletions std/permutation/poseidon2/poseidon2.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package poseidon

import (
"errors"
"fmt"
"math/big"

"github.com/consensys/gnark-crypto/ecc"
Expand All @@ -19,7 +20,7 @@ var (
ErrInvalidSizebuffer = errors.New("the size of the input should match the size of the hash buffer")
)

type Hash struct {
type Permutation struct {
params parameters
}

Expand All @@ -45,77 +46,79 @@ type parameters struct {
roundKeys [][]big.Int
}

func NewHash(t, d, rf, rp int, seed string, curve ecc.ID) Hash {
func NewHash(t, d, rf, rp int, curve ecc.ID) Permutation {
params := parameters{t: t, d: d, rF: rf, rP: rp}
if curve == ecc.BN254 {
rc := poseidonbn254.InitRC(seed, rf, rp, t)
params.roundKeys = make([][]big.Int, len(rc))
for i := 0; i < len(rc); i++ {
params.roundKeys[i] = make([]big.Int, len(rc[i]))
for j := 0; j < len(rc[i]); j++ {
rc[i][j].BigInt(&params.roundKeys[i][j])
concreteParams := poseidonbn254.NewParameters(t, rf, rp)
params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys))
for i := range params.roundKeys {
params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i]))
for j := range params.roundKeys[i] {
concreteParams.RoundKeys[i][j].BigInt(&params.roundKeys[i][j])
}
}
} else if curve == ecc.BLS12_381 {
rc := poseidonbls12381.InitRC(seed, rf, rp, t)
params.roundKeys = make([][]big.Int, len(rc))
for i := 0; i < len(rc); i++ {
params.roundKeys[i] = make([]big.Int, len(rc[i]))
for j := 0; j < len(rc[i]); j++ {
rc[i][j].BigInt(&params.roundKeys[i][j])
concreteParams := poseidonbls12381.NewParameters(t, rf, rp)
params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys))
for i := range params.roundKeys {
params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i]))
for j := range params.roundKeys[i] {
concreteParams.RoundKeys[i][j].BigInt(&params.roundKeys[i][j])
}
}
} else if curve == ecc.BLS12_377 {
rc := poseidonbls12377.InitRC(seed, rf, rp, t)
params.roundKeys = make([][]big.Int, len(rc))
for i := 0; i < len(rc); i++ {
params.roundKeys[i] = make([]big.Int, len(rc[i]))
for j := 0; j < len(rc[i]); j++ {
rc[i][j].BigInt(&params.roundKeys[i][j])
concreteParams := poseidonbls12377.NewParameters(t, rf, rp)
params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys))
for i := range params.roundKeys {
params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i]))
for j := range params.roundKeys[i] {
concreteParams.RoundKeys[i][j].BigInt(&params.roundKeys[i][j])
}
}
} else if curve == ecc.BW6_761 {
rc := poseidonbw6761.InitRC(seed, rf, rp, t)
params.roundKeys = make([][]big.Int, len(rc))
for i := 0; i < len(rc); i++ {
params.roundKeys[i] = make([]big.Int, len(rc[i]))
for j := 0; j < len(rc[i]); j++ {
rc[i][j].BigInt(&params.roundKeys[i][j])
concreteParams := poseidonbw6761.NewParameters(t, rf, rp)
params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys))
for i := range params.roundKeys {
params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i]))
for j := range params.roundKeys[i] {
concreteParams.RoundKeys[i][j].BigInt(&params.roundKeys[i][j])
}
}
} else if curve == ecc.BW6_633 {
rc := poseidonbw6633.InitRC(seed, rf, rp, t)
params.roundKeys = make([][]big.Int, len(rc))
for i := 0; i < len(rc); i++ {
params.roundKeys[i] = make([]big.Int, len(rc[i]))
for j := 0; j < len(rc[i]); j++ {
rc[i][j].BigInt(&params.roundKeys[i][j])
concreteParams := poseidonbw6633.NewParameters(t, rf, rp)
params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys))
for i := range params.roundKeys {
params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i]))
for j := range params.roundKeys[i] {
concreteParams.RoundKeys[i][j].BigInt(&params.roundKeys[i][j])
}
}
} else if curve == ecc.BLS24_315 {
rc := poseidonbls24315.InitRC(seed, rf, rp, t)
params.roundKeys = make([][]big.Int, len(rc))
for i := 0; i < len(rc); i++ {
params.roundKeys[i] = make([]big.Int, len(rc[i]))
for j := 0; j < len(rc[i]); j++ {
rc[i][j].BigInt(&params.roundKeys[i][j])
concreteParams := poseidonbls24315.NewParameters(t, rf, rp)
params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys))
for i := range params.roundKeys {
params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i]))
for j := range params.roundKeys[i] {
concreteParams.RoundKeys[i][j].BigInt(&params.roundKeys[i][j])
}
}
} else if curve == ecc.BLS24_317 {
rc := poseidonbls24317.InitRC(seed, rf, rp, t)
params.roundKeys = make([][]big.Int, len(rc))
for i := 0; i < len(rc); i++ {
params.roundKeys[i] = make([]big.Int, len(rc[i]))
for j := 0; j < len(rc[i]); j++ {
rc[i][j].BigInt(&params.roundKeys[i][j])
concreteParams := poseidonbls24317.NewParameters(t, rf, rp)
params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys))
for i := range params.roundKeys {
params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i]))
for j := range params.roundKeys[i] {
concreteParams.RoundKeys[i][j].BigInt(&params.roundKeys[i][j])
}
}
} else {
panic(fmt.Errorf("curve %s not supported", curve.String()))
}
return Hash{params: params}
return Permutation{params: params}
}

// sBox applies the sBox on buffer[index]
func (h *Hash) sBox(api frontend.API, index int, input []frontend.Variable) {
func (h *Permutation) sBox(api frontend.API, index int, input []frontend.Variable) {
tmp := input[index]
if h.params.d == 3 {
input[index] = api.Mul(input[index], input[index])
Expand Down Expand Up @@ -149,7 +152,7 @@ func (h *Hash) sBox(api frontend.API, index int, input []frontend.Variable) {
// (1 1 4 6)
// on chunks of 4 elements on each part of the buffer
// see https://eprint.iacr.org/2023/323.pdf appendix B for the addition chain
func (h *Hash) matMulM4InPlace(api frontend.API, s []frontend.Variable) {
func (h *Permutation) matMulM4InPlace(api frontend.API, s []frontend.Variable) {
c := len(s) / 4
for i := 0; i < c; i++ {
t0 := api.Add(s[4*i], s[4*i+1]) // s0+s1
Expand All @@ -176,7 +179,7 @@ func (h *Hash) matMulM4InPlace(api frontend.API, s []frontend.Variable) {
//
// when t=0[4], the buffer is multiplied by circ(2M4,M4,..,M4)
// see https://eprint.iacr.org/2023/323.pdf
func (h *Hash) matMulExternalInPlace(api frontend.API, input []frontend.Variable) {
func (h *Permutation) matMulExternalInPlace(api frontend.API, input []frontend.Variable) {

if h.params.t == 2 {
tmp := api.Add(input[0], input[1])
Expand Down Expand Up @@ -213,7 +216,7 @@ func (h *Hash) matMulExternalInPlace(api frontend.API, input []frontend.Variable

// when t=2,3 the matrix are respectibely [[2,1][1,3]] and [[2,1,1][1,2,1][1,1,3]]
// otherwise the matrix is filled with ones except on the diagonal,
func (h *Hash) matMulInternalInPlace(api frontend.API, input []frontend.Variable) {
func (h *Permutation) matMulInternalInPlace(api frontend.API, input []frontend.Variable) {
if h.params.t == 2 {
sum := api.Add(input[0], input[1])
input[0] = api.Add(input[0], sum)
Expand Down Expand Up @@ -241,13 +244,13 @@ func (h *Hash) matMulInternalInPlace(api frontend.API, input []frontend.Variable
}

// addRoundKeyInPlace adds the round-th key to the buffer
func (h *Hash) addRoundKeyInPlace(api frontend.API, round int, input []frontend.Variable) {
func (h *Permutation) addRoundKeyInPlace(api frontend.API, round int, input []frontend.Variable) {
for i := 0; i < len(h.params.roundKeys[round]); i++ {
input[i] = api.Add(input[i], h.params.roundKeys[round][i])
}
}

func (h *Hash) Permutation(api frontend.API, input []frontend.Variable) error {
func (h *Permutation) Permutation(api frontend.API, input []frontend.Variable) error {
if len(input) != h.params.t {
return ErrInvalidSizebuffer
}
Expand Down Expand Up @@ -283,11 +286,11 @@ func (h *Hash) Permutation(api frontend.API, input []frontend.Variable) error {
return nil
}

// Apply aliases Permutation in the t=2 case
// Compress aliases Permutation in the t=2 case
// implements hash.CompressionFunction
func (h *Hash) Apply(api frontend.API, l, r frontend.Variable) frontend.Variable {
func (h *Permutation) Compress(api frontend.API, l, r frontend.Variable) frontend.Variable {
if h.params.t != 2 {
panic("poseidon2: Apply can only be used when t=2")
panic("poseidon2: Compress can only be used when t=2")
}
vars := [2]frontend.Variable{l, r}
if err := h.Permutation(api, vars[:]); err != nil {
Expand Down
34 changes: 17 additions & 17 deletions std/permutation/poseidon2/poseidon2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type circuitParams struct {
}

func (c *Poseidon2Circuit) Define(api frontend.API) error {
h := NewHash(c.params.t, c.params.d, c.params.rf, c.params.rp, c.params.seed, c.params.id)
h := NewHash(c.params.t, c.params.d, c.params.rf, c.params.rp, c.params.id)
h.Permutation(api, c.Input)
for i := 0; i < len(c.Input); i++ {
api.AssertIsEqual(c.Output[i], c.Input[i])
Expand All @@ -68,11 +68,11 @@ func TestPoseidon2(t *testing.T) {
{
var circuit, validWitness Poseidon2Circuit

h := poseidonbn254.NewHash(
h := poseidonbn254.NewPermutation(
params[ecc.BN254].t,
params[ecc.BN254].rf,
params[ecc.BN254].rp,
"seed")
)
var in, out [3]frbn254.Element
for i := 0; i < 3; i++ {
in[i].SetRandom()
Expand Down Expand Up @@ -101,11 +101,11 @@ func TestPoseidon2(t *testing.T) {
{
var circuit, validWitness Poseidon2Circuit

h := poseidonbls12377.NewHash(
h := poseidonbls12377.NewPermutation(
params[ecc.BLS12_377].t,
params[ecc.BLS12_377].rf,
params[ecc.BLS12_377].rp,
"seed")
)
var in, out [3]frbls12377.Element
for i := 0; i < 3; i++ {
in[i].SetRandom()
Expand Down Expand Up @@ -134,11 +134,11 @@ func TestPoseidon2(t *testing.T) {
{
var circuit, validWitness Poseidon2Circuit

h := poseidonbls12381.NewHash(
h := poseidonbls12381.NewPermutation(
params[ecc.BLS12_381].t,
params[ecc.BLS12_381].rf,
params[ecc.BLS12_381].rp,
"seed")
)
var in, out [3]frbls12381.Element
for i := 0; i < 3; i++ {
in[i].SetRandom()
Expand Down Expand Up @@ -167,11 +167,11 @@ func TestPoseidon2(t *testing.T) {
{
var circuit, validWitness Poseidon2Circuit

h := poseidonbw6633.NewHash(
h := poseidonbw6633.NewPermutation(
params[ecc.BW6_633].t,
params[ecc.BW6_633].rf,
params[ecc.BW6_633].rp,
"seed")
)
var in, out [3]frbw6633.Element
for i := 0; i < 3; i++ {
in[i].SetRandom()
Expand Down Expand Up @@ -200,11 +200,11 @@ func TestPoseidon2(t *testing.T) {
{
var circuit, validWitness Poseidon2Circuit

h := poseidonbw6633.NewHash(
h := poseidonbw6633.NewPermutation(
params[ecc.BW6_633].t,
params[ecc.BW6_633].rf,
params[ecc.BW6_633].rp,
"seed")
)
var in, out [3]frbw6633.Element
for i := 0; i < 3; i++ {
in[i].SetRandom()
Expand Down Expand Up @@ -233,11 +233,11 @@ func TestPoseidon2(t *testing.T) {
{
var circuit, validWitness Poseidon2Circuit

h := poseidonbw6761.NewHash(
h := poseidonbw6761.NewPermutation(
params[ecc.BW6_761].t,
params[ecc.BW6_761].rf,
params[ecc.BW6_761].rp,
"seed")
)
var in, out [3]frbw6761.Element
for i := 0; i < 3; i++ {
in[i].SetRandom()
Expand Down Expand Up @@ -266,11 +266,11 @@ func TestPoseidon2(t *testing.T) {
{
var circuit, validWitness Poseidon2Circuit

h := poseidonbls24315.NewHash(
h := poseidonbls24315.NewPermutation(
params[ecc.BLS24_315].t,
params[ecc.BLS24_315].rf,
params[ecc.BLS24_315].rp,
"seed")
)
var in, out [3]frbls24315.Element
for i := 0; i < 3; i++ {
in[i].SetRandom()
Expand Down Expand Up @@ -299,11 +299,11 @@ func TestPoseidon2(t *testing.T) {
{
var circuit, validWitness Poseidon2Circuit

h := poseidonbls24317.NewHash(
h := poseidonbls24317.NewPermutation(
params[ecc.BLS24_317].t,
params[ecc.BLS24_317].rf,
params[ecc.BLS24_317].rp,
"seed")
)
var in, out [3]frbls24317.Element
for i := 0; i < 3; i++ {
in[i].SetRandom()
Expand Down

0 comments on commit 37a79fa

Please sign in to comment.