Skip to content

Commit

Permalink
feat: direct multivariate polynomial evaluation in non-native (#1299)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivokub authored Nov 22, 2024
1 parent 7512178 commit 32e3404
Show file tree
Hide file tree
Showing 5 changed files with 722 additions and 56 deletions.
11 changes: 11 additions & 0 deletions std/math/emulated/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ identity at a random point:
where e(X) is a polynomial used for carrying the overflows of the left- and
right-hand side of the above equation.
This approach can be extended to the case when the left hand side is not a
simple multiplication, but rather any evaluation of a multivariate polynomial.
So in essence we can check the correctness of any polynomial evaluation modulo
r:
F(x_1, x_2, ..., x_n) = c + z*r
through the following identity:
F(x_1(X), x_2(X), ..., x_n(X)) = c(X) + z(X) * r(X) + (2^w' - X) e(X).
# Subtraction
We perform subtraction limb-wise between the elements x and y. However, we have
Expand Down
148 changes: 148 additions & 0 deletions std/math/emulated/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1277,3 +1277,151 @@ func TestIsZeroEdgeCases(t *testing.T) {
testIsZeroEdgeCases[BN254Fr](t)
testIsZeroEdgeCases[emparams.Mod1e512](t)
}

type PolyEvalCircuit[T FieldParams] struct {
Inputs []Element[T]
TermsByIndices [][]int
Coeffs []int
Expected Element[T]
}

func (c *PolyEvalCircuit[T]) Define(api frontend.API) error {
// withEval
f, err := NewField[T](api)
if err != nil {
return err
}
// reconstruct the terms from the inputs and the indices
terms := make([][]*Element[T], len(c.TermsByIndices))
for i := range terms {
terms[i] = make([]*Element[T], len(c.TermsByIndices[i]))
for j := range terms[i] {
terms[i][j] = &c.Inputs[c.TermsByIndices[i][j]]
}
}
resEval := f.Eval(terms, c.Coeffs)

// withSum
addTerms := make([]*Element[T], len(c.TermsByIndices))
for i, term := range c.TermsByIndices {
termVal := f.One()
for j := range term {
termVal = f.Mul(termVal, &c.Inputs[term[j]])
}
addTerms[i] = f.MulConst(termVal, big.NewInt(int64(c.Coeffs[i])))
}
resSum := f.Sum(addTerms...)

// mul no reduce
addTerms2 := make([]*Element[T], len(c.TermsByIndices))
for i, term := range c.TermsByIndices {
termVal := f.One()
for j := range term {
termVal = f.MulNoReduce(termVal, &c.Inputs[term[j]])
}
addTerms2[i] = f.MulConst(termVal, big.NewInt(int64(c.Coeffs[i])))
}
resNoReduce := f.Sum(addTerms2...)
resReduced := f.Reduce(resNoReduce)

// assertions
f.AssertIsEqual(resEval, &c.Expected)
f.AssertIsEqual(resSum, &c.Expected)
f.AssertIsEqual(resNoReduce, &c.Expected)
f.AssertIsEqual(resReduced, &c.Expected)

return nil
}

func TestPolyEval(t *testing.T) {
testPolyEval[Goldilocks](t)
testPolyEval[BN254Fr](t)
testPolyEval[emparams.Mod1e512](t)
}

func testPolyEval[T FieldParams](t *testing.T) {
const nbInputs = 2
assert := test.NewAssert(t)
var fp T
var err error
// 2*x^3 + 3*x^2 y + 4*x y^2 + 5*y^3 assuming we have inputs w=[x, y], then
// we can represent by the indices of the inputs:
// 2*x^3 + 3*x^2 y + 4*x y^2 + 5*y^3 -> 2*x*x*x + 3*x*x*y + 4*x*y*y + 5*y*y*y -> 2*w[0]*w[0]*w[0] + 3*w[0]*w[0]*w[1] + 4*w[0]*w[1]*w[1] + 5*w[1]*w[1]*w[1]
// the following variable gives the indices of the inputs. For givin the
// circuit this is better as then we can easily reference to the inputs by
// index.
toMulByIndex := [][]int{{0, 0, 0}, {0, 0, 1}, {0, 1, 1}, {1, 1, 1}}
coefficients := []int{2, 3, 4, 5}
inputs := make([]*big.Int, nbInputs)
assignmentInput := make([]Element[T], nbInputs)
for i := range inputs {
inputs[i], err = rand.Int(rand.Reader, fp.Modulus())
assert.NoError(err)
}
for i := range inputs {
assignmentInput[i] = ValueOf[T](inputs[i])
}
expected := new(big.Int)
for i, term := range toMulByIndex {
termVal := new(big.Int).SetInt64(int64(coefficients[i]))
for j := range term {
termVal.Mul(termVal, inputs[term[j]])
}
expected.Add(expected, termVal)
}
expected.Mod(expected, fp.Modulus())

assignment := &PolyEvalCircuit[T]{
Inputs: assignmentInput,
Expected: ValueOf[T](expected),
}
assert.CheckCircuit(&PolyEvalCircuit[T]{Inputs: make([]Element[T], nbInputs), TermsByIndices: toMulByIndex, Coeffs: coefficients}, test.WithValidAssignment(assignment))
}

type PolyEvalNegativeCoefficient[T FieldParams] struct {
Inputs []Element[T]
Res Element[T]
}

func (c *PolyEvalNegativeCoefficient[T]) Define(api frontend.API) error {
f, err := NewField[T](api)
if err != nil {
return err
}
// x - y
coefficients := []int{1, -1}
res := f.Eval([][]*Element[T]{{&c.Inputs[0]}, {&c.Inputs[1]}}, coefficients)
f.AssertIsEqual(res, &c.Res)
return nil
}

func TestPolyEvalNegativeCoefficient(t *testing.T) {
testPolyEvalNegativeCoefficient[Goldilocks](t)
testPolyEvalNegativeCoefficient[BN254Fr](t)
testPolyEvalNegativeCoefficient[emparams.Mod1e512](t)
}

func testPolyEvalNegativeCoefficient[T FieldParams](t *testing.T) {
t.Skip("not implemented yet")
assert := test.NewAssert(t)
var fp T
fmt.Println("modulus", fp.Modulus())
var err error
const nbInputs = 2
inputs := make([]*big.Int, nbInputs)
assignmentInput := make([]Element[T], nbInputs)
for i := range inputs {
inputs[i], err = rand.Int(rand.Reader, fp.Modulus())
assert.NoError(err)
}
for i := range inputs {
fmt.Println("input", i, inputs[i])
assignmentInput[i] = ValueOf[T](inputs[i])
}
expected := new(big.Int).Sub(inputs[0], inputs[1])
expected.Mod(expected, fp.Modulus())
fmt.Println("expected", expected)
assignment := &PolyEvalNegativeCoefficient[T]{Inputs: assignmentInput, Res: ValueOf[T](expected)}
err = test.IsSolved(&PolyEvalNegativeCoefficient[T]{Inputs: make([]Element[T], nbInputs)}, assignment, testCurve.ScalarField())
assert.NoError(err)
}
16 changes: 14 additions & 2 deletions std/math/emulated/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type Field[T FieldParams] struct {
constrainedLimbs map[[16]byte]struct{}
checker frontend.Rangechecker

mulChecks []mulCheck[T]
deferredChecks []deferredChecker
}

type ctxKey[T FieldParams] struct{}
Expand Down Expand Up @@ -103,7 +103,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb())
}

native.Compiler().Defer(f.performMulChecks)
native.Compiler().Defer(f.performDeferredChecks)
if storer, ok := native.(kvstore.Store); ok {
storer.SetKeyValue(ctxKey[T]{}, f)
}
Expand Down Expand Up @@ -282,3 +282,15 @@ func max[T constraints.Ordered](a ...T) T {
}
return m
}

func sum[T constraints.Ordered](a ...T) T {
if len(a) == 0 {
var f T
return f
}
m := a[0]
for _, v := range a[1:] {
m += v
}
return m
}
Loading

0 comments on commit 32e3404

Please sign in to comment.