Skip to content

Commit

Permalink
feat: make sumcheck generic (PROOF-913) (#247)
Browse files Browse the repository at this point in the history
* add field concept

* fill in field concept

* fill in field concept

* rework

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumchecK

* rework sumcheck

* rework sumcheck

* rework reference transcript

* fill in reference transcript

* rework sumcheck

* add tests

* rework sumcheck

* add stub for proof computation

* rework sumcheck

* rework sumcheck

* rework cpu driver

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* fill in polynomial utility tests

* fill in cpu driver

* add driver test

* fill in cpu driver

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* add field accumulator

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* fill in chunked driver

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* reformat

* update benchmark

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* rework sumcheck

* refactor sumcheck

* refactor sumcheck

* refactor
  • Loading branch information
rnburn authored Feb 26, 2025
1 parent 0e59204 commit 137f1e8
Show file tree
Hide file tree
Showing 74 changed files with 1,788 additions and 1,899 deletions.
2 changes: 1 addition & 1 deletion benchmark/sumcheck/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ sxt_cc_benchmark(
"//sxt/proof/sumcheck:reference_transcript",
"//sxt/proof/transcript",
"//sxt/scalar25/random:element",
"//sxt/scalar25/type:element",
"//sxt/scalar25/realization:field",
],
)
14 changes: 7 additions & 7 deletions benchmark/sumcheck/benchmark.m.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include "sxt/proof/sumcheck/reference_transcript.h"
#include "sxt/proof/transcript/transcript.h"
#include "sxt/scalar25/random/element.h"
#include "sxt/scalar25/type/element.h"
#include "sxt/scalar25/realization/field.h"

using namespace sxt;

Expand Down Expand Up @@ -111,22 +111,22 @@ int main(int argc, char* argv[]) {
memmg::managed_array<s25t::element> polynomials((p.degree + 1u) * num_rounds);
memmg::managed_array<s25t::element> evaluation_point(num_rounds);
prft::transcript base_transcript{"abc123"};
prfsk::reference_transcript transcript{base_transcript};
prfsk::gpu_driver drv;
prfsk::reference_transcript<s25t::element> transcript{base_transcript};
prfsk::gpu_driver<s25t::element> drv;

// initial run
{
auto fut = prfsk::prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table,
product_terms, p.n);
auto fut = prfsk::prove_sum<s25t::element>(polynomials, evaluation_point, transcript, drv, mles,
product_table, product_terms, p.n);
xens::get_scheduler().run();
}

