From 7ef61c9ebb8a40c86e8cb16bedbc3054e43930e3 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 15 Dec 2023 13:24:13 -0600 Subject: [PATCH] fix: sis doesn't use twiddles directly no more --- ecc/bls12-377/fr/sis/sis.go | 4 +- ecc/bls12-377/fr/sis/sis_fft.go | 225 +++++++++++++------- ecc/bls12-377/fr/sis/sis_test.go | 6 +- ecc/bn254/fr/sis/sis.go | 2 +- ecc/bn254/fr/sis/sis_fft.go | 225 +++++++++++++------- ecc/bn254/fr/sis/sis_test.go | 2 +- internal/generator/main.go | 2 +- internal/generator/sis/template/fft.go.tmpl | 36 ++-- 8 files changed, 326 insertions(+), 176 deletions(-) diff --git a/ecc/bls12-377/fr/sis/sis.go b/ecc/bls12-377/fr/sis/sis.go index 164dc1a0e1..1279c8145b 100644 --- a/ecc/bls12-377/fr/sis/sis.go +++ b/ecc/bls12-377/fr/sis/sis.go @@ -119,7 +119,7 @@ func NewRSis(seed int64, logTwoDegree, logTwoBound, maxNbElementsToHash int) (*R LogTwoBound: logTwoBound, capacity: capacity, Degree: degree, - Domain: fft.NewDomain(uint64(degree), shift), + Domain: fft.NewDomain(uint64(degree), fft.WithShift(shift)), A: make([][]fr.Element, n), Ag: make([][]fr.Element, n), bufM: make(fr.Vector, degree*n), @@ -129,7 +129,7 @@ func NewRSis(seed int64, logTwoDegree, logTwoBound, maxNbElementsToHash int) (*R } if r.LogTwoBound == 8 && r.Degree == 64 { // TODO @gbotrel fixme, that's dirty. - r.twiddleCosets = precomputeTwiddlesCoset(r.Domain.Twiddles, r.Domain.FrMultiplicativeGen) + r.twiddleCosets = precomputeTwiddlesCoset(r.Domain.Generator, r.Domain.FrMultiplicativeGen) } // filling A diff --git a/ecc/bls12-377/fr/sis/sis_fft.go b/ecc/bls12-377/fr/sis/sis_fft.go index 891b7e677b..ae351d5ec1 100644 --- a/ecc/bls12-377/fr/sis/sis_fft.go +++ b/ecc/bls12-377/fr/sis/sis_fft.go @@ -18,6 +18,7 @@ package sis import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "math/big" ) // fft64 is generated by gnark-crypto and contains the unrolled code for FFT (DIF) on 64 elements @@ -413,82 +414,154 @@ func fft64(a []fr.Element, twiddlesCoset []fr.Element) { // precomputeTwiddlesCoset precomputes twiddlesCoset from twiddles and coset table // it then return all elements in the correct order for the unrolled FFT. -func precomputeTwiddlesCoset(twiddles [][]fr.Element, shifter fr.Element) []fr.Element { - r := make([][]fr.Element, len(twiddles)) - for i := 0; i < len(twiddles); i++ { - r[i] = make([]fr.Element, len(twiddles[i])) - s := shifter - for k := 0; k < i; k++ { - s.Mul(&s, &s) - } - for j := 0; j < len(twiddles[i]); j++ { - r[i][j].Mul(&twiddles[i][j], &s) - } - } - toReturn := make([]fr.Element, 0, 63) +func precomputeTwiddlesCoset(generator, shifter fr.Element) []fr.Element { + toReturn := make([]fr.Element, 63) + var r, s fr.Element + e := new(big.Int) - toReturn = append(toReturn, r[5][0]) - toReturn = append(toReturn, r[4][0]) - toReturn = append(toReturn, r[4][1]) - toReturn = append(toReturn, r[3][0]) - toReturn = append(toReturn, r[3][2]) - toReturn = append(toReturn, r[3][1]) - toReturn = append(toReturn, r[3][3]) - toReturn = append(toReturn, r[2][0]) - toReturn = append(toReturn, r[2][4]) - toReturn = append(toReturn, r[2][2]) - toReturn = append(toReturn, r[2][6]) - toReturn = append(toReturn, r[2][1]) - toReturn = append(toReturn, r[2][5]) - toReturn = append(toReturn, r[2][3]) - toReturn = append(toReturn, r[2][7]) - toReturn = append(toReturn, r[1][0]) - toReturn = append(toReturn, r[1][8]) - toReturn = append(toReturn, r[1][4]) - toReturn = append(toReturn, r[1][12]) - toReturn = append(toReturn, r[1][2]) - toReturn = append(toReturn, r[1][10]) - toReturn = append(toReturn, r[1][6]) - toReturn = append(toReturn, r[1][14]) - toReturn = append(toReturn, r[1][1]) - toReturn = append(toReturn, r[1][9]) - toReturn = append(toReturn, r[1][5]) - toReturn = append(toReturn, r[1][13]) - toReturn = append(toReturn, r[1][3]) - toReturn = append(toReturn, r[1][11]) - toReturn = append(toReturn, r[1][7]) - toReturn = append(toReturn, r[1][15]) - toReturn = append(toReturn, r[0][0]) - toReturn = append(toReturn, r[0][16]) - toReturn = append(toReturn, r[0][8]) - toReturn = append(toReturn, r[0][24]) - toReturn = append(toReturn, r[0][4]) - toReturn = append(toReturn, r[0][20]) - toReturn = append(toReturn, r[0][12]) - toReturn = append(toReturn, r[0][28]) - toReturn = append(toReturn, r[0][2]) - toReturn = append(toReturn, r[0][18]) - toReturn = append(toReturn, r[0][10]) - toReturn = append(toReturn, r[0][26]) - toReturn = append(toReturn, r[0][6]) - toReturn = append(toReturn, r[0][22]) - toReturn = append(toReturn, r[0][14]) - toReturn = append(toReturn, r[0][30]) - toReturn = append(toReturn, r[0][1]) - toReturn = append(toReturn, r[0][17]) - toReturn = append(toReturn, r[0][9]) - toReturn = append(toReturn, r[0][25]) - toReturn = append(toReturn, r[0][5]) - toReturn = append(toReturn, r[0][21]) - toReturn = append(toReturn, r[0][13]) - toReturn = append(toReturn, r[0][29]) - toReturn = append(toReturn, r[0][3]) - toReturn = append(toReturn, r[0][19]) - toReturn = append(toReturn, r[0][11]) - toReturn = append(toReturn, r[0][27]) - toReturn = append(toReturn, r[0][7]) - toReturn = append(toReturn, r[0][23]) - toReturn = append(toReturn, r[0][15]) - toReturn = append(toReturn, r[0][31]) + s = shifter + for k := 0; k < 5; k++ { + s.Square(&s) + } + toReturn[0] = s + s = shifter + for k := 0; k < 4; k++ { + s.Square(&s) + } + toReturn[1] = s + r.Exp(generator, e.SetUint64(uint64(1<<4*1))) + toReturn[2].Mul(&r, &s) + s = shifter + for k := 0; k < 3; k++ { + s.Square(&s) + } + toReturn[3] = s + r.Exp(generator, e.SetUint64(uint64(1<<3*2))) + toReturn[4].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<3*1))) + toReturn[5].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<3*3))) + toReturn[6].Mul(&r, &s) + s = shifter + for k := 0; k < 2; k++ { + s.Square(&s) + } + toReturn[7] = s + r.Exp(generator, e.SetUint64(uint64(1<<2*4))) + toReturn[8].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*2))) + toReturn[9].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*6))) + toReturn[10].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*1))) + toReturn[11].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*5))) + toReturn[12].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*3))) + toReturn[13].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*7))) + toReturn[14].Mul(&r, &s) + s = shifter + for k := 0; k < 1; k++ { + s.Square(&s) + } + toReturn[15] = s + r.Exp(generator, e.SetUint64(uint64(1<<1*8))) + toReturn[16].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*4))) + toReturn[17].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*12))) + toReturn[18].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*2))) + toReturn[19].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*10))) + toReturn[20].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*6))) + toReturn[21].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*14))) + toReturn[22].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*1))) + toReturn[23].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*9))) + toReturn[24].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*5))) + toReturn[25].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*13))) + toReturn[26].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*3))) + toReturn[27].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*11))) + toReturn[28].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*7))) + toReturn[29].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*15))) + toReturn[30].Mul(&r, &s) + s = shifter + for k := 0; k < 0; k++ { + s.Square(&s) + } + toReturn[31] = s + r.Exp(generator, e.SetUint64(uint64(1<<0*16))) + toReturn[32].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*8))) + toReturn[33].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*24))) + toReturn[34].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*4))) + toReturn[35].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*20))) + toReturn[36].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*12))) + toReturn[37].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*28))) + toReturn[38].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*2))) + toReturn[39].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*18))) + toReturn[40].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*10))) + toReturn[41].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*26))) + toReturn[42].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*6))) + toReturn[43].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*22))) + toReturn[44].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*14))) + toReturn[45].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*30))) + toReturn[46].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*1))) + toReturn[47].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*17))) + toReturn[48].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*9))) + toReturn[49].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*25))) + toReturn[50].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*5))) + toReturn[51].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*21))) + toReturn[52].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*13))) + toReturn[53].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*29))) + toReturn[54].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*3))) + toReturn[55].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*19))) + toReturn[56].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*11))) + toReturn[57].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*27))) + toReturn[58].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*7))) + toReturn[59].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*23))) + toReturn[60].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*15))) + toReturn[61].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*31))) + toReturn[62].Mul(&r, &s) return toReturn } diff --git a/ecc/bls12-377/fr/sis/sis_test.go b/ecc/bls12-377/fr/sis/sis_test.go index ddccf9222d..e15626f8e1 100644 --- a/ecc/bls12-377/fr/sis/sis_test.go +++ b/ecc/bls12-377/fr/sis/sis_test.go @@ -142,7 +142,7 @@ func TestMulMod(t *testing.T) { // and random. var shift fr.Element shift.SetString("19540430494807482326159819597004422086093766032135589407132600596362845576832") - domain := fft.NewDomain(uint64(size), shift) + domain := fft.NewDomain(uint64(size), fft.WithShift(shift)) // mul mod domain.FFT(p, fft.DIF, fft.OnCoset()) @@ -415,7 +415,7 @@ func TestUnrolledFFT(t *testing.T) { const size = 64 assert := require.New(t) - domain := fft.NewDomain(size, shift) + domain := fft.NewDomain(size, fft.WithShift(shift)) k1 := make([]fr.Element, size) for i := 0; i < size; i++ { @@ -428,7 +428,7 @@ func TestUnrolledFFT(t *testing.T) { domain.FFT(k1, fft.DIF, fft.OnCoset(), fft.WithNbTasks(1)) // unrolled FFT - twiddlesCoset := precomputeTwiddlesCoset(domain.Twiddles, domain.FrMultiplicativeGen) + twiddlesCoset := precomputeTwiddlesCoset(domain.Generator, domain.FrMultiplicativeGen) fft64(k2, twiddlesCoset) // compare results diff --git a/ecc/bn254/fr/sis/sis.go b/ecc/bn254/fr/sis/sis.go index 6ffdc5e999..dce215b141 100644 --- a/ecc/bn254/fr/sis/sis.go +++ b/ecc/bn254/fr/sis/sis.go @@ -129,7 +129,7 @@ func NewRSis(seed int64, logTwoDegree, logTwoBound, maxNbElementsToHash int) (*R } if r.LogTwoBound == 8 && r.Degree == 64 { // TODO @gbotrel fixme, that's dirty. - r.twiddleCosets = precomputeTwiddlesCoset(r.Domain.Twiddles, r.Domain.FrMultiplicativeGen) + r.twiddleCosets = precomputeTwiddlesCoset(r.Domain.Generator, r.Domain.FrMultiplicativeGen) } // filling A diff --git a/ecc/bn254/fr/sis/sis_fft.go b/ecc/bn254/fr/sis/sis_fft.go index 336805ebe3..70b4d32d9d 100644 --- a/ecc/bn254/fr/sis/sis_fft.go +++ b/ecc/bn254/fr/sis/sis_fft.go @@ -18,6 +18,7 @@ package sis import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "math/big" ) // fft64 is generated by gnark-crypto and contains the unrolled code for FFT (DIF) on 64 elements @@ -413,82 +414,154 @@ func fft64(a []fr.Element, twiddlesCoset []fr.Element) { // precomputeTwiddlesCoset precomputes twiddlesCoset from twiddles and coset table // it then return all elements in the correct order for the unrolled FFT. -func precomputeTwiddlesCoset(twiddles [][]fr.Element, shifter fr.Element) []fr.Element { - r := make([][]fr.Element, len(twiddles)) - for i := 0; i < len(twiddles); i++ { - r[i] = make([]fr.Element, len(twiddles[i])) - s := shifter - for k := 0; k < i; k++ { - s.Mul(&s, &s) - } - for j := 0; j < len(twiddles[i]); j++ { - r[i][j].Mul(&twiddles[i][j], &s) - } - } - toReturn := make([]fr.Element, 0, 63) +func precomputeTwiddlesCoset(generator, shifter fr.Element) []fr.Element { + toReturn := make([]fr.Element, 63) + var r, s fr.Element + e := new(big.Int) - toReturn = append(toReturn, r[5][0]) - toReturn = append(toReturn, r[4][0]) - toReturn = append(toReturn, r[4][1]) - toReturn = append(toReturn, r[3][0]) - toReturn = append(toReturn, r[3][2]) - toReturn = append(toReturn, r[3][1]) - toReturn = append(toReturn, r[3][3]) - toReturn = append(toReturn, r[2][0]) - toReturn = append(toReturn, r[2][4]) - toReturn = append(toReturn, r[2][2]) - toReturn = append(toReturn, r[2][6]) - toReturn = append(toReturn, r[2][1]) - toReturn = append(toReturn, r[2][5]) - toReturn = append(toReturn, r[2][3]) - toReturn = append(toReturn, r[2][7]) - toReturn = append(toReturn, r[1][0]) - toReturn = append(toReturn, r[1][8]) - toReturn = append(toReturn, r[1][4]) - toReturn = append(toReturn, r[1][12]) - toReturn = append(toReturn, r[1][2]) - toReturn = append(toReturn, r[1][10]) - toReturn = append(toReturn, r[1][6]) - toReturn = append(toReturn, r[1][14]) - toReturn = append(toReturn, r[1][1]) - toReturn = append(toReturn, r[1][9]) - toReturn = append(toReturn, r[1][5]) - toReturn = append(toReturn, r[1][13]) - toReturn = append(toReturn, r[1][3]) - toReturn = append(toReturn, r[1][11]) - toReturn = append(toReturn, r[1][7]) - toReturn = append(toReturn, r[1][15]) - toReturn = append(toReturn, r[0][0]) - toReturn = append(toReturn, r[0][16]) - toReturn = append(toReturn, r[0][8]) - toReturn = append(toReturn, r[0][24]) - toReturn = append(toReturn, r[0][4]) - toReturn = append(toReturn, r[0][20]) - toReturn = append(toReturn, r[0][12]) - toReturn = append(toReturn, r[0][28]) - toReturn = append(toReturn, r[0][2]) - toReturn = append(toReturn, r[0][18]) - toReturn = append(toReturn, r[0][10]) - toReturn = append(toReturn, r[0][26]) - toReturn = append(toReturn, r[0][6]) - toReturn = append(toReturn, r[0][22]) - toReturn = append(toReturn, r[0][14]) - toReturn = append(toReturn, r[0][30]) - toReturn = append(toReturn, r[0][1]) - toReturn = append(toReturn, r[0][17]) - toReturn = append(toReturn, r[0][9]) - toReturn = append(toReturn, r[0][25]) - toReturn = append(toReturn, r[0][5]) - toReturn = append(toReturn, r[0][21]) - toReturn = append(toReturn, r[0][13]) - toReturn = append(toReturn, r[0][29]) - toReturn = append(toReturn, r[0][3]) - toReturn = append(toReturn, r[0][19]) - toReturn = append(toReturn, r[0][11]) - toReturn = append(toReturn, r[0][27]) - toReturn = append(toReturn, r[0][7]) - toReturn = append(toReturn, r[0][23]) - toReturn = append(toReturn, r[0][15]) - toReturn = append(toReturn, r[0][31]) + s = shifter + for k := 0; k < 5; k++ { + s.Square(&s) + } + toReturn[0] = s + s = shifter + for k := 0; k < 4; k++ { + s.Square(&s) + } + toReturn[1] = s + r.Exp(generator, e.SetUint64(uint64(1<<4*1))) + toReturn[2].Mul(&r, &s) + s = shifter + for k := 0; k < 3; k++ { + s.Square(&s) + } + toReturn[3] = s + r.Exp(generator, e.SetUint64(uint64(1<<3*2))) + toReturn[4].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<3*1))) + toReturn[5].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<3*3))) + toReturn[6].Mul(&r, &s) + s = shifter + for k := 0; k < 2; k++ { + s.Square(&s) + } + toReturn[7] = s + r.Exp(generator, e.SetUint64(uint64(1<<2*4))) + toReturn[8].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*2))) + toReturn[9].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*6))) + toReturn[10].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*1))) + toReturn[11].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*5))) + toReturn[12].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*3))) + toReturn[13].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*7))) + toReturn[14].Mul(&r, &s) + s = shifter + for k := 0; k < 1; k++ { + s.Square(&s) + } + toReturn[15] = s + r.Exp(generator, e.SetUint64(uint64(1<<1*8))) + toReturn[16].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*4))) + toReturn[17].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*12))) + toReturn[18].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*2))) + toReturn[19].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*10))) + toReturn[20].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*6))) + toReturn[21].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*14))) + toReturn[22].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*1))) + toReturn[23].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*9))) + toReturn[24].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*5))) + toReturn[25].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*13))) + toReturn[26].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*3))) + toReturn[27].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*11))) + toReturn[28].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*7))) + toReturn[29].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*15))) + toReturn[30].Mul(&r, &s) + s = shifter + for k := 0; k < 0; k++ { + s.Square(&s) + } + toReturn[31] = s + r.Exp(generator, e.SetUint64(uint64(1<<0*16))) + toReturn[32].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*8))) + toReturn[33].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*24))) + toReturn[34].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*4))) + toReturn[35].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*20))) + toReturn[36].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*12))) + toReturn[37].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*28))) + toReturn[38].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*2))) + toReturn[39].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*18))) + toReturn[40].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*10))) + toReturn[41].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*26))) + toReturn[42].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*6))) + toReturn[43].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*22))) + toReturn[44].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*14))) + toReturn[45].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*30))) + toReturn[46].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*1))) + toReturn[47].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*17))) + toReturn[48].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*9))) + toReturn[49].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*25))) + toReturn[50].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*5))) + toReturn[51].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*21))) + toReturn[52].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*13))) + toReturn[53].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*29))) + toReturn[54].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*3))) + toReturn[55].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*19))) + toReturn[56].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*11))) + toReturn[57].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*27))) + toReturn[58].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*7))) + toReturn[59].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*23))) + toReturn[60].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*15))) + toReturn[61].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*31))) + toReturn[62].Mul(&r, &s) return toReturn } diff --git a/ecc/bn254/fr/sis/sis_test.go b/ecc/bn254/fr/sis/sis_test.go index 9cbba969e8..cde678d214 100644 --- a/ecc/bn254/fr/sis/sis_test.go +++ b/ecc/bn254/fr/sis/sis_test.go @@ -427,7 +427,7 @@ func TestUnrolledFFT(t *testing.T) { domain.FFT(k1, fft.DIF, fft.OnCoset(), fft.WithNbTasks(1)) // unrolled FFT - twiddlesCoset := precomputeTwiddlesCoset(domain.Twiddles, domain.FrMultiplicativeGen) + twiddlesCoset := precomputeTwiddlesCoset(domain.Generator, domain.FrMultiplicativeGen) fft64(k2, twiddlesCoset) // compare results diff --git a/internal/generator/main.go b/internal/generator/main.go index 3c9903b8a2..389f96c2e0 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -92,7 +92,7 @@ func main() { // generate fft on fr assertNoError(fft.Generate(conf, filepath.Join(curveDir, "fr", "fft"), bgen)) - if conf.Equal(config.BN254) { + if conf.Equal(config.BN254) || conf.Equal(config.BLS12_377) { assertNoError(sis.Generate(conf, filepath.Join(curveDir, "fr", "sis"), bgen)) } diff --git a/internal/generator/sis/template/fft.go.tmpl b/internal/generator/sis/template/fft.go.tmpl index 36cdace6ce..91938f81d0 100644 --- a/internal/generator/sis/template/fft.go.tmpl +++ b/internal/generator/sis/template/fft.go.tmpl @@ -1,5 +1,6 @@ import ( - "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/{{ .Name }}/fr" + "math/big" ) // fft64 is generated by gnark-crypto and contains the unrolled code for FFT (DIF) on 64 elements @@ -48,28 +49,31 @@ func fft64(a []fr.Element, twiddlesCoset []fr.Element) { // precomputeTwiddlesCoset precomputes twiddlesCoset from twiddles and coset table // it then return all elements in the correct order for the unrolled FFT. -func precomputeTwiddlesCoset(twiddles [][]fr.Element, shifter fr.Element) []fr.Element { - r := make([][]fr.Element, len(twiddles)) - for i := 0; i < len(twiddles); i++ { - r[i] = make([]fr.Element, len(twiddles[i])) - s := shifter - for k:=0; k