Skip to content

Commit

Permalink
feat: add c api for proving sumcheck (PROOF-913) (#245)
Browse files Browse the repository at this point in the history
* add stub for sumcheck c api

* fill in c api

* fill in cbindings

* fill in sumcheck cbindings

* fill in sumcheck c-bindings

* fill in sumcheck transcript

* fill in computational backend

* fill in sumcheck api

* fill in sumcheck api

* fill in sumcheck api

* fill in sumcheck api

* fill in gpu backend

* fill in gpu backend

* fill in sumcheck api

* fill in sumcheck api

* fill in gpu backend

* drop dead code

* fill in sumcheck api

* fill in sumcheck api

* fill in sumcheck api

* fill in sumcheck api

* fill in sumcheck

* add assertion check

* fill in sumcheck test

* test cbindings

* fill in c api

* reformat

* fix include

* fill in sumcheck test

* fill in cpu backend

* drop dead code

* add assertion check

* fill in docs

* doc

* reword

* add note

* fix rust error

* fill in test descriptor

* add test
  • Loading branch information
rnburn authored Feb 14, 2025
1 parent 5ef238b commit 0e59204
Show file tree
Hide file tree
Showing 22 changed files with 659 additions and 1 deletion.
20 changes: 20 additions & 0 deletions cbindings/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cuda_library(
":get_one_commit",
":inner_product_proof",
":pedersen",
":sumcheck",
],
alwayslink = 1,
)
Expand Down Expand Up @@ -214,3 +215,22 @@ sxt_cc_component(
],
alwayslink = 1,
)