// sample
double elapse = 0;
for (unsigned i = 0; i < (p.num_samples + 1u); ++i) {
auto t1 = std::chrono::steady_clock::now();
auto fut = prfsk::prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table,
product_terms, p.n);
auto fut = prfsk::prove_sum<s25t::element>(polynomials, evaluation_point, transcript, drv, mles,
product_table, product_terms, p.n);
xens::get_scheduler().run();
auto t2 = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1);
Expand Down
2 changes: 1 addition & 1 deletion cbindings/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ sxt_cc_component(
"//sxt/base/test:unit_test",
"//sxt/proof/sumcheck:reference_transcript",
"//sxt/scalar25/operation:overload",
"//sxt/scalar25/type:element",
"//sxt/scalar25/realization:field",
"//sxt/scalar25/type:literal",
],
deps = [
Expand Down
10 changes: 5 additions & 5 deletions cbindings/sumcheck.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
#include "sxt/base/test/unit_test.h"
#include "sxt/proof/sumcheck/reference_transcript.h"
#include "sxt/scalar25/operation/overload.h"
#include "sxt/scalar25/type/element.h"
#include "sxt/scalar25/realization/field.h"
#include "sxt/scalar25/type/literal.h"

using namespace sxt;
using s25t::operator""_s25;

TEST_CASE("we can create sumcheck proofs") {
prft::transcript base_transcript{"abc"};
prfsk::reference_transcript transcript{base_transcript};
prfsk::reference_transcript<s25t::element> transcript{base_transcript};

std::vector<s25t::element> polynomials(2);
std::vector<s25t::element> evaluation_point(1);
Expand All @@ -55,7 +55,7 @@ TEST_CASE("we can create sumcheck proofs") {

auto f = [](s25t::element* r, void* context, const s25t::element* polynomial,
unsigned polynomial_len) noexcept {
static_cast<prfsk::reference_transcript*>(context)->round_challenge(
static_cast<prfsk::reference_transcript<s25t::element>*>(context)->round_challenge(
*r, {polynomial, polynomial_len});
};

Expand All @@ -70,7 +70,7 @@ TEST_CASE("we can create sumcheck proofs") {
REQUIRE(polynomials[1] == mles[1] - mles[0]);
{
prft::transcript base_transcript_p{"abc"};
prfsk::reference_transcript transcript_p{base_transcript_p};
prfsk::reference_transcript<s25t::element> transcript_p{base_transcript_p};
s25t::element r;
transcript_p.round_challenge(r, polynomials);
REQUIRE(evaluation_point[0] == r);
Expand All @@ -88,7 +88,7 @@ TEST_CASE("we can create sumcheck proofs") {
REQUIRE(polynomials[1] == mles[1] - mles[0]);
{
prft::transcript base_transcript_p{"abc"};
prfsk::reference_transcript transcript_p{base_transcript_p};
prfsk::reference_transcript<s25t::element> transcript_p{base_transcript_p};
s25t::element r;
transcript_p.round_challenge(r, polynomials);
REQUIRE(evaluation_point[0] == r);
Expand Down
5 changes: 5 additions & 0 deletions sxt/base/concept/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ sxt_cc_component(
"//sxt/base/test:unit_test",
],
)

sxt_cc_component(
name = "field",
with_test = False,
)
17 changes: 17 additions & 0 deletions sxt/base/concept/field.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/base/concept/field.h"
31 changes: 31 additions & 0 deletions sxt/base/concept/field.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

namespace sxt::bascpt {
//--------------------------------------------------------------------------------------------------
// field
//--------------------------------------------------------------------------------------------------
template <class T>
concept field = requires(T& res, const T& e) {
neg(res, e);
add(res, e, e);
sub(res, e, e);
mul(res, e, e);
muladd(res, e, e, e);
};
} // namespace sxt::bascpt
14 changes: 14 additions & 0 deletions sxt/base/field/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,17 @@ sxt_cc_component(
"//sxt/base/type:narrow_cast",
],
)

sxt_cc_component(
name = "element",
with_test = False,
)

sxt_cc_component(
name = "accumulator",
with_test = False,
deps = [
":element",
"//sxt/base/macro:cuda_callable",
],
)
17 changes: 17 additions & 0 deletions sxt/base/field/accumulator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/base/field/accumulator.h"
31 changes: 31 additions & 0 deletions sxt/base/field/accumulator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "sxt/base/field/element.h"
#include "sxt/base/macro/cuda_callable.h"

namespace sxt::basfld {
//--------------------------------------------------------------------------------------------------
// accumulator
//--------------------------------------------------------------------------------------------------
template <basfld::element T> struct accumulator {
using value_type = T;

CUDA_CALLABLE static void accumulate_inplace(T& res, T& e) noexcept { add(res, res, e); }
};
} // namespace sxt::basfld
17 changes: 17 additions & 0 deletions sxt/base/field/element.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/base/field/element.h"
35 changes: 35 additions & 0 deletions sxt/base/field/element.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <concepts>

namespace sxt::basfld {
//--------------------------------------------------------------------------------------------------
// element
//--------------------------------------------------------------------------------------------------
template <class T>
concept element = requires(T& res, const T& e) {
neg(res, e);
add(res, e, e);
sub(res, e, e);
mul(res, e, e);
muladd(res, e, e, e);
{ T::identity() } noexcept -> std::same_as<T>;
{ T::one() } noexcept -> std::same_as<T>;
};
} // namespace sxt::basfld
8 changes: 4 additions & 4 deletions sxt/cbindings/backend/callback_sumcheck_transcript.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ namespace sxt::cbnbck {
//--------------------------------------------------------------------------------------------------
// callback_sumcheck_transcript
//--------------------------------------------------------------------------------------------------
class callback_sumcheck_transcript final : public prfsk::sumcheck_transcript {
template <basfld::element T>
class callback_sumcheck_transcript final : public prfsk::sumcheck_transcript<T> {
public:
using callback_t = void (*)(s25t::element* r, void* context, const s25t::element* polynomial,
unsigned polynomial_len);
using callback_t = void (*)(T* r, void* context, const T* polynomial, unsigned polynomial_len);

callback_sumcheck_transcript(callback_t f, void* context) noexcept : f_{f}, context_{context} {}

void init(size_t /*num_variables*/, size_t /*round_degree*/) noexcept override {}

void round_challenge(s25t::element& r, basct::cspan<s25t::element> polynomial) noexcept override {
void round_challenge(T& r, basct::cspan<T> polynomial) noexcept override {
f_(&r, context_, polynomial.data(), static_cast<unsigned>(polynomial.size()));
}

Expand Down
16 changes: 5 additions & 11 deletions sxt/cbindings/backend/cpu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,12 @@
#include "sxt/proof/transcript/transcript.h"
#include "sxt/ristretto/operation/compression.h"
#include "sxt/ristretto/type/compressed_element.h"
#include "sxt/scalar25/type/element.h"
#include "sxt/seqcommit/generator/precomputed_generators.h"

namespace sxt::cbnbck {
//--------------------------------------------------------------------------------------------------
// prove_sumcheck
//--------------------------------------------------------------------------------------------------
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-parameter"
void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsigned field_id,
const cbnb::sumcheck_descriptor& descriptor,
void* transcript_callback, void* transcript_context) noexcept {
Expand All @@ -83,8 +78,8 @@ void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi
static_cast<cbnb::field_id_t>(field_id), [&]<class T>(std::type_identity<T>) noexcept {
static_assert(std::same_as<T, s25t::element>, "only support curve-255 right now");
// transcript
callback_sumcheck_transcript transcript{
reinterpret_cast<callback_sumcheck_transcript::callback_t>(
callback_sumcheck_transcript<T> transcript{
reinterpret_cast<callback_sumcheck_transcript<T>::callback_t>(
const_cast<void*>(transcript_callback)),
transcript_context};

Expand All @@ -109,14 +104,13 @@ void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi
descriptor.product_terms,
descriptor.num_product_terms,
};
prfsk::cpu_driver drv;
prfsk::cpu_driver<T> drv;
auto fut =
prfsk::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, mles_span,
product_table_span, product_terms_span, descriptor.n);
prfsk::prove_sum<T>(polynomials_span, evaluation_point_span, transcript, drv, mles_span,
product_table_span, product_terms_span, descriptor.n);
SXT_RELEASE_ASSERT(fut.ready());
});
}
#pragma clang diagnostic pop

//--------------------------------------------------------------------------------------------------
// compute_commitments
Expand Down
11 changes: 5 additions & 6 deletions sxt/cbindings/backend/gpu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
#include "sxt/ristretto/operation/compression.h"
#include "sxt/ristretto/type/compressed_element.h"
#include "sxt/ristretto/type/literal.h"
#include "sxt/scalar25/type/element.h"
#include "sxt/seqcommit/generator/precomputed_generators.h"

using sxt::rstt::operator""_rs;
Expand Down Expand Up @@ -112,8 +111,8 @@ void gpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi
static_cast<cbnb::field_id_t>(field_id), [&]<class T>(std::type_identity<T>) noexcept {
static_assert(std::same_as<T, s25t::element>, "only support curve-255 right now");
// transcript
callback_sumcheck_transcript transcript{
reinterpret_cast<callback_sumcheck_transcript::callback_t>(
callback_sumcheck_transcript<T> transcript{
reinterpret_cast<callback_sumcheck_transcript<T>::callback_t>(
const_cast<void*>(transcript_callback)),
transcript_context};

Expand All @@ -138,10 +137,10 @@ void gpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi
descriptor.product_terms,
descriptor.num_product_terms,
};
prfsk::chunked_gpu_driver drv;
prfsk::chunked_gpu_driver<T> drv;
auto fut =
prfsk::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, mles_span,
product_table_span, product_terms_span, descriptor.n);
prfsk::prove_sum<T>(polynomials_span, evaluation_point_span, transcript, drv, mles_span,
product_table_span, product_terms_span, descriptor.n);
xens::get_scheduler().run();
});
}
Expand Down
2 changes: 1 addition & 1 deletion sxt/cbindings/base/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ sxt_cc_component(
deps = [
":field_id",
"//sxt/base/error:panic",
"//sxt/scalar25/type:element",
"//sxt/scalar25/realization:field",
],
)

Expand Down
2 changes: 1 addition & 1 deletion sxt/cbindings/base/field_id_utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include "sxt/base/error/panic.h"
#include "sxt/cbindings/base/field_id.h"
#include "sxt/scalar25/type/element.h"
#include "sxt/scalar25/realization/field.h"

namespace sxt::cbnb {
//--------------------------------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 137f1e8

Please sign in to comment.