Skip to content

Commit

Permalink
Merge pull request #1049 from Consensys/perf/jointScalarMulGeneric
Browse files Browse the repository at this point in the history
perf(sw_emulated): optimize jointScalarMulGeneric
  • Loading branch information
yelhousni authored Feb 9, 2024
2 parents 9b8efda + 382bd8e commit 4c3ef85
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 10 deletions.
64 changes: 56 additions & 8 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,16 +664,64 @@ func (c *Curve[B, S]) jointScalarMul(p1, p2 *AffinePoint[B], s1, s2 *emulated.El
}
}

// jointScalarMulGeneric computes [s1]p1 + [s2]p2. It doesn't modify the inputs.
// jointScalarMulGeneric computes [s1]p1 + [s2]p2. It doesn't modify p1, p2 nor s1, s2.
//
// ⚠️ p1, p2 must not be (0,0) and s1, s2 must not be 0, unless [algopts.WithCompleteArithmetic] option is set.
// ⚠️ The scalars s1, s2 must be nonzero and the point p1, p2 different from (0,0), unless [algopts.WithCompleteArithmetic] option is set.
func (c *Curve[B, S]) jointScalarMulGeneric(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
res1 := c.scalarMulGeneric(p1, s1, opts...)
res2 := c.scalarMulGeneric(p2, s2, opts...)
return c.Add(res1, res2)
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
if cfg.CompleteArithmetic {
// TODO @yelhousni: optimize
res1 := c.scalarMulGeneric(p1, s1, opts...)
res2 := c.scalarMulGeneric(p2, s2, opts...)
return c.AddUnified(res1, res2)
} else {
return c.jointScalarMulGenericUnsafe(p1, p2, s1, s2)
}
}

// jointScalarMulGenericUnsafe computes [s1]p1 + [s2]p2 using Shamir's trick and returns it. It doesn't modify p1, p2 nor s1, s2.
// ⚠️ The scalars must be nonzero and the points different from (0,0).
func (c *Curve[B, S]) jointScalarMulGenericUnsafe(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S]) *AffinePoint[B] {
var Acc, B1, p1Neg, p2Neg *AffinePoint[B]
p1Neg = c.Neg(p1)
p2Neg = c.Neg(p2)

// Acc = P1 + P2
Acc = c.Add(p1, p2)

s1bits := c.scalarApi.ToBits(s1)
s2bits := c.scalarApi.ToBits(s2)

var st S
nbits := st.Modulus().BitLen()

for i := nbits - 1; i > 0; i-- {
B1 = &AffinePoint[B]{
X: p1Neg.X,
Y: *c.baseApi.Select(s1bits[i], &p1.Y, &p1Neg.Y),
}
Acc = c.doubleAndAdd(Acc, B1)
B1 = &AffinePoint[B]{
X: p2Neg.X,
Y: *c.baseApi.Select(s2bits[i], &p2.Y, &p2Neg.Y),
}
Acc = c.Add(Acc, B1)

}

// i = 0
p1Neg = c.Add(p1Neg, Acc)
Acc = c.Select(s1bits[0], Acc, p1Neg)
p2Neg = c.Add(p2Neg, Acc)
Acc = c.Select(s2bits[0], Acc, p2Neg)

return Acc
}

// jointScalarMulGLV computes [s1]p1 + [s2]p2 using an endomorphism. It doesn't modify P, Q nor s.
// jointScalarMulGLV computes [s1]p1 + [s2]p2 using an endomorphism. It doesn't modify p1, p2 nor s1, s2.
//
// ⚠️ The scalars s1, s2 must be nonzero and the point p1, p2 different from (0,0), unless [algopts.WithCompleteArithmetic] option is set.
func (c *Curve[B, S]) jointScalarMulGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
Expand All @@ -691,8 +739,8 @@ func (c *Curve[B, S]) jointScalarMulGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated
}
}

// jointScalarMulGLVUnsafe computes [s]Q + [t]R using Shamir's trick with an efficient endomorphism and returns it. It doesn't modify P, Q nor s.
// ⚠️ The scalar s must be nonzero and the point Q different from (0,0), unless [algopts.WithCompleteArithmetic] option is set.
// jointScalarMulGLVUnsafe computes [s]Q + [t]R using Shamir's trick with an efficient endomorphism and returns it. It doesn't modify Q, R nor s, t.
// ⚠️ The scalars must be nonzero and the points different from (0,0).
func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulated.Element[S]) *AffinePoint[B] {
var st S
frModulus := c.scalarApi.Modulus()
Expand Down
171 changes: 169 additions & 2 deletions std/algebra/emulated/sw_emulated/point_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,40 @@ func TestJointScalarMul6(t *testing.T) {
assert.NoError(err)
}

func TestJointScalarMul4(t *testing.T) {
assert := test.NewAssert(t)
p256 := elliptic.P256()
s1, err := rand.Int(rand.Reader, p256.Params().N)
assert.NoError(err)
s2, err := rand.Int(rand.Reader, p256.Params().N)
assert.NoError(err)
p1x, p1y := p256.ScalarBaseMult(s1.Bytes())
p2x, p2y := p256.ScalarBaseMult(s2.Bytes())
resx, resy := p256.ScalarMult(p1x, p1y, s1.Bytes())
tmpx, tmpy := p256.ScalarMult(p2x, p2y, s2.Bytes())
resx, resy = p256.Add(resx, resy, tmpx, tmpy)

circuit := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{}
witness := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](s1),
S2: emulated.ValueOf[emulated.P256Fr](s2),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p1x),
Y: emulated.ValueOf[emulated.P256Fp](p1y),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p2x),
Y: emulated.ValueOf[emulated.P256Fp](p2y),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](resx),
Y: emulated.ValueOf[emulated.P256Fp](resy),
},
}
err = test.IsSolved(&circuit, &witness, testCurve.ScalarField())
assert.NoError(err)
}

type JointScalarMulEdgeCasesTest[T, S emulated.FieldParams] struct {
P1, P2, Q AffinePoint[T]
S1, S2 emulated.Element[S]
Expand All @@ -1415,12 +1449,11 @@ func TestJointScalarMulEdgeCases6(t *testing.T) {
s2 := new(big.Int)
r1.BigInt(s1)
r2.BigInt(s2)
var res, res1, res2, gen2, infinity bw6761.G1Affine
var res1, res2, gen2, infinity bw6761.G1Affine
_, _, gen1, _ := bw6761.Generators()
gen2.Double(&gen1)
res1.ScalarMultiplication(&gen1, s1)
res2.ScalarMultiplication(&gen2, s2)
res.Add(&res1, &res2)

circuit := JointScalarMulEdgeCasesTest[emulated.BW6761Fp, emulated.BW6761Fr]{}
// s1*(0,0) + s2*(0,0) == (0,0)
Expand Down Expand Up @@ -1544,6 +1577,140 @@ func TestJointScalarMulEdgeCases6(t *testing.T) {
assert.NoError(err)
}

func TestJointScalarMulEdgeCases4(t *testing.T) {
assert := test.NewAssert(t)
p256 := elliptic.P256()
s1, err := rand.Int(rand.Reader, p256.Params().N)
assert.NoError(err)
s2, err := rand.Int(rand.Reader, p256.Params().N)
assert.NoError(err)
p1x, p1y := p256.ScalarBaseMult(s1.Bytes())
p2x, p2y := p256.ScalarBaseMult(s2.Bytes())
res1x, res1y := p256.ScalarMult(p1x, p1y, s1.Bytes())
res2x, res2y := p256.ScalarMult(p2x, p2y, s2.Bytes())

circuit := JointScalarMulEdgeCasesTest[emulated.P256Fp, emulated.P256Fr]{}
// s1*(0,0) + s2*(0,0) == (0,0)
witness1 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](s1),
S2: emulated.ValueOf[emulated.P256Fr](s2),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
}
err = test.IsSolved(&circuit, &witness1, testCurve.ScalarField())
assert.NoError(err)

// s1*P + s2*(0,0) == s1*P
witness2 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](s1),
S2: emulated.ValueOf[emulated.P256Fr](s2),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p1x),
Y: emulated.ValueOf[emulated.P256Fp](p1y),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](res1x),
Y: emulated.ValueOf[emulated.P256Fp](res1y),
},
}
err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField())
assert.NoError(err)

// s1*(0,0) + s2*Q == s2*Q
witness3 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](s1),
S2: emulated.ValueOf[emulated.P256Fr](s2),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p2x),
Y: emulated.ValueOf[emulated.P256Fp](p2y),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](res2x),
Y: emulated.ValueOf[emulated.P256Fp](res2y),
},
}
err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField())
assert.NoError(err)

// 0*P + 0*Q == (0,0)
witness4 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](0),
S2: emulated.ValueOf[emulated.P256Fr](0),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p1x),
Y: emulated.ValueOf[emulated.P256Fp](p1y),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p2x),
Y: emulated.ValueOf[emulated.P256Fp](p2y),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
}
err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField())
assert.NoError(err)

// 0*P + s2*Q == s2*Q
witness5 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](0),
S2: emulated.ValueOf[emulated.P256Fr](s2),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p1x),
Y: emulated.ValueOf[emulated.P256Fp](p1y),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p2x),
Y: emulated.ValueOf[emulated.P256Fp](p2y),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](res2x),
Y: emulated.ValueOf[emulated.P256Fp](res2y),
},
}
err = test.IsSolved(&circuit, &witness5, testCurve.ScalarField())
assert.NoError(err)

// s1*P + 0*Q == s1*P
witness6 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](s1),
S2: emulated.ValueOf[emulated.P256Fr](0),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p1x),
Y: emulated.ValueOf[emulated.P256Fp](p1y),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p2x),
Y: emulated.ValueOf[emulated.P256Fp](p2y),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](res1x),
Y: emulated.ValueOf[emulated.P256Fp](res1y),
},
}
err = test.IsSolved(&circuit, &witness6, testCurve.ScalarField())
assert.NoError(err)
}

type MuxCircuitTest[T, S emulated.FieldParams] struct {
Selector frontend.Variable
Inputs [8]AffinePoint[T]
Expand Down

0 comments on commit 4c3ef85

Please sign in to comment.