sxt_cc_component(
name = "sumcheck",
impl_deps = [
":backend",
],
test_deps = [
":backend",
"//sxt/base/test:unit_test",
"//sxt/proof/sumcheck:reference_transcript",
"//sxt/scalar25/operation:overload",
"//sxt/scalar25/type:element",
"//sxt/scalar25/type:literal",
],
deps = [
":blitzar_api",
],
alwayslink = 1,
)
77 changes: 77 additions & 0 deletions cbindings/blitzar_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ extern "C" {
#define SXT_CURVE_BN_254 2
#define SXT_CURVE_GRUMPKIN 3

#define SXT_FIELD_SCALAR255 0

/** config struct to hold the chosen backend */
struct sxt_config {
int backend;
Expand Down Expand Up @@ -127,6 +129,56 @@ struct sxt_sequence_descriptor {
int is_signed;
};

/** Describe inputs to a sumcheck proof.
*
* The sumcheck proof is constructed using a polynomial of the form
*
* sum_i^num_products {mult_i x prod_j^product_length_i f_j(X1, ..., Xr)}
*
* where f_j(X1, ..., Xr) denotes a multilinear extension of r variables.
*
* Pointer types are dependent on the field type as specified
* by the field id.
*
* We will let FIELD denote the field type when describing fields.
*/
struct sumcheck_descriptor {
// multilinear extensions referenced in a sumcheck proof
//
// mles should point to a n x num_mles column major matrix of
// type FIELD
const void* mles;

// Describe each product of the sumcheck polynomial. product_table should
// point to an num_products array with entries of type
//
// struct {
// FIELD multiplier
// unsigned product_length
// }
const void* product_table;

// MLE indices for the entries in product_table
const unsigned* product_terms;

// The length of each MLE
unsigned n;

// The number of distinct MLEs
unsigned num_mles;

// The number of products in the sumcheck polynomial
unsigned num_products;

/// The total number of total product terms in the sumcheck polynomial
// sum_i product_length_i
unsigned num_product_terms;

// The degree of the round polynomial for sumcheck
// max_i product_length_i
unsigned round_degree;
};

/** resources for multiexponentiations with pre-specified generators */
struct sxt_multiexp_handle;

Expand Down Expand Up @@ -689,6 +741,31 @@ void sxt_fixed_vlen_multiexponentiation(void* res, const struct sxt_multiexp_han
const unsigned* output_bit_table,
const unsigned* output_lengths, unsigned num_outputs,
const uint8_t* scalars);

/**
* Construct a sumcheck proof for a polynomial
*
* sum_i^num_products {mult_i x prod_j^product_length_i f_j(X1, ..., Xr)}
*
* input:
* field_id identifies the field of the sumcheck polynomial
* descriptor describes the sumcheck polynomial
* transcript_callback points to a function with signature
* void (FIELD* r, void* context, FIELD* polynomial, unsigned polynomial_length)
* and will be invoked each sumcheck round to draw a random FIELD entry.
*
* output:
* polynomials points to an (round_degree + 1) x (num_variables) column major matrix of type FIELD
* and will be filled with the sumcheck round polynomials upon completion.
*
* evaluation_point points to a num_variables array of type FIELD and will contain
* the evaluation point for sumcheck upon completion as set by the transcript_callback.
*
*/
void sxt_prove_sumcheck(void* polynomials, void* evaluation_point, unsigned field_id,
const struct sumcheck_descriptor* descriptor, void* transcript_callback,
void* transcript_context);

#ifdef __cplusplus
} // extern "C"
#endif
35 changes: 35 additions & 0 deletions cbindings/sumcheck.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cbindings/sumcheck.h"

#include "cbindings/backend.h"

using namespace sxt;

//--------------------------------------------------------------------------------------------------
// sxt_prove_sumcheck
//--------------------------------------------------------------------------------------------------
void sxt_prove_sumcheck(void* polynomials, void* evaluation_point, unsigned field_id,
const sumcheck_descriptor* descriptor, void* transcript_callback,
void* transcript_context) {
auto backend = cbn::get_backend();
static_assert(sizeof(sumcheck_descriptor) == sizeof(cbnb::sumcheck_descriptor),
"sumcheck descriptors must be binary compatible");
backend->prove_sumcheck(polynomials, evaluation_point, field_id,
*reinterpret_cast<const cbnb::sumcheck_descriptor*>(descriptor),
transcript_callback, transcript_context);
}
19 changes: 19 additions & 0 deletions cbindings/sumcheck.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "cbindings/blitzar_api.h"
97 changes: 97 additions & 0 deletions cbindings/sumcheck.t.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cbindings/sumcheck.h"

#include <vector>

#include "cbindings/backend.h"
#include "sxt/base/test/unit_test.h"
#include "sxt/proof/sumcheck/reference_transcript.h"
#include "sxt/scalar25/operation/overload.h"
#include "sxt/scalar25/type/element.h"
#include "sxt/scalar25/type/literal.h"

using namespace sxt;
using s25t::operator""_s25;

TEST_CASE("we can create sumcheck proofs") {
prft::transcript base_transcript{"abc"};
prfsk::reference_transcript transcript{base_transcript};

std::vector<s25t::element> polynomials(2);
std::vector<s25t::element> evaluation_point(1);
std::vector<s25t::element> mles = {
0x8_s25,
0x3_s25,
};
std::vector<std::pair<s25t::element, unsigned>> product_table = {
{0x1_s25, 1},
};
std::vector<unsigned> product_terms = {0};
sumcheck_descriptor descriptor{
.mles = mles.data(),
.product_table = product_table.data(),
.product_terms = product_terms.data(),
.n = 2,
.num_mles = 1,
.num_products = 1,
.num_product_terms = 1,
.round_degree = 1,
};

auto f = [](s25t::element* r, void* context, const s25t::element* polynomial,
unsigned polynomial_len) noexcept {
static_cast<prfsk::reference_transcript*>(context)->round_challenge(
*r, {polynomial, polynomial_len});
};

SECTION("we can prove a sum with n=2 on GPU") {
cbn::reset_backend_for_testing();
const sxt_config config = {SXT_GPU_BACKEND, 0};
REQUIRE(sxt_init(&config) == 0);

sxt_prove_sumcheck(polynomials.data(), evaluation_point.data(), SXT_FIELD_SCALAR255,
&descriptor, reinterpret_cast<void*>(+f), &transcript);
REQUIRE(polynomials[0] == mles[0]);
REQUIRE(polynomials[1] == mles[1] - mles[0]);
{
prft::transcript base_transcript_p{"abc"};
prfsk::reference_transcript transcript_p{base_transcript_p};
s25t::element r;
transcript_p.round_challenge(r, polynomials);
REQUIRE(evaluation_point[0] == r);
}
}

SECTION("we can prove a sum with n=2 on CPU") {
cbn::reset_backend_for_testing();
const sxt_config config = {SXT_CPU_BACKEND, 0};
REQUIRE(sxt_init(&config) == 0);

sxt_prove_sumcheck(polynomials.data(), evaluation_point.data(), SXT_FIELD_SCALAR255,
&descriptor, reinterpret_cast<void*>(+f), &transcript);
REQUIRE(polynomials[0] == mles[0]);
REQUIRE(polynomials[1] == mles[1] - mles[0]);
{
prft::transcript base_transcript_p{"abc"};
prfsk::reference_transcript transcript_p{base_transcript_p};
s25t::element r;
transcript_p.round_challenge(r, polynomials);
REQUIRE(evaluation_point[0] == r);
}
}
}
23 changes: 22 additions & 1 deletion sxt/cbindings/backend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ load(
"sxt_cc_component",
)

sxt_cc_component(
name = "callback_sumcheck_transcript",
impl_deps = [
],
with_test = False,
deps = [
"//sxt/proof/sumcheck:sumcheck_transcript",
],
)

sxt_cc_component(
name = "computational_backend",
impl_deps = [
Expand All @@ -12,7 +22,8 @@ sxt_cc_component(
deps = [
":computational_backend_utility",
"//sxt/base/container:span",
"//sxt/cbindings/base:curve_id",
"//sxt/cbindings/base:curve_id_utility",
"//sxt/cbindings/base:sumcheck_descriptor",
"//sxt/curve21/type:element_p3",
"//sxt/multiexp/base:exponent_sequence",
"//sxt/multiexp/pippenger2:partition_table_accessor_base",
Expand All @@ -38,13 +49,16 @@ sxt_cc_component(
name = "gpu_backend",
impl_deps = [
":computational_backend_utility",
":callback_sumcheck_transcript",
"//sxt/base/error:assert",
"//sxt/base/system:directory_recorder",
"//sxt/base/system:file_io",
"//sxt/base/num:divide_up",
"//sxt/base/num:ceil_log2",
"//sxt/proof/transcript:transcript",
"//sxt/scalar25/type:element",
"//sxt/cbindings/base:curve_id_utility",
"//sxt/cbindings/base:field_id_utility",
"//sxt/curve_bng1/operation:add",
"//sxt/curve_bng1/operation:double",
"//sxt/curve_bng1/operation:neg",
Expand Down Expand Up @@ -83,6 +97,8 @@ sxt_cc_component(
"//sxt/proof/inner_product:proof_descriptor",
"//sxt/proof/inner_product:proof_computation",
"//sxt/proof/inner_product:gpu_driver",
"//sxt/proof/sumcheck:chunked_gpu_driver",
"//sxt/proof/sumcheck:proof_computation",
],
with_test = False,
deps = [
Expand All @@ -94,10 +110,13 @@ sxt_cc_component(
sxt_cc_component(
name = "cpu_backend",
impl_deps = [
":callback_sumcheck_transcript",
":computational_backend_utility",
"//sxt/base/error:panic",
"//sxt/base/num:ceil_log2",
"//sxt/base/num:round_up",
"//sxt/cbindings/base:curve_id_utility",
"//sxt/cbindings/base:field_id_utility",
"//sxt/proof/transcript:transcript",
"//sxt/scalar25/type:element",
"//sxt/curve_bng1/operation:add",
Expand Down Expand Up @@ -136,6 +155,8 @@ sxt_cc_component(
"//sxt/proof/inner_product:proof_descriptor",
"//sxt/proof/inner_product:proof_computation",
"//sxt/proof/inner_product:cpu_driver",
"//sxt/proof/sumcheck:cpu_driver",
"//sxt/proof/sumcheck:proof_computation",
],
with_test = False,
deps = [
Expand Down
17 changes: 17 additions & 0 deletions sxt/cbindings/backend/callback_sumcheck_transcript.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/cbindings/backend/callback_sumcheck_transcript.h"
Loading

0 comments on commit 0e59204

Please sign in to comment.