From 3eaf235aa0e885d279b38560f2e8dd96c5be30dc Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 17 Feb 2025 13:51:58 -0800 Subject: [PATCH 01/83] add field concept --- sxt/base/concept/BUILD | 5 +++++ sxt/base/concept/field.cc | 1 + sxt/base/concept/field.h | 15 +++++++++++++++ 3 files changed, 21 insertions(+) create mode 100644 sxt/base/concept/field.cc create mode 100644 sxt/base/concept/field.h diff --git a/sxt/base/concept/BUILD b/sxt/base/concept/BUILD index 1e3e001ce..a6c9e77bf 100644 --- a/sxt/base/concept/BUILD +++ b/sxt/base/concept/BUILD @@ -9,3 +9,8 @@ sxt_cc_component( "//sxt/base/test:unit_test", ], ) + +sxt_cc_component( + name = "field", + with_test = False, +) diff --git a/sxt/base/concept/field.cc b/sxt/base/concept/field.cc new file mode 100644 index 000000000..9bb167899 --- /dev/null +++ b/sxt/base/concept/field.cc @@ -0,0 +1 @@ +#include "sxt/base/concept/field.h" diff --git a/sxt/base/concept/field.h b/sxt/base/concept/field.h new file mode 100644 index 000000000..3098cf394 --- /dev/null +++ b/sxt/base/concept/field.h @@ -0,0 +1,15 @@ +#pragma once + +namespace sxt::bascpt { +//-------------------------------------------------------------------------------------------------- +// field +//-------------------------------------------------------------------------------------------------- +template +concept field = requires(T& res, const T& e) { + neg(res, e); + add(res, e, e); + sub(res, e, e); + mul(res, e, e); + muladd(res, e, e, e); +}; +} // namespace sxt::bascpt From 0287bb90c2454bc0699e0e3ed8c9b37d37446e21 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 17 Feb 2025 14:18:07 -0800 Subject: [PATCH 02/83] fill in field concept --- sxt/scalar25/realization/BUILD | 9 +++++++++ sxt/scalar25/realization/field.cc | 1 + sxt/scalar25/realization/field.h | 4 ++++ sxt/scalar25/type/BUILD | 6 ++++++ sxt/scalar25/type/element.h | 3 ++- sxt/scalar25/type/operation_adl_stub.cc | 1 + sxt/scalar25/type/operation_adl_stub.h | 13 +++++++++++++ 7 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 sxt/scalar25/realization/BUILD create mode 100644 sxt/scalar25/realization/field.cc create mode 100644 sxt/scalar25/realization/field.h create mode 100644 sxt/scalar25/type/operation_adl_stub.cc create mode 100644 sxt/scalar25/type/operation_adl_stub.h diff --git a/sxt/scalar25/realization/BUILD b/sxt/scalar25/realization/BUILD new file mode 100644 index 000000000..1a6c82622 --- /dev/null +++ b/sxt/scalar25/realization/BUILD @@ -0,0 +1,9 @@ +load( + "//bazel:sxt_build_system.bzl", + "sxt_cc_component", +) + +sxt_cc_component( + name = "field", + with_test = False, +) diff --git a/sxt/scalar25/realization/field.cc b/sxt/scalar25/realization/field.cc new file mode 100644 index 000000000..6a4616b8e --- /dev/null +++ b/sxt/scalar25/realization/field.cc @@ -0,0 +1 @@ +#include "sxt/scalar25/realization/field.h" diff --git a/sxt/scalar25/realization/field.h b/sxt/scalar25/realization/field.h new file mode 100644 index 000000000..78ddb2b84 --- /dev/null +++ b/sxt/scalar25/realization/field.h @@ -0,0 +1,4 @@ +#pragma once + +#include "sxt/scalar25/type/element.h" +#include "sxt/base/concept/field.h" diff --git a/sxt/scalar25/type/BUILD b/sxt/scalar25/type/BUILD index 5a463e5ce..22a845243 100644 --- a/sxt/scalar25/type/BUILD +++ b/sxt/scalar25/type/BUILD @@ -3,6 +3,11 @@ load( "sxt_cc_component", ) +sxt_cc_component( + name = "operation_adl_stub", + with_test = False, +) + sxt_cc_component( name = "element", impl_deps = [ @@ -12,6 +17,7 @@ sxt_cc_component( "//sxt/base/test:unit_test", ], deps = [ + ":operation_adl_stub", "//sxt/base/macro:cuda_callable", ], ) diff --git a/sxt/scalar25/type/element.h b/sxt/scalar25/type/element.h index 3882b5c89..478b22752 100644 --- a/sxt/scalar25/type/element.h +++ b/sxt/scalar25/type/element.h @@ -23,6 +23,7 @@ #include #include "sxt/base/macro/cuda_callable.h" +#include "sxt/scalar25/type/operation_adl_stub.h" namespace sxt::s25t { //-------------------------------------------------------------------------------------------------- @@ -33,7 +34,7 @@ namespace sxt::s25t { * L being the order of the main subgroup * (L = 2^252 + 27742317777372353535851937790883648493). */ -class element { +class element : public s25o::operation_adl_stub { public: element() noexcept = default; diff --git a/sxt/scalar25/type/operation_adl_stub.cc b/sxt/scalar25/type/operation_adl_stub.cc new file mode 100644 index 000000000..82f1feee6 --- /dev/null +++ b/sxt/scalar25/type/operation_adl_stub.cc @@ -0,0 +1 @@ +#include "sxt/scalar25/type/operation_adl_stub.h" diff --git a/sxt/scalar25/type/operation_adl_stub.h b/sxt/scalar25/type/operation_adl_stub.h new file mode 100644 index 000000000..134fafbf6 --- /dev/null +++ b/sxt/scalar25/type/operation_adl_stub.h @@ -0,0 +1,13 @@ +#pragma once + + +namespace sxt::s25o { +//-------------------------------------------------------------------------------------------------- +// operation_adl_stub +//-------------------------------------------------------------------------------------------------- +/** + * A stub class that can be inherited so that functions in the s25o namespace + * will participate in ADL. + */ +struct operation_adl_stub {}; +}; // namespace sxt::s25o From 5ac444bc7f07f5ea56be20a2f99061c6f4c14ca1 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 17 Feb 2025 14:21:29 -0800 Subject: [PATCH 03/83] fill in field concept --- sxt/scalar25/realization/BUILD | 9 +++++++++ sxt/scalar25/realization/field.h | 10 +++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sxt/scalar25/realization/BUILD b/sxt/scalar25/realization/BUILD index 1a6c82622..000207fcb 100644 --- a/sxt/scalar25/realization/BUILD +++ b/sxt/scalar25/realization/BUILD @@ -6,4 +6,13 @@ load( sxt_cc_component( name = "field", with_test = False, + deps = [ + "//sxt/base/concept:field", + "//sxt/scalar25/type:element", + "//sxt/scalar25/operation:add", + "//sxt/scalar25/operation:sub", + "//sxt/scalar25/operation:neg", + "//sxt/scalar25/operation:mul", + "//sxt/scalar25/operation:muladd", + ], ) diff --git a/sxt/scalar25/realization/field.h b/sxt/scalar25/realization/field.h index 78ddb2b84..d8b47181f 100644 --- a/sxt/scalar25/realization/field.h +++ b/sxt/scalar25/realization/field.h @@ -1,4 +1,12 @@ #pragma once -#include "sxt/scalar25/type/element.h" #include "sxt/base/concept/field.h" +#include "sxt/scalar25/operation/add.h" +#include "sxt/scalar25/operation/mul.h" +#include "sxt/scalar25/operation/muladd.h" +#include "sxt/scalar25/operation/neg.h" +#include "sxt/scalar25/operation/sub.h" +#include "sxt/scalar25/type/element.h" + +static_assert( + sxt::bascpt::field); From 111f0605221c2e04a4d491e9d640d91084425fdd Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 17 Feb 2025 21:05:15 -0800 Subject: [PATCH 04/83] rework --- sxt/base/field/BUILD | 5 +++++ sxt/base/field/element.cc | 1 + sxt/base/field/element.h | 15 +++++++++++++++ sxt/scalar25/realization/BUILD | 2 +- sxt/scalar25/realization/field.h | 4 ++-- 5 files changed, 24 insertions(+), 3 deletions(-) create mode 100644 sxt/base/field/element.cc create mode 100644 sxt/base/field/element.h diff --git a/sxt/base/field/BUILD b/sxt/base/field/BUILD index 68838f458..b68abf0bf 100644 --- a/sxt/base/field/BUILD +++ b/sxt/base/field/BUILD @@ -14,3 +14,8 @@ sxt_cc_component( "//sxt/base/type:narrow_cast", ], ) + +sxt_cc_component( + name = "element", + with_test = False, +) diff --git a/sxt/base/field/element.cc b/sxt/base/field/element.cc new file mode 100644 index 000000000..f95ddeab8 --- /dev/null +++ b/sxt/base/field/element.cc @@ -0,0 +1 @@ +#include "sxt/base/field/element.h" diff --git a/sxt/base/field/element.h b/sxt/base/field/element.h new file mode 100644 index 000000000..c91c6ff15 --- /dev/null +++ b/sxt/base/field/element.h @@ -0,0 +1,15 @@ +#pragma once + +namespace sxt::basfld { +//-------------------------------------------------------------------------------------------------- +// element +//-------------------------------------------------------------------------------------------------- +template +concept element = requires(T& res, const T& e) { + neg(res, e); + add(res, e, e); + sub(res, e, e); + mul(res, e, e); + muladd(res, e, e, e); +}; +} // namespace sxt::basfld diff --git a/sxt/scalar25/realization/BUILD b/sxt/scalar25/realization/BUILD index 000207fcb..aa4e2ed87 100644 --- a/sxt/scalar25/realization/BUILD +++ b/sxt/scalar25/realization/BUILD @@ -7,7 +7,7 @@ sxt_cc_component( name = "field", with_test = False, deps = [ - "//sxt/base/concept:field", + "//sxt/base/field:element", "//sxt/scalar25/type:element", "//sxt/scalar25/operation:add", "//sxt/scalar25/operation:sub", diff --git a/sxt/scalar25/realization/field.h b/sxt/scalar25/realization/field.h index d8b47181f..c4f5501cc 100644 --- a/sxt/scalar25/realization/field.h +++ b/sxt/scalar25/realization/field.h @@ -1,6 +1,6 @@ #pragma once -#include "sxt/base/concept/field.h" +#include "sxt/base/field/element.h" #include "sxt/scalar25/operation/add.h" #include "sxt/scalar25/operation/mul.h" #include "sxt/scalar25/operation/muladd.h" @@ -9,4 +9,4 @@ #include "sxt/scalar25/type/element.h" static_assert( - sxt::bascpt::field); + sxt::basfld::element); From 8b1d3c19efd67e2665df7723739cea905d3a888e Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 17 Feb 2025 21:10:46 -0800 Subject: [PATCH 05/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 25 +++++++++++++++++++++++++ sxt/proof/sumcheck2/verification.cc | 1 + sxt/proof/sumcheck2/verification.h | 16 ++++++++++++++++ sxt/proof/sumcheck2/verification.t.cc | 5 +++++ 4 files changed, 47 insertions(+) create mode 100644 sxt/proof/sumcheck2/BUILD create mode 100644 sxt/proof/sumcheck2/verification.cc create mode 100644 sxt/proof/sumcheck2/verification.h create mode 100644 sxt/proof/sumcheck2/verification.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD new file mode 100644 index 000000000..518694d7a --- /dev/null +++ b/sxt/proof/sumcheck2/BUILD @@ -0,0 +1,25 @@ +load( + "//bazel:sxt_build_system.bzl", + "sxt_cc_component", +) + +sxt_cc_component( + name = "verification", + # impl_deps = [ + # ":polynomial_utility", + # ":sumcheck_transcript", + # "//sxt/base/error:assert", + # "//sxt/base/log:log", + # "//sxt/scalar25/operation:overload", + # "//sxt/scalar25/type:element", + # ], + test_deps = [ + # ":reference_transcript", + "//sxt/base/test:unit_test", + # "//sxt/scalar25/type:element", + # "//sxt/scalar25/type:literal", + ], + deps = [ + "//sxt/base/container:span", + ], +) diff --git a/sxt/proof/sumcheck2/verification.cc b/sxt/proof/sumcheck2/verification.cc new file mode 100644 index 000000000..0746fc7e9 --- /dev/null +++ b/sxt/proof/sumcheck2/verification.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/verification.h" diff --git a/sxt/proof/sumcheck2/verification.h b/sxt/proof/sumcheck2/verification.h new file mode 100644 index 000000000..5b6a5e301 --- /dev/null +++ b/sxt/proof/sumcheck2/verification.h @@ -0,0 +1,16 @@ +#pragma once + +#include "sxt/base/container/span.h" + +namespace sxt::prfsk2 { +class sumcheck_transcript; + +//-------------------------------------------------------------------------------------------------- +// verify_sumcheck_no_evaluation +//-------------------------------------------------------------------------------------------------- +/* bool verify_sumcheck_no_evaluation(s25t::element& expected_sum, */ +/* basct::span evaluation_point, */ +/* sumcheck_transcript& transcript, */ +/* basct::cspan round_polynomials, */ +/* unsigned round_degree) noexcept; */ +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/verification.t.cc b/sxt/proof/sumcheck2/verification.t.cc new file mode 100644 index 000000000..d7d98eaec --- /dev/null +++ b/sxt/proof/sumcheck2/verification.t.cc @@ -0,0 +1,5 @@ +#include "sxt/proof/sumcheck2/verification.h" + +#include "sxt/base/test/unit_test.h" + +TEST_CASE("todo") {} From fb85a676c6e3d3253dfe36809c426b91cd208fe6 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 18 Feb 2025 13:46:02 -0800 Subject: [PATCH 06/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 11 ++++ sxt/proof/sumcheck2/sumcheck_transcript.cc | 1 + sxt/proof/sumcheck2/sumcheck_transcript.h | 19 ++++++ sxt/proof/sumcheck2/verification.h | 67 +++++++++++++++++++--- 4 files changed, 91 insertions(+), 7 deletions(-) create mode 100644 sxt/proof/sumcheck2/sumcheck_transcript.cc create mode 100644 sxt/proof/sumcheck2/sumcheck_transcript.h diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 518694d7a..ddf169f9f 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -3,6 +3,15 @@ load( "sxt_cc_component", ) +sxt_cc_component( + name = "sumcheck_transcript", + with_test = False, + deps = [ + "//sxt/base/container:span", + "//sxt/base/field:element", + ], +) + sxt_cc_component( name = "verification", # impl_deps = [ @@ -20,6 +29,8 @@ sxt_cc_component( # "//sxt/scalar25/type:literal", ], deps = [ + ":sumcheck_transcript", "//sxt/base/container:span", + "//sxt/base/log:log", ], ) diff --git a/sxt/proof/sumcheck2/sumcheck_transcript.cc b/sxt/proof/sumcheck2/sumcheck_transcript.cc new file mode 100644 index 000000000..ae78804db --- /dev/null +++ b/sxt/proof/sumcheck2/sumcheck_transcript.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/sumcheck_transcript.h" diff --git a/sxt/proof/sumcheck2/sumcheck_transcript.h b/sxt/proof/sumcheck2/sumcheck_transcript.h new file mode 100644 index 000000000..a23e98bb8 --- /dev/null +++ b/sxt/proof/sumcheck2/sumcheck_transcript.h @@ -0,0 +1,19 @@ +#pragma once + +#include "sxt/base/field/element.h" +#include "sxt/base/container/span.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// sumcheck_transcript +//-------------------------------------------------------------------------------------------------- +template +class sumcheck_transcript { +public: + virtual ~sumcheck_transcript() noexcept = default; + + virtual void init(size_t num_variables, size_t round_degree) noexcept = 0; + + virtual void round_challenge(T& r, basct::cspan polynomial) noexcept = 0; +}; +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/verification.h b/sxt/proof/sumcheck2/verification.h index 5b6a5e301..a3a0081af 100644 --- a/sxt/proof/sumcheck2/verification.h +++ b/sxt/proof/sumcheck2/verification.h @@ -1,16 +1,69 @@ #pragma once #include "sxt/base/container/span.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/log/log.h" +#include "sxt/proof/sumcheck2/sumcheck_transcript.h" namespace sxt::prfsk2 { -class sumcheck_transcript; - //-------------------------------------------------------------------------------------------------- // verify_sumcheck_no_evaluation //-------------------------------------------------------------------------------------------------- -/* bool verify_sumcheck_no_evaluation(s25t::element& expected_sum, */ -/* basct::span evaluation_point, */ -/* sumcheck_transcript& transcript, */ -/* basct::cspan round_polynomials, */ -/* unsigned round_degree) noexcept; */ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-parameter" +template +bool verify_sumcheck_no_evaluation(T& expected_sum, + basct::span evaluation_point, + sumcheck_transcript& transcript, + basct::cspan round_polynomials, + unsigned round_degree) noexcept { + return true; + auto num_variables = evaluation_point.size(); + SXT_RELEASE_ASSERT( + // clang-format off + num_variables > 0 && round_degree > 0 + // clang-format on + ); + + basl::info("verifying sumcheck of {} variables and round degree {}", num_variables, round_degree); + + // check dimensions + if (auto expected_count = (round_degree + 1u) * num_variables; + round_polynomials.size() != expected_count) { + basl::info("sumcheck verification failed: expected {} scalars for round_polynomials but got {}", + expected_count, round_polynomials.size()); + return false; + } + + transcript.init(num_variables, round_degree); + + // go through sumcheck rounds + for (unsigned round_index = 0; round_index < num_variables; ++round_index) { + auto polynomial = + round_polynomials.subspan((round_degree + 1u) * round_index, round_degree + 1u); + +#if 0 + // check sum + s25t::element sum; + sum_polynomial_01(sum, polynomial); + if (expected_sum != sum) { + basl::info("sumcheck verification failed on round {}", round_index + 1); + return false; + } + + // draw a random scalar + s25t::element r; + transcript.round_challenge(r, polynomial); + evaluation_point[round_index] = r; + + // evaluate at random point + evaluate_polynomial(expected_sum, polynomial, r); +#endif + } + + return true; +} +#pragma clang diagnostic pop } // namespace sxt::prfsk2 From 6f81e62b821433e37e2e70764889beeee4309ee9 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 18 Feb 2025 14:17:25 -0800 Subject: [PATCH 07/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 8 +++++++ sxt/proof/sumcheck2/polynomial_utility.cc | 1 + sxt/proof/sumcheck2/polynomial_utility.h | 25 +++++++++++++++++++++ sxt/proof/sumcheck2/polynomial_utility.t.cc | 5 +++++ 4 files changed, 39 insertions(+) create mode 100644 sxt/proof/sumcheck2/polynomial_utility.cc create mode 100644 sxt/proof/sumcheck2/polynomial_utility.h create mode 100644 sxt/proof/sumcheck2/polynomial_utility.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index ddf169f9f..92e95fa99 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -3,6 +3,14 @@ load( "sxt_cc_component", ) +sxt_cc_component( + name = "polynomial_utility", + deps = [ + "//sxt/base/container:span", + "//sxt/base/field:element", + ], +) + sxt_cc_component( name = "sumcheck_transcript", with_test = False, diff --git a/sxt/proof/sumcheck2/polynomial_utility.cc b/sxt/proof/sumcheck2/polynomial_utility.cc new file mode 100644 index 000000000..f69c84beb --- /dev/null +++ b/sxt/proof/sumcheck2/polynomial_utility.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/polynomial_utility.h" diff --git a/sxt/proof/sumcheck2/polynomial_utility.h b/sxt/proof/sumcheck2/polynomial_utility.h new file mode 100644 index 000000000..6ff885ca6 --- /dev/null +++ b/sxt/proof/sumcheck2/polynomial_utility.h @@ -0,0 +1,25 @@ +#pragma once + +#include "sxt/base/container/span.h" +#include "sxt/base/field/element.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// sum_polynomial_01 +//-------------------------------------------------------------------------------------------------- +// Given a polynomial +// f_a(X) = a[0] + a[1] * X + a[2] * X^2 + ... +// compute the sum +// f_a(0) + f_a(1) +template +void sum_polynomial_01(T& e, basct::cspan polynomial) noexcept { + if (polynomial.empty()) { + e = T{}; + return; + } + e = polynomial[0]; + for (unsigned i = 0; i < polynomial.size(); ++i) { + add(e, e, polynomial[i]); + } +} +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/polynomial_utility.t.cc b/sxt/proof/sumcheck2/polynomial_utility.t.cc new file mode 100644 index 000000000..9c9345d0b --- /dev/null +++ b/sxt/proof/sumcheck2/polynomial_utility.t.cc @@ -0,0 +1,5 @@ +#include "sxt/proof/sumcheck2/polynomial_utility.h" + +#include "sxt/base/test/unit_test.h" + +TEST_CASE("todo") {} From 5b957b47cce703aeacb9d983bb512ae74718a5f6 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 18 Feb 2025 14:22:50 -0800 Subject: [PATCH 08/83] rework sumchecK --- sxt/proof/sumcheck2/BUILD | 2 +- sxt/proof/sumcheck2/verification.h | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 92e95fa99..85b5872ac 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -23,7 +23,6 @@ sxt_cc_component( sxt_cc_component( name = "verification", # impl_deps = [ - # ":polynomial_utility", # ":sumcheck_transcript", # "//sxt/base/error:assert", # "//sxt/base/log:log", @@ -37,6 +36,7 @@ sxt_cc_component( # "//sxt/scalar25/type:literal", ], deps = [ + ":polynomial_utility", ":sumcheck_transcript", "//sxt/base/container:span", "//sxt/base/log:log", diff --git a/sxt/proof/sumcheck2/verification.h b/sxt/proof/sumcheck2/verification.h index a3a0081af..895509507 100644 --- a/sxt/proof/sumcheck2/verification.h +++ b/sxt/proof/sumcheck2/verification.h @@ -4,6 +4,7 @@ #include "sxt/base/error/assert.h" #include "sxt/base/log/log.h" #include "sxt/proof/sumcheck2/sumcheck_transcript.h" +#include "sxt/proof/sumcheck2/polynomial_utility.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- @@ -44,9 +45,8 @@ bool verify_sumcheck_no_evaluation(T& expected_sum, auto polynomial = round_polynomials.subspan((round_degree + 1u) * round_index, round_degree + 1u); -#if 0 // check sum - s25t::element sum; + T sum; sum_polynomial_01(sum, polynomial); if (expected_sum != sum) { basl::info("sumcheck verification failed on round {}", round_index + 1); @@ -54,10 +54,11 @@ bool verify_sumcheck_no_evaluation(T& expected_sum, } // draw a random scalar - s25t::element r; + T r; transcript.round_challenge(r, polynomial); evaluation_point[round_index] = r; +#if 0 // evaluate at random point evaluate_polynomial(expected_sum, polynomial, r); #endif From 8d29a9d436e56d002f01a23126ca0d4d6da073c0 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 18 Feb 2025 14:34:50 -0800 Subject: [PATCH 09/83] rework sumcheck --- sxt/proof/sumcheck2/polynomial_utility.h | 18 ++++++++++++++++++ sxt/proof/sumcheck2/verification.h | 8 -------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sxt/proof/sumcheck2/polynomial_utility.h b/sxt/proof/sumcheck2/polynomial_utility.h index 6ff885ca6..edd5efc98 100644 --- a/sxt/proof/sumcheck2/polynomial_utility.h +++ b/sxt/proof/sumcheck2/polynomial_utility.h @@ -22,4 +22,22 @@ void sum_polynomial_01(T& e, basct::cspan polynomial) noexcept { add(e, e, polynomial[i]); } } + +//-------------------------------------------------------------------------------------------------- +// evaluate_polynomial +//-------------------------------------------------------------------------------------------------- +template +void evaluate_polynomial(T& e, basct::cspan polynomial, const T& x) noexcept { + if (polynomial.empty()) { + e = T{}; + return; + } + auto i = polynomial.size(); + --i; + e = polynomial[i]; + while (i > 0) { + --i; + muladd(e, e, x, polynomial[i]); + } +} } // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/verification.h b/sxt/proof/sumcheck2/verification.h index 895509507..6acac8863 100644 --- a/sxt/proof/sumcheck2/verification.h +++ b/sxt/proof/sumcheck2/verification.h @@ -10,17 +10,12 @@ namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // verify_sumcheck_no_evaluation //-------------------------------------------------------------------------------------------------- -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-parameter" template bool verify_sumcheck_no_evaluation(T& expected_sum, basct::span evaluation_point, sumcheck_transcript& transcript, basct::cspan round_polynomials, unsigned round_degree) noexcept { - return true; auto num_variables = evaluation_point.size(); SXT_RELEASE_ASSERT( // clang-format off @@ -58,13 +53,10 @@ bool verify_sumcheck_no_evaluation(T& expected_sum, transcript.round_challenge(r, polynomial); evaluation_point[round_index] = r; -#if 0 // evaluate at random point evaluate_polynomial(expected_sum, polynomial, r); -#endif } return true; } -#pragma clang diagnostic pop } // namespace sxt::prfsk2 From 19db19ade2afc721151b9ede4cf554c2cf8c1b22 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 18 Feb 2025 17:09:50 -0800 Subject: [PATCH 10/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 6 +++--- sxt/proof/sumcheck2/verification.t.cc | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 85b5872ac..3f676a930 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -26,14 +26,14 @@ sxt_cc_component( # ":sumcheck_transcript", # "//sxt/base/error:assert", # "//sxt/base/log:log", - # "//sxt/scalar25/operation:overload", # "//sxt/scalar25/type:element", # ], test_deps = [ # ":reference_transcript", "//sxt/base/test:unit_test", - # "//sxt/scalar25/type:element", - # "//sxt/scalar25/type:literal", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/type:element", + "//sxt/scalar25/type:literal", ], deps = [ ":polynomial_utility", diff --git a/sxt/proof/sumcheck2/verification.t.cc b/sxt/proof/sumcheck2/verification.t.cc index d7d98eaec..bef321c61 100644 --- a/sxt/proof/sumcheck2/verification.t.cc +++ b/sxt/proof/sumcheck2/verification.t.cc @@ -1,5 +1,16 @@ #include "sxt/proof/sumcheck2/verification.h" +#include + #include "sxt/base/test/unit_test.h" +/* #include "sxt/proof/sumcheck/reference_transcript.h" */ +#include "sxt/proof/transcript/transcript.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/type/literal.h" + +using namespace sxt; +using namespace sxt::prfsk2; +using sxt::s25t::operator""_s25; TEST_CASE("todo") {} From 6981f31df67773645ec69ec908ba2d152a04bb47 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 18 Feb 2025 17:19:57 -0800 Subject: [PATCH 11/83] rework reference transcript --- sxt/proof/sumcheck2/BUILD | 9 ++++++ sxt/proof/sumcheck2/reference_transcript.cc | 1 + sxt/proof/sumcheck2/reference_transcript.h | 28 +++++++++++++++++++ sxt/proof/sumcheck2/reference_transcript.t.cc | 5 ++++ 4 files changed, 43 insertions(+) create mode 100644 sxt/proof/sumcheck2/reference_transcript.cc create mode 100644 sxt/proof/sumcheck2/reference_transcript.h create mode 100644 sxt/proof/sumcheck2/reference_transcript.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 3f676a930..873a0274c 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -20,6 +20,15 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "reference_transcript", + deps = [ + ":sumcheck_transcript", + "//sxt/base/container:span", + "//sxt/base/field:element", + ], +) + sxt_cc_component( name = "verification", # impl_deps = [ diff --git a/sxt/proof/sumcheck2/reference_transcript.cc b/sxt/proof/sumcheck2/reference_transcript.cc new file mode 100644 index 000000000..44733b41e --- /dev/null +++ b/sxt/proof/sumcheck2/reference_transcript.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/reference_transcript.h" diff --git a/sxt/proof/sumcheck2/reference_transcript.h b/sxt/proof/sumcheck2/reference_transcript.h new file mode 100644 index 000000000..afa023f62 --- /dev/null +++ b/sxt/proof/sumcheck2/reference_transcript.h @@ -0,0 +1,28 @@ +#pragma once + +#include "sxt/proof/sumcheck2/sumcheck_transcript.h" +#include "sxt/proof/transcript/transcript.h" + +namespace sxt::prfsk { +//-------------------------------------------------------------------------------------------------- +// reference_transcript +//-------------------------------------------------------------------------------------------------- +template +class reference_transcript final : public sumcheck_transcript { +public: + explicit reference_transcript(prft::transcript& transcript) noexcept; + + void init(size_t num_variables, size_t round_degree) noexcept { + (void)num_variables; + (void)round_degree; + } + + void round_challenge(T& r, basct::cspan polynomial) noexcept { + (void)r; + (void)polynomial; + } + +private: + prft::transcript& transcript_; +}; +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck2/reference_transcript.t.cc b/sxt/proof/sumcheck2/reference_transcript.t.cc new file mode 100644 index 000000000..37030bddf --- /dev/null +++ b/sxt/proof/sumcheck2/reference_transcript.t.cc @@ -0,0 +1,5 @@ +#include "sxt/proof/sumcheck2/reference_transcript.h" + +#include "sxt/base/test/unit_test.h" + +TEST_CASE("todo") {} From 29a3fc4178e63cfe68e92e90d7533656c543c7c8 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 18 Feb 2025 17:26:35 -0800 Subject: [PATCH 12/83] fill in reference transcript --- sxt/proof/sumcheck2/BUILD | 1 + sxt/proof/sumcheck2/reference_transcript.h | 14 ++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 873a0274c..339034593 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -26,6 +26,7 @@ sxt_cc_component( ":sumcheck_transcript", "//sxt/base/container:span", "//sxt/base/field:element", + "//sxt/proof/transcript:transcript_utility", ], ) diff --git a/sxt/proof/sumcheck2/reference_transcript.h b/sxt/proof/sumcheck2/reference_transcript.h index afa023f62..4dfdc461a 100644 --- a/sxt/proof/sumcheck2/reference_transcript.h +++ b/sxt/proof/sumcheck2/reference_transcript.h @@ -2,8 +2,9 @@ #include "sxt/proof/sumcheck2/sumcheck_transcript.h" #include "sxt/proof/transcript/transcript.h" +#include "sxt/proof/transcript/transcript_utility.h" -namespace sxt::prfsk { +namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // reference_transcript //-------------------------------------------------------------------------------------------------- @@ -13,16 +14,17 @@ class reference_transcript final : public sumcheck_transcript { explicit reference_transcript(prft::transcript& transcript) noexcept; void init(size_t num_variables, size_t round_degree) noexcept { - (void)num_variables; - (void)round_degree; + prft::set_domain(transcript_, "sumcheck proof v1"); + prft::append_value(transcript_, "n", num_variables); + prft::append_value(transcript_, "k", round_degree); } void round_challenge(T& r, basct::cspan polynomial) noexcept { - (void)r; - (void)polynomial; + prft::append_values(transcript_, "P", polynomial); + prft::challenge_value(r, transcript_, "R"); } private: prft::transcript& transcript_; }; -} // namespace sxt::prfsk +} // namespace sxt::prfsk2 From 6b071b396e8ba612028ed0e8c375caf62980a0e5 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 18 Feb 2025 18:37:34 -0800 Subject: [PATCH 13/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 4 +- sxt/proof/sumcheck2/reference_transcript.h | 3 +- sxt/proof/sumcheck2/verification.t.cc | 104 ++++++++++++++++++++- 3 files changed, 105 insertions(+), 6 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 339034593..9592e74cc 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -39,10 +39,10 @@ sxt_cc_component( # "//sxt/scalar25/type:element", # ], test_deps = [ - # ":reference_transcript", + ":reference_transcript", "//sxt/base/test:unit_test", "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], deps = [ diff --git a/sxt/proof/sumcheck2/reference_transcript.h b/sxt/proof/sumcheck2/reference_transcript.h index 4dfdc461a..3c75da067 100644 --- a/sxt/proof/sumcheck2/reference_transcript.h +++ b/sxt/proof/sumcheck2/reference_transcript.h @@ -11,7 +11,8 @@ namespace sxt::prfsk2 { template class reference_transcript final : public sumcheck_transcript { public: - explicit reference_transcript(prft::transcript& transcript) noexcept; + explicit reference_transcript(prft::transcript& transcript) noexcept + : transcript_{transcript} {} void init(size_t num_variables, size_t round_degree) noexcept { prft::set_domain(transcript_, "sumcheck proof v1"); diff --git a/sxt/proof/sumcheck2/verification.t.cc b/sxt/proof/sumcheck2/verification.t.cc index bef321c61..4d6502753 100644 --- a/sxt/proof/sumcheck2/verification.t.cc +++ b/sxt/proof/sumcheck2/verification.t.cc @@ -3,14 +3,112 @@ #include #include "sxt/base/test/unit_test.h" -/* #include "sxt/proof/sumcheck/reference_transcript.h" */ +#include "sxt/proof/sumcheck2/reference_transcript.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; using namespace sxt::prfsk2; using sxt::s25t::operator""_s25; -TEST_CASE("todo") {} +TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { + using T = s25t::element; + s25t::element expected_sum = 0x0_s25; + std::vector evaluation_point = {0x0_s25}; + prft::transcript base_transcript{"abc"}; + reference_transcript transcript{base_transcript}; + std::vector round_polynomials = {0x0_s25, 0x0_s25}; + + SECTION("verification fails if dimensions don't match") { + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 2); + REQUIRE(!res); + } + + SECTION("we can verify a single round") { + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 1); + REQUIRE(res); + REQUIRE(evaluation_point[0] != 0x0_s25); + } + + SECTION("verification fails if the round polynomial doesn't match the sum") { + round_polynomials[1] = 0x1_s25; + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 1); + REQUIRE(!res); + } + + SECTION("we can verify a sum with two rounds") { + // Use the MLE: + // 3(1-x1)(1-x2) + 5(1-x1)x2 -7x1(1-x2) -1x1x2 + round_polynomials.resize(4); + + // round 1 + round_polynomials[0] = 0x3_s25 + 0x5_s25; + round_polynomials[1] = -0x3_s25 - 0x7_s25 - 0x5_s25 - 0x1_s25; + + // draw scalar + s25t::element r; + { + prft::transcript base_transcript_p{"abc"}; + reference_transcript transcript_p{base_transcript_p}; + transcript_p.init(2, 1); + transcript_p.round_challenge(r, basct::span{round_polynomials}.subspan(0, 2)); + } + + // round 2 + round_polynomials[2] = 0x3_s25 * (0x1_s25 - r) - 0x7_s25 * r; + round_polynomials[3] = + -0x3_s25 * (0x1_s25 - r) + 0x5_s25 * (0x1_s25 - r) + 0x7_s25 * r - 0x1_s25 * r; + + // prove + evaluation_point.resize(2); + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 1); + REQUIRE(evaluation_point[0] == r); + REQUIRE(res); + } + + SECTION("sumcheck verification fails if the random scalar used is wrong") { + // Use the MLE: + // 3(1-x1)(1-x2) + 5(1-x1)x2 -7x1(1-x2) -1x1x2 + round_polynomials.resize(4); + + // round 1 + round_polynomials[0] = 0x3_s25 + 0x5_s25; + round_polynomials[1] = -0x3_s25 - 0x7_s25 - 0x5_s25 - 0x1_s25; + + // draw scalar + s25t::element r = 0x112233_s25; + + // round 2 + round_polynomials[2] = 0x3_s25 * (0x1_s25 - r) - 0x7_s25 * r; + round_polynomials[3] = + -0x3_s25 * (0x1_s25 - r) + 0x5_s25 * (0x1_s25 - r) + 0x7_s25 * r - 0x1_s25 * r; + + // prove + evaluation_point.resize(2); + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 1); + REQUIRE(!res); + } + + SECTION("we can verify a polynomial of degree 2 with one round") { + // Use the MLEs: + // f(x1) = 3(1-x1) -7x1 + // g(x1) = -2 (1 - x1) + 4 x1 + round_polynomials = { + 0x3_s25 * -0x2_s25, + (-0x3_s25 - 0x7_s25) * -0x2_s25 + 0x3_s25 * (0x2_s25 + 0x4_s25), + (-0x3_s25 - 0x7_s25) * (0x2_s25 + 0x4_s25), + }; + expected_sum = 0x3_s25 * -0x2_s25 - 0x7_s25 * 0x4_s25; + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 2); + REQUIRE(res); + REQUIRE(evaluation_point[0] != 0x0_s25); + } +} From 2bff522c42aebe4e6d89d2975b1a0b39f2230a78 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 18 Feb 2025 22:15:33 -0800 Subject: [PATCH 14/83] add tests --- sxt/proof/sumcheck2/BUILD | 6 +++ sxt/proof/sumcheck2/reference_transcript.t.cc | 40 ++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 9592e74cc..6eaece8e8 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -28,6 +28,12 @@ sxt_cc_component( "//sxt/base/field:element", "//sxt/proof/transcript:transcript_utility", ], + test_deps = [ + "//sxt/base/test:unit_test", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/realization:field", + "//sxt/scalar25/type:literal", + ], ) sxt_cc_component( diff --git a/sxt/proof/sumcheck2/reference_transcript.t.cc b/sxt/proof/sumcheck2/reference_transcript.t.cc index 37030bddf..307647110 100644 --- a/sxt/proof/sumcheck2/reference_transcript.t.cc +++ b/sxt/proof/sumcheck2/reference_transcript.t.cc @@ -1,5 +1,43 @@ #include "sxt/proof/sumcheck2/reference_transcript.h" #include "sxt/base/test/unit_test.h" +#include "sxt/proof/transcript/transcript.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/realization/field.h" +#include "sxt/scalar25/type/literal.h" -TEST_CASE("todo") {} +using namespace sxt; +using namespace sxt::prfsk2; +using sxt::s25t::operator""_s25; + +TEST_CASE("we provide an implementation of sumcheck transcript") { + using T = s25t::element; + prft::transcript base_transcript{"abc"}; + reference_transcript transcript{base_transcript}; + std::vector p = {0x123_s25}; + s25t::element r, rp; + + SECTION("we don't draw the same challenge from a transcript") { + transcript.round_challenge(r, p); + transcript.round_challenge(rp, p); + REQUIRE(r != rp); + + prft::transcript base_transcript_p{"abc"}; + reference_transcript transcript_p{base_transcript_p}; + p[0] = 0x456_s25; + transcript_p.round_challenge(rp, p); + REQUIRE(r != rp); + } + + SECTION("init_transcript produces different results based on parameters") { + transcript.init(1, 2); + transcript.round_challenge(r, p); + + prft::transcript base_transcript_p{"abc"}; + reference_transcript transcript_p{base_transcript_p}; + transcript.init(2, 1); + transcript.round_challenge(rp, p); + + REQUIRE(r != rp); + } +} From 605b349325caadfd0bec9e122fefeefb883f075a Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 09:16:48 -0800 Subject: [PATCH 15/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 18 ++++++++++++++++++ sxt/proof/sumcheck2/driver.cc | 1 + sxt/proof/sumcheck2/driver.h | 29 +++++++++++++++++++++++++++++ sxt/proof/sumcheck2/workspace.cc | 1 + sxt/proof/sumcheck2/workspace.h | 11 +++++++++++ 5 files changed, 60 insertions(+) create mode 100644 sxt/proof/sumcheck2/driver.cc create mode 100644 sxt/proof/sumcheck2/driver.h create mode 100644 sxt/proof/sumcheck2/workspace.cc create mode 100644 sxt/proof/sumcheck2/workspace.h diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 6eaece8e8..65dc91f38 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -3,6 +3,24 @@ load( "sxt_cc_component", ) +sxt_cc_component( + name = "workspace", + with_test = False, + deps = [ + ], +) + +sxt_cc_component( + name = "driver", + with_test = False, + deps = [ + ":workspace", + "//sxt/execution/async:future_fwd", + "//sxt/base/container:span", + "//sxt/base/field:element", + ], +) + sxt_cc_component( name = "polynomial_utility", deps = [ diff --git a/sxt/proof/sumcheck2/driver.cc b/sxt/proof/sumcheck2/driver.cc new file mode 100644 index 000000000..74f52ac47 --- /dev/null +++ b/sxt/proof/sumcheck2/driver.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/driver.h" diff --git a/sxt/proof/sumcheck2/driver.h b/sxt/proof/sumcheck2/driver.h new file mode 100644 index 000000000..bc857ae62 --- /dev/null +++ b/sxt/proof/sumcheck2/driver.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "sxt/base/container/span.h" +#include "sxt/base/field/element.h" +#include "sxt/execution/async/future_fwd.h" +#include "sxt/proof/sumcheck2/workspace.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// driver +//-------------------------------------------------------------------------------------------------- +template +class driver { +public: + virtual ~driver() noexcept = default; + + virtual xena::future> + make_workspace(basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned n) const noexcept = 0; + + virtual xena::future<> sum(basct::span polynomial, + workspace& ws) const noexcept = 0; + + virtual xena::future<> fold(workspace& ws, const T& r) const noexcept = 0; +}; +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/workspace.cc b/sxt/proof/sumcheck2/workspace.cc new file mode 100644 index 000000000..f5547993c --- /dev/null +++ b/sxt/proof/sumcheck2/workspace.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/workspace.h" diff --git a/sxt/proof/sumcheck2/workspace.h b/sxt/proof/sumcheck2/workspace.h new file mode 100644 index 000000000..a422cb703 --- /dev/null +++ b/sxt/proof/sumcheck2/workspace.h @@ -0,0 +1,11 @@ +#pragma once + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// workspace +//-------------------------------------------------------------------------------------------------- +class workspace { +public: + virtual ~workspace() noexcept = default; +}; +} // namespace sxt::prfsk2 From 94b94f2e98e59fc72a57f926b21b05b9ab1c14ca Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 13:54:24 -0800 Subject: [PATCH 16/83] add stub for proof computation --- sxt/proof/sumcheck2/BUILD | 14 ++++++++++++++ sxt/proof/sumcheck2/proof_computation.cc | 1 + sxt/proof/sumcheck2/proof_computation.h | 4 ++++ sxt/proof/sumcheck2/proof_computation.t.cc | 5 +++++ 4 files changed, 24 insertions(+) create mode 100644 sxt/proof/sumcheck2/proof_computation.cc create mode 100644 sxt/proof/sumcheck2/proof_computation.h create mode 100644 sxt/proof/sumcheck2/proof_computation.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 65dc91f38..680d0a04e 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -27,6 +27,20 @@ sxt_cc_component( "//sxt/base/container:span", "//sxt/base/field:element", ], + test_deps = [ + "//sxt/base/test:unit_test", + ], +) + +sxt_cc_component( + name = "proof_computation", + deps = [ + "//sxt/base/container:span", + "//sxt/base/field:element", + ], + test_deps = [ + "//sxt/base/test:unit_test", + ], ) sxt_cc_component( diff --git a/sxt/proof/sumcheck2/proof_computation.cc b/sxt/proof/sumcheck2/proof_computation.cc new file mode 100644 index 000000000..8d7efecf6 --- /dev/null +++ b/sxt/proof/sumcheck2/proof_computation.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/proof_computation.h" diff --git a/sxt/proof/sumcheck2/proof_computation.h b/sxt/proof/sumcheck2/proof_computation.h new file mode 100644 index 000000000..f444efdaa --- /dev/null +++ b/sxt/proof/sumcheck2/proof_computation.h @@ -0,0 +1,4 @@ +#pragma once + +namespace sxt::prfsk2 { +} // namespace sxt:prfsk2 diff --git a/sxt/proof/sumcheck2/proof_computation.t.cc b/sxt/proof/sumcheck2/proof_computation.t.cc new file mode 100644 index 000000000..608b8e4d8 --- /dev/null +++ b/sxt/proof/sumcheck2/proof_computation.t.cc @@ -0,0 +1,5 @@ +#include "sxt/proof/sumcheck2/proof_computation.h" + +#include "sxt/base/test/unit_test.h" + +TEST_CASE("todo") {} From 80921d8face537390281c5e5dba2d11781173c77 Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 14:03:22 -0800 Subject: [PATCH 17/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 6 +++ sxt/proof/sumcheck2/proof_computation.h | 52 +++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 680d0a04e..527d98d47 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -35,8 +35,14 @@ sxt_cc_component( sxt_cc_component( name = "proof_computation", deps = [ + ":sumcheck_transcript", + ":driver", "//sxt/base/container:span", + "//sxt/base/error:assert", "//sxt/base/field:element", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:future", + "//sxt/execution/async:coroutine", ], test_deps = [ "//sxt/base/test:unit_test", diff --git a/sxt/proof/sumcheck2/proof_computation.h b/sxt/proof/sumcheck2/proof_computation.h index f444efdaa..a0dd847e5 100644 --- a/sxt/proof/sumcheck2/proof_computation.h +++ b/sxt/proof/sumcheck2/proof_computation.h @@ -1,4 +1,56 @@ #pragma once +#include "sxt/base/error/assert.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/base/field/element.h" +#include "sxt/proof/sumcheck2/driver.h" +#include "sxt/proof/sumcheck2/sumcheck_transcript.h" +#include "sxt/execution/async/future.h" + namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// prove_sum +//-------------------------------------------------------------------------------------------------- +template +xena::future<> prove_sum(basct::span polynomials, + basct::span evaluation_point, + sumcheck_transcript& transcript, const driver& drv, + basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned n) noexcept { + SXT_RELEASE_ASSERT(0 < n); + auto num_variables = std::max(basn::ceil_log2(n), 1); + auto polynomial_length = polynomials.size() / num_variables; + auto num_mles = mles.size() / n; + SXT_RELEASE_ASSERT( + // clang-format off + polynomial_length > 1 && + evaluation_point.size() == num_variables && + polynomials.size() == num_variables * polynomial_length && + mles.size() == n * num_mles + // clang-format on + ); + + transcript.init(num_variables, polynomial_length - 1); + + auto ws = co_await drv.make_workspace(mles, product_table, product_terms, n); + + for (unsigned round_index = 0; round_index < num_variables; ++round_index) { + auto polynomial = polynomials.subspan(round_index * polynomial_length, polynomial_length); + + // compute the round polynomial + co_await drv.sum(polynomial, *ws); + + // draw the next random challenge + T r; + transcript.round_challenge(r, polynomial); + evaluation_point[round_index] = r; + + // fold the polynomial + if (round_index < num_variables - 1u) { + co_await drv.fold(*ws, r); + } + } +} } // namespace sxt:prfsk2 From 704637ef723407723c3338f081fc1e929943e3f4 Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 19:58:53 -0800 Subject: [PATCH 18/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 10 ++++++++++ sxt/proof/sumcheck2/cpu_driver.cc | 1 + sxt/proof/sumcheck2/cpu_driver.h | 4 ++++ sxt/proof/sumcheck2/cpu_driver.t.cc | 5 +++++ 4 files changed, 20 insertions(+) create mode 100644 sxt/proof/sumcheck2/cpu_driver.cc create mode 100644 sxt/proof/sumcheck2/cpu_driver.h create mode 100644 sxt/proof/sumcheck2/cpu_driver.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 527d98d47..5497fd5da 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -21,6 +21,16 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "cpu_driver", + deps = [ + ":driver", + ], + test_deps = [ + "//sxt/base/test:unit_test", + ], +) + sxt_cc_component( name = "polynomial_utility", deps = [ diff --git a/sxt/proof/sumcheck2/cpu_driver.cc b/sxt/proof/sumcheck2/cpu_driver.cc new file mode 100644 index 000000000..9abb18f9f --- /dev/null +++ b/sxt/proof/sumcheck2/cpu_driver.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/cpu_driver.h" diff --git a/sxt/proof/sumcheck2/cpu_driver.h b/sxt/proof/sumcheck2/cpu_driver.h new file mode 100644 index 000000000..15d2df639 --- /dev/null +++ b/sxt/proof/sumcheck2/cpu_driver.h @@ -0,0 +1,4 @@ +#pragma once + +namespace sxt::prfsk2 { +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/cpu_driver.t.cc b/sxt/proof/sumcheck2/cpu_driver.t.cc new file mode 100644 index 000000000..9b07db8b6 --- /dev/null +++ b/sxt/proof/sumcheck2/cpu_driver.t.cc @@ -0,0 +1,5 @@ +#include "sxt/proof/sumcheck2/cpu_driver.h" + +#include "sxt/base/test/unit_test.h" + +TEST_CASE("todo") {} From 58763a41212253f97c99ee12bb17e97631b2c58b Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 20:04:35 -0800 Subject: [PATCH 19/83] rework cpu driver --- sxt/proof/sumcheck2/BUILD | 3 +++ sxt/proof/sumcheck2/cpu_driver.h | 41 ++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 5497fd5da..71a829e7a 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -25,6 +25,9 @@ sxt_cc_component( name = "cpu_driver", deps = [ ":driver", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:coroutine", + "//sxt/memory/management:managed_array", ], test_deps = [ "//sxt/base/test:unit_test", diff --git a/sxt/proof/sumcheck2/cpu_driver.h b/sxt/proof/sumcheck2/cpu_driver.h index 15d2df639..86ed1867b 100644 --- a/sxt/proof/sumcheck2/cpu_driver.h +++ b/sxt/proof/sumcheck2/cpu_driver.h @@ -1,4 +1,45 @@ #pragma once +#include "sxt/base/num/ceil_log2.h" +#include "sxt/proof/sumcheck2/driver.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/memory/management/managed_array.h" + namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// cpu_driver +//-------------------------------------------------------------------------------------------------- +template +class cpu_driver final : public driver { + struct cpu_workspace final : public workspace { + memmg::managed_array mles; + basct::cspan> product_table; + basct::cspan product_terms; + unsigned n; + unsigned num_variables; + }; + +public: + // driver + xena::future> + make_workspace(basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned n) const noexcept override { + auto res = std::make_unique(); + res->mles = memmg::managed_array{mles.begin(), mles.end()}; + res->product_table = product_table; + res->product_terms = product_terms; + res->n = n; + res->num_variables = std::max(basn::ceil_log2(n), 1); + return xena::make_ready_future>(std::move(res)); + } + + xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { + return {}; + } + + xena::future<> fold(workspace& ws, T& r) const noexcept override { + return {}; + } +}; } // namespace sxt::prfsk2 From 4ee310a801d2dee98a95d06e7d34eb9b329de152 Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 20:07:07 -0800 Subject: [PATCH 20/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 1 + sxt/proof/sumcheck2/cpu_driver.h | 48 +++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 71a829e7a..d224d45d4 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -25,6 +25,7 @@ sxt_cc_component( name = "cpu_driver", deps = [ ":driver", + "//sxt/base/error:assert", "//sxt/base/num:ceil_log2", "//sxt/execution/async:coroutine", "//sxt/memory/management:managed_array", diff --git a/sxt/proof/sumcheck2/cpu_driver.h b/sxt/proof/sumcheck2/cpu_driver.h index 86ed1867b..911dd2f9a 100644 --- a/sxt/proof/sumcheck2/cpu_driver.h +++ b/sxt/proof/sumcheck2/cpu_driver.h @@ -1,5 +1,6 @@ #pragma once +#include "sxt/base/error/assert.h" #include "sxt/base/num/ceil_log2.h" #include "sxt/proof/sumcheck2/driver.h" #include "sxt/execution/async/coroutine.h" @@ -35,7 +36,52 @@ class cpu_driver final : public driver { } xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { - return {}; + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + SXT_RELEASE_ASSERT(work.n >= mid); + + auto mles = work.mles.data(); + auto product_table = work.product_table; + auto product_terms = work.product_terms; + + for (auto& val : polynomial) { + val = {}; + } + + // expand paired terms +#if 0 + auto n1 = work.n - mid; + for (unsigned i = 0; i < n1; ++i) { + unsigned term_first = 0; + for (auto [mult, num_terms] : product_table) { + SXT_RELEASE_ASSERT(num_terms < polynomial.size()); + auto terms = product_terms.subspan(term_first, num_terms); + SXT_STACK_ARRAY(p, num_terms + 1u, T); + expand_products(p, mles + i, n, mid, terms); + for (unsigned term_index = 0; term_index < p.size(); ++term_index) { + muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); + } + term_first += num_terms; + } + } + + // expand terms where the corresponding pair is zero (i.e. n is not a power of 2) + for (unsigned i = n1; i < mid; ++i) { + unsigned term_first = 0; + for (auto [mult, num_terms] : product_table) { + auto terms = product_terms.subspan(term_first, num_terms); + SXT_STACK_ARRAY(p, num_terms + 1u, T); + partial_expand_products(p, mles + i, n, terms); + for (unsigned term_index = 0; term_index < p.size(); ++term_index) { + muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); + } + term_first += num_terms; + } + } +#endif + + return xena::make_ready_future(); } xena::future<> fold(workspace& ws, T& r) const noexcept override { From d6186d4947a618e7569dcbff877fd3863a64edf7 Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 20:09:53 -0800 Subject: [PATCH 21/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 1 + sxt/proof/sumcheck2/polynomial_utility.h | 44 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index d224d45d4..728bf2709 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -40,6 +40,7 @@ sxt_cc_component( deps = [ "//sxt/base/container:span", "//sxt/base/field:element", + "//sxt/base/macro:cuda_callable", ], test_deps = [ "//sxt/base/test:unit_test", diff --git a/sxt/proof/sumcheck2/polynomial_utility.h b/sxt/proof/sumcheck2/polynomial_utility.h index edd5efc98..9b9df5bd6 100644 --- a/sxt/proof/sumcheck2/polynomial_utility.h +++ b/sxt/proof/sumcheck2/polynomial_utility.h @@ -1,7 +1,10 @@ #pragma once +#include + #include "sxt/base/container/span.h" #include "sxt/base/field/element.h" +#include "sxt/base/macro/cuda_callable.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- @@ -40,4 +43,45 @@ void evaluate_polynomial(T& e, basct::cspan polynomial, const T& x) noexcept muladd(e, e, x, polynomial[i]); } } + +//-------------------------------------------------------------------------------------------------- +// expand_products +//-------------------------------------------------------------------------------------------------- +template +CUDA_CALLABLE +void expand_products(basct::span p, const T* mles, unsigned n, + unsigned step, basct::cspan terms) noexcept { + auto num_terms = terms.size(); + assert( + // clang-format off + num_terms > 0 && + n > step && + p.size() == num_terms + 1u + // clang-format on + ); + T a, b; + auto mle_index = terms[0]; + a = *(mles + mle_index * n); + b = *(mles + mle_index * n + step); + sub(b, b, a); + p[0] = a; + p[1] = b; + + for (unsigned i = 1; i < num_terms; ++i) { + auto mle_index = terms[i]; + a = *(mles + mle_index * n); + b = *(mles + mle_index * n + step); + sub(b, b, a); + + auto c_prev = p[0]; + mul(p[0], c_prev, a); + for (unsigned pow = 1u; pow < i + 1u; ++pow) { + auto c = p[pow]; + mul(p[pow], c, a); + muladd(p[pow], c_prev, b, p[pow]); + c_prev = c; + } + mul(p[i + 1u], c_prev, b); + } +} } // namespace sxt::prfsk2 From 84a29a0bfe5f3a602c7eee069d1c99780b8ebfe6 Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 20:12:30 -0800 Subject: [PATCH 22/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 2 ++ sxt/proof/sumcheck2/cpu_driver.h | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 728bf2709..bc4e51741 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -25,6 +25,8 @@ sxt_cc_component( name = "cpu_driver", deps = [ ":driver", + ":polynomial_utility", + "//sxt/base/container:stack_array", "//sxt/base/error:assert", "//sxt/base/num:ceil_log2", "//sxt/execution/async:coroutine", diff --git a/sxt/proof/sumcheck2/cpu_driver.h b/sxt/proof/sumcheck2/cpu_driver.h index 911dd2f9a..66c9d686b 100644 --- a/sxt/proof/sumcheck2/cpu_driver.h +++ b/sxt/proof/sumcheck2/cpu_driver.h @@ -1,10 +1,12 @@ #pragma once +#include "sxt/base/container/stack_array.h" #include "sxt/base/error/assert.h" #include "sxt/base/num/ceil_log2.h" -#include "sxt/proof/sumcheck2/driver.h" #include "sxt/execution/async/coroutine.h" #include "sxt/memory/management/managed_array.h" +#include "sxt/proof/sumcheck2/driver.h" +#include "sxt/proof/sumcheck2/polynomial_utility.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- @@ -50,7 +52,6 @@ class cpu_driver final : public driver { } // expand paired terms -#if 0 auto n1 = work.n - mid; for (unsigned i = 0; i < n1; ++i) { unsigned term_first = 0; @@ -66,6 +67,7 @@ class cpu_driver final : public driver { } } +#if 0 // expand terms where the corresponding pair is zero (i.e. n is not a power of 2) for (unsigned i = n1; i < mid; ++i) { unsigned term_first = 0; From 21b4d9e347570379f061b24503c83b6fe20640dd Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 20:14:02 -0800 Subject: [PATCH 23/83] rework sumcheck --- sxt/proof/sumcheck2/polynomial_utility.h | 38 ++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/sxt/proof/sumcheck2/polynomial_utility.h b/sxt/proof/sumcheck2/polynomial_utility.h index 9b9df5bd6..1f9dc59d2 100644 --- a/sxt/proof/sumcheck2/polynomial_utility.h +++ b/sxt/proof/sumcheck2/polynomial_utility.h @@ -84,4 +84,42 @@ void expand_products(basct::span p, const T* mles, unsigned n, mul(p[i + 1u], c_prev, b); } } + +//-------------------------------------------------------------------------------------------------- +// partial_expand_products +//-------------------------------------------------------------------------------------------------- +template +CUDA_CALLABLE +void partial_expand_products(basct::span p, const T* mles, unsigned n, + basct::cspan terms) noexcept { + auto num_terms = terms.size(); + assert( + // clang-format off + num_terms > 0 && + p.size() == num_terms + 1u + // clang-format on + ); + T a, b; + auto mle_index = terms[0]; + a = *(mles + mle_index * n); + neg(b, a); + p[0] = a; + p[1] = b; + + for (unsigned i = 1; i < num_terms; ++i) { + auto mle_index = terms[i]; + a = *(mles + mle_index * n); + neg(b, a); + + auto c_prev = p[0]; + mul(p[0], c_prev, a); + for (unsigned pow = 1u; pow < i + 1u; ++pow) { + auto c = p[pow]; + mul(p[pow], c, a); + muladd(p[pow], c_prev, b, p[pow]); + c_prev = c; + } + mul(p[i + 1u], c_prev, b); + } +} } // namespace sxt::prfsk2 From 22ebc2feb5e5236360e9f76a7e5797744048d3a5 Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 20:14:54 -0800 Subject: [PATCH 24/83] rework sumcheck --- sxt/proof/sumcheck2/cpu_driver.h | 72 ++++++++++++++++---------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/sxt/proof/sumcheck2/cpu_driver.h b/sxt/proof/sumcheck2/cpu_driver.h index 66c9d686b..38f05a507 100644 --- a/sxt/proof/sumcheck2/cpu_driver.h +++ b/sxt/proof/sumcheck2/cpu_driver.h @@ -38,52 +38,50 @@ class cpu_driver final : public driver { } xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { - auto& work = static_cast(ws); - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - SXT_RELEASE_ASSERT(work.n >= mid); + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + SXT_RELEASE_ASSERT(work.n >= mid); - auto mles = work.mles.data(); - auto product_table = work.product_table; - auto product_terms = work.product_terms; + auto mles = work.mles.data(); + auto product_table = work.product_table; + auto product_terms = work.product_terms; - for (auto& val : polynomial) { - val = {}; - } + for (auto& val : polynomial) { + val = {}; + } - // expand paired terms - auto n1 = work.n - mid; - for (unsigned i = 0; i < n1; ++i) { - unsigned term_first = 0; - for (auto [mult, num_terms] : product_table) { - SXT_RELEASE_ASSERT(num_terms < polynomial.size()); - auto terms = product_terms.subspan(term_first, num_terms); - SXT_STACK_ARRAY(p, num_terms + 1u, T); - expand_products(p, mles + i, n, mid, terms); - for (unsigned term_index = 0; term_index < p.size(); ++term_index) { - muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); + // expand paired terms + auto n1 = work.n - mid; + for (unsigned i = 0; i < n1; ++i) { + unsigned term_first = 0; + for (auto [mult, num_terms] : product_table) { + SXT_RELEASE_ASSERT(num_terms < polynomial.size()); + auto terms = product_terms.subspan(term_first, num_terms); + SXT_STACK_ARRAY(p, num_terms + 1u, T); + expand_products(p, mles + i, n, mid, terms); + for (unsigned term_index = 0; term_index < p.size(); ++term_index) { + muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); + } + term_first += num_terms; } - term_first += num_terms; } - } -#if 0 - // expand terms where the corresponding pair is zero (i.e. n is not a power of 2) - for (unsigned i = n1; i < mid; ++i) { - unsigned term_first = 0; - for (auto [mult, num_terms] : product_table) { - auto terms = product_terms.subspan(term_first, num_terms); - SXT_STACK_ARRAY(p, num_terms + 1u, T); - partial_expand_products(p, mles + i, n, terms); - for (unsigned term_index = 0; term_index < p.size(); ++term_index) { - muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); + // expand terms where the corresponding pair is zero (i.e. n is not a power of 2) + for (unsigned i = n1; i < mid; ++i) { + unsigned term_first = 0; + for (auto [mult, num_terms] : product_table) { + auto terms = product_terms.subspan(term_first, num_terms); + SXT_STACK_ARRAY(p, num_terms + 1u, T); + partial_expand_products(p, mles + i, n, terms); + for (unsigned term_index = 0; term_index < p.size(); ++term_index) { + muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); + } + term_first += num_terms; } - term_first += num_terms; } - } -#endif - return xena::make_ready_future(); + return xena::make_ready_future(); } xena::future<> fold(workspace& ws, T& r) const noexcept override { From 15dc8fb7614c66a4d1314b666e83dedd724116d1 Mon Sep 17 00:00:00 2001 From: rnburn Date: Wed, 19 Feb 2025 20:17:30 -0800 Subject: [PATCH 25/83] rework sumcheck --- sxt/proof/sumcheck2/cpu_driver.h | 44 ++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/sxt/proof/sumcheck2/cpu_driver.h b/sxt/proof/sumcheck2/cpu_driver.h index 38f05a507..8ec935120 100644 --- a/sxt/proof/sumcheck2/cpu_driver.h +++ b/sxt/proof/sumcheck2/cpu_driver.h @@ -86,6 +86,50 @@ class cpu_driver final : public driver { xena::future<> fold(workspace& ws, T& r) const noexcept override { return {}; +#if 0 + using s25t::operator""_s25; + + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + auto num_mles = work.mles.size() / n; + SXT_RELEASE_ASSERT( + // clang-format off + work.n >= mid && work.mles.size() % n == 0 + // clang-format on + ); + + auto mles = work.mles.data(); + memmg::managed_array mles_p(num_mles * mid); + + s25t::element one_m_r = 0x1_s25; + s25o::sub(one_m_r, one_m_r, r); + auto n1 = work.n - mid; + for (auto mle_index = 0; mle_index < num_mles; ++mle_index) { + auto data = mles + n * mle_index; + auto data_p = mles_p.data() + mid * mle_index; + + // fold paired terms + for (unsigned i = 0; i < n1; ++i) { + auto val = data[i]; + s25o::mul(val, val, one_m_r); + s25o::muladd(val, r, data[mid + i], val); + data_p[i] = val; + } + + // fold terms paired with zero + for (unsigned i = n1; i < mid; ++i) { + auto val = data[i]; + s25o::mul(val, val, one_m_r); + data_p[i] = val; + } + } + + work.n = mid; + --work.num_variables; + work.mles = std::move(mles_p); + return xena::make_ready_future(); +#endif } }; } // namespace sxt::prfsk2 From 2efca8ce651a0109585e36e81bb0d513e2f4519d Mon Sep 17 00:00:00 2001 From: rnburn Date: Thu, 20 Feb 2025 13:51:09 -0800 Subject: [PATCH 26/83] fill in polynomial utility tests --- sxt/proof/sumcheck2/BUILD | 3 + sxt/proof/sumcheck2/polynomial_utility.t.cc | 115 +++++++++++++++++++- 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index bc4e51741..753714d6a 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -46,6 +46,9 @@ sxt_cc_component( ], test_deps = [ "//sxt/base/test:unit_test", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/realization:field", + "//sxt/scalar25/type:literal", ], ) diff --git a/sxt/proof/sumcheck2/polynomial_utility.t.cc b/sxt/proof/sumcheck2/polynomial_utility.t.cc index 9c9345d0b..f08dc9829 100644 --- a/sxt/proof/sumcheck2/polynomial_utility.t.cc +++ b/sxt/proof/sumcheck2/polynomial_utility.t.cc @@ -1,5 +1,118 @@ #include "sxt/proof/sumcheck2/polynomial_utility.h" +#include + #include "sxt/base/test/unit_test.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/realization/field.h" +#include "sxt/scalar25/type/literal.h" + +using namespace sxt; +using namespace sxt::prfsk2; +using s25t::operator""_s25; + +using T = s25t::element; + +TEST_CASE("we perform basic operations on polynomials") { + + s25t::element e; + + std::vector p; + + SECTION("we can compute the 0-1 sum of a zero polynomials") { + sum_polynomial_01(e, p); + REQUIRE(e == 0x0_s25); + } + + SECTION("we can compute the 0-1 sum of a constant polynomial") { + p = {0x123_s25}; + sum_polynomial_01(e, p); + REQUIRE(e == 0x246_s25); + } + + SECTION("we can compute the 0-1 sum of a 1 degree polynomial") { + p = {0x123_s25, 0x456_s25}; + sum_polynomial_01(e, p); + REQUIRE(e == 0x246_s25 + 0x456_s25); + } + + SECTION("we can evaluate the zero polynomial") { + evaluate_polynomial(e, p, 0x123_s25); + REQUIRE(e == 0x0_s25); + } + + SECTION("we can evaluate a constant polynomial") { + p = {0x123_s25}; + evaluate_polynomial(e, p, 0x321_s25); + REQUIRE(e == 0x123_s25); + } + + SECTION("we can evaluate a polynomial of degree 1") { + p = {0x123_s25, 0x456_s25}; + evaluate_polynomial(e, p, 0x321_s25); + REQUIRE(e == 0x123_s25 + 0x456_s25 * 0x321_s25); + } + + SECTION("we can evaluate a polynomial of degree 2") { + p = {0x123_s25, 0x456_s25, 0x789_s25}; + evaluate_polynomial(e, p, 0x321_s25); + REQUIRE(e == 0x123_s25 + 0x456_s25 * 0x321_s25 + 0x789_s25 * 0x321_s25 * 0x321_s25); + } +} + +TEST_CASE("we can expand a product of MLEs") { + std::vector p; + std::vector mles; + std::vector terms; + + SECTION("we can expand a single MLE") { + p.resize(2); + mles = {0x123_s25, 0x456_s25}; + terms = {0}; + expand_products(p, mles.data(), 2, 1, terms); + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == mles[1] - mles[0]); + } + + SECTION("we can partially expand MLEs (where some terms are assumed to be zero)") { + mles = {0x123_s25, 0x0_s25}; + p.resize(2); + terms = {0}; + partial_expand_products(p, mles.data(), 1, terms); + + std::vector expected(2); + expand_products(expected, mles.data(), 2, 1, terms); + REQUIRE(p == expected); + } + + SECTION("we can expand two MLEs") { + p.resize(3); + mles = {0x123_s25, 0x456_s25, 0x1122_s25, 0x4455_s25}; + terms = {0, 1}; + expand_products(p, mles.data(), 2, 1, terms); + auto a1 = mles[0]; + auto a2 = mles[1] - mles[0]; + auto b1 = mles[2]; + auto b2 = mles[3] - mles[2]; + REQUIRE(p[0] == a1 * b1); + REQUIRE(p[1] == a1 * b2 + a2 * b1); + REQUIRE(p[2] == a2 * b2); + } -TEST_CASE("todo") {} + SECTION("we can expand three MLEs") { + p.resize(4); + mles = {0x123_s25, 0x456_s25, 0x1122_s25, 0x4455_s25, 0x2233_s25, 0x5566_s25}; + terms = {0, 1, 2}; + expand_products(p, mles.data(), 2, 1, terms); + auto a1 = mles[0]; + auto a2 = mles[1] - mles[0]; + auto b1 = mles[2]; + auto b2 = mles[3] - mles[2]; + auto c1 = mles[4]; + auto c2 = mles[5] - mles[4]; + REQUIRE(p[0] == a1 * b1 * c1); + REQUIRE(p[1] == a1 * b1 * c2 + a1 * b2 * c1 + a2 * b1 * c1); + REQUIRE(p[2] == a1 * b2 * c2 + a2 * b1 * c2 + a2 * b2 * c1); + REQUIRE(p[3] == a2 * b2 * c2); + } +} From 6a61ffe4e250b73e38a3a1d1cde73bdd187f2815 Mon Sep 17 00:00:00 2001 From: rnburn Date: Thu, 20 Feb 2025 14:03:51 -0800 Subject: [PATCH 27/83] fill in cpu driver --- sxt/base/field/element.h | 3 +++ sxt/proof/sumcheck2/cpu_driver.h | 17 ++++++----------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/sxt/base/field/element.h b/sxt/base/field/element.h index c91c6ff15..63a5708a8 100644 --- a/sxt/base/field/element.h +++ b/sxt/base/field/element.h @@ -1,5 +1,7 @@ #pragma once +#include + namespace sxt::basfld { //-------------------------------------------------------------------------------------------------- // element @@ -11,5 +13,6 @@ concept element = requires(T& res, const T& e) { sub(res, e, e); mul(res, e, e); muladd(res, e, e, e); + { T::identity() } noexcept -> std::same_as; }; } // namespace sxt::basfld diff --git a/sxt/proof/sumcheck2/cpu_driver.h b/sxt/proof/sumcheck2/cpu_driver.h index 8ec935120..6ab3c6fed 100644 --- a/sxt/proof/sumcheck2/cpu_driver.h +++ b/sxt/proof/sumcheck2/cpu_driver.h @@ -85,10 +85,6 @@ class cpu_driver final : public driver { } xena::future<> fold(workspace& ws, T& r) const noexcept override { - return {}; -#if 0 - using s25t::operator""_s25; - auto& work = static_cast(ws); auto n = work.n; auto mid = 1u << (work.num_variables - 1u); @@ -100,10 +96,10 @@ class cpu_driver final : public driver { ); auto mles = work.mles.data(); - memmg::managed_array mles_p(num_mles * mid); + memmg::managed_array mles_p(num_mles * mid); - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); + T one_m_r = T::identity(); + sub(one_m_r, one_m_r, r); auto n1 = work.n - mid; for (auto mle_index = 0; mle_index < num_mles; ++mle_index) { auto data = mles + n * mle_index; @@ -112,15 +108,15 @@ class cpu_driver final : public driver { // fold paired terms for (unsigned i = 0; i < n1; ++i) { auto val = data[i]; - s25o::mul(val, val, one_m_r); - s25o::muladd(val, r, data[mid + i], val); + mul(val, val, one_m_r); + muladd(val, r, data[mid + i], val); data_p[i] = val; } // fold terms paired with zero for (unsigned i = n1; i < mid; ++i) { auto val = data[i]; - s25o::mul(val, val, one_m_r); + mul(val, val, one_m_r); data_p[i] = val; } } @@ -129,7 +125,6 @@ class cpu_driver final : public driver { --work.num_variables; work.mles = std::move(mles_p); return xena::make_ready_future(); -#endif } }; } // namespace sxt::prfsk2 From d5574b02118ec0be63270bc71fb54b082e94257e Mon Sep 17 00:00:00 2001 From: rnburn Date: Thu, 20 Feb 2025 14:52:36 -0800 Subject: [PATCH 28/83] add driver test --- sxt/proof/sumcheck2/BUILD | 18 +++++ sxt/proof/sumcheck2/driver_test.cc | 117 +++++++++++++++++++++++++++++ sxt/proof/sumcheck2/driver_test.h | 11 +++ 3 files changed, 146 insertions(+) create mode 100644 sxt/proof/sumcheck2/driver_test.cc create mode 100644 sxt/proof/sumcheck2/driver_test.h diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 753714d6a..dafd05075 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -21,6 +21,24 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "driver_test", + deps = [ + ":driver", + "//sxt/scalar25/realization:field", + ], + impl_deps = [ + ":workspace", + "//sxt/base/test:unit_test", + "//sxt/execution/async:future", + "//sxt/execution/schedule:scheduler", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/type:element", + "//sxt/scalar25/type:literal", + ], + with_test = False, +) + sxt_cc_component( name = "cpu_driver", deps = [ diff --git a/sxt/proof/sumcheck2/driver_test.cc b/sxt/proof/sumcheck2/driver_test.cc new file mode 100644 index 000000000..d3da244e5 --- /dev/null +++ b/sxt/proof/sumcheck2/driver_test.cc @@ -0,0 +1,117 @@ +#include "sxt/proof/sumcheck2/driver_test.h" + +#include + +#include "sxt/base/test/unit_test.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/schedule/scheduler.h" +#include "sxt/proof/sumcheck/driver.h" +#include "sxt/proof/sumcheck/workspace.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/type/literal.h" + +namespace sxt::prfsk2 { +using s25t::operator""_s25; + +//-------------------------------------------------------------------------------------------------- +// exercise_driver +//-------------------------------------------------------------------------------------------------- +void exercise_driver(const driver& drv) { + std::vector mles; + std::vector> product_table{ + {0x1_s25, 1}, + }; + std::vector product_terms = {0}; + + std::vector p(2); + + SECTION("we can sum a polynomial with n = 1") { + std::vector mles = {0x123_s25}; + auto ws = drv.make_workspace(mles, product_table, product_terms, 1); + xens::get_scheduler().run(); + auto fut = drv.sum(p, *ws.value()); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == -mles[0]); + } + + SECTION("we can sum a polynomial with a non-unity multiplier") { + std::vector mles = {0x123_s25}; + product_table[0].first = 0x2_s25; + auto ws = drv.make_workspace(mles, product_table, product_terms, 1); + xens::get_scheduler().run(); + auto fut = drv.sum(p, *ws.value()); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == 0x2_s25 * mles[0]); + REQUIRE(p[1] == -0x2_s25 * mles[0]); + } + + SECTION("we can sum a polynomial with n = 2") { + std::vector mles = {0x123_s25, 0x456_s25}; + auto ws = drv.make_workspace(mles, product_table, product_terms, 2); + xens::get_scheduler().run(); + auto fut = drv.sum(p, *ws.value()); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == mles[1] - mles[0]); + } + + SECTION("we can sum a polynomial with two MLEs added together") { + std::vector mles = {0x123_s25, 0x456_s25}; + std::vector> product_table{ + {0x1_s25, 1}, + {0x1_s25, 1}, + }; + std::vector product_terms = {0, 1}; + + auto ws = drv.make_workspace(mles, product_table, product_terms, 1); + xens::get_scheduler().run(); + auto fut = drv.sum(p, *ws.value()); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0] + mles[1]); + REQUIRE(p[1] == -mles[0] - mles[1]); + } + + SECTION("we can sum a polynomial with two MLEs multiplied together") { + std::vector mles = {0x123_s25, 0x456_s25}; + std::vector> product_table{ + {0x1_s25, 2}, + }; + std::vector product_terms = {0, 1}; + p.resize(3); + + auto ws = drv.make_workspace(mles, product_table, product_terms, 1); + xens::get_scheduler().run(); + auto fut = drv.sum(p, *ws.value()); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0] * mles[1]); + REQUIRE(p[1] == -mles[0] * mles[1] - mles[1] * mles[0]); + REQUIRE(p[2] == mles[0] * mles[1]); + } + + SECTION("we can fold mles") { + std::vector mles = {0x123_s25, 0x456_s25, 0x789_s25}; + auto ws = drv.make_workspace(mles, product_table, product_terms, 3); + xens::get_scheduler().run(); + auto r = 0xabc123_s25; + auto fut = drv.fold(*ws.value(), r); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + fut = drv.sum(p, *ws.value()); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + + mles[0] = (0x1_s25 - r) * mles[0] + r * mles[2]; + mles[1] = (0x1_s25 - r) * mles[1]; + + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == mles[1] - mles[0]); + } +} +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/driver_test.h b/sxt/proof/sumcheck2/driver_test.h new file mode 100644 index 000000000..3fc355147 --- /dev/null +++ b/sxt/proof/sumcheck2/driver_test.h @@ -0,0 +1,11 @@ +#pragma once + +#include "sxt/proof/sumcheck2/driver.h" +#include "sxt/scalar25/realization/field.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// exercise_driver +//-------------------------------------------------------------------------------------------------- +void exercise_driver(const driver& drv); +} // namespace sxt::prfsk2 From ce2abd07fd062a68cc042a48604165ef08aec262 Mon Sep 17 00:00:00 2001 From: rnburn Date: Thu, 20 Feb 2025 15:04:49 -0800 Subject: [PATCH 29/83] fill in cpu driver --- sxt/base/field/element.h | 1 + sxt/proof/sumcheck2/BUILD | 1 + sxt/proof/sumcheck2/cpu_driver.h | 68 ++++++++++++++--------------- sxt/proof/sumcheck2/cpu_driver.t.cc | 9 +++- sxt/scalar25/type/element.h | 6 +++ 5 files changed, 50 insertions(+), 35 deletions(-) diff --git a/sxt/base/field/element.h b/sxt/base/field/element.h index 63a5708a8..f5409323b 100644 --- a/sxt/base/field/element.h +++ b/sxt/base/field/element.h @@ -14,5 +14,6 @@ concept element = requires(T& res, const T& e) { mul(res, e, e); muladd(res, e, e, e); { T::identity() } noexcept -> std::same_as; + { T::one() } noexcept -> std::same_as; }; } // namespace sxt::basfld diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index dafd05075..68e7c1b1f 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -51,6 +51,7 @@ sxt_cc_component( "//sxt/memory/management:managed_array", ], test_deps = [ + ":driver_test", "//sxt/base/test:unit_test", ], ) diff --git a/sxt/proof/sumcheck2/cpu_driver.h b/sxt/proof/sumcheck2/cpu_driver.h index 6ab3c6fed..45e1e497d 100644 --- a/sxt/proof/sumcheck2/cpu_driver.h +++ b/sxt/proof/sumcheck2/cpu_driver.h @@ -84,47 +84,47 @@ class cpu_driver final : public driver { return xena::make_ready_future(); } - xena::future<> fold(workspace& ws, T& r) const noexcept override { - auto& work = static_cast(ws); - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - auto num_mles = work.mles.size() / n; - SXT_RELEASE_ASSERT( - // clang-format off + xena::future<> fold(workspace& ws, const T& r) const noexcept override { + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + auto num_mles = work.mles.size() / n; + SXT_RELEASE_ASSERT( + // clang-format off work.n >= mid && work.mles.size() % n == 0 - // clang-format on - ); + // clang-format on + ); - auto mles = work.mles.data(); - memmg::managed_array mles_p(num_mles * mid); + auto mles = work.mles.data(); + memmg::managed_array mles_p(num_mles * mid); - T one_m_r = T::identity(); - sub(one_m_r, one_m_r, r); - auto n1 = work.n - mid; - for (auto mle_index = 0; mle_index < num_mles; ++mle_index) { - auto data = mles + n * mle_index; - auto data_p = mles_p.data() + mid * mle_index; + T one_m_r = T::one(); + sub(one_m_r, one_m_r, r); + auto n1 = work.n - mid; + for (auto mle_index = 0; mle_index < num_mles; ++mle_index) { + auto data = mles + n * mle_index; + auto data_p = mles_p.data() + mid * mle_index; - // fold paired terms - for (unsigned i = 0; i < n1; ++i) { - auto val = data[i]; - mul(val, val, one_m_r); - muladd(val, r, data[mid + i], val); - data_p[i] = val; - } + // fold paired terms + for (unsigned i = 0; i < n1; ++i) { + auto val = data[i]; + mul(val, val, one_m_r); + muladd(val, r, data[mid + i], val); + data_p[i] = val; + } - // fold terms paired with zero - for (unsigned i = n1; i < mid; ++i) { - auto val = data[i]; - mul(val, val, one_m_r); - data_p[i] = val; + // fold terms paired with zero + for (unsigned i = n1; i < mid; ++i) { + auto val = data[i]; + mul(val, val, one_m_r); + data_p[i] = val; + } } - } - work.n = mid; - --work.num_variables; - work.mles = std::move(mles_p); - return xena::make_ready_future(); + work.n = mid; + --work.num_variables; + work.mles = std::move(mles_p); + return xena::make_ready_future(); } }; } // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/cpu_driver.t.cc b/sxt/proof/sumcheck2/cpu_driver.t.cc index 9b07db8b6..4da8f688e 100644 --- a/sxt/proof/sumcheck2/cpu_driver.t.cc +++ b/sxt/proof/sumcheck2/cpu_driver.t.cc @@ -1,5 +1,12 @@ #include "sxt/proof/sumcheck2/cpu_driver.h" +#include "sxt/proof/sumcheck2/driver_test.h" #include "sxt/base/test/unit_test.h" -TEST_CASE("todo") {} +using namespace sxt; +using namespace sxt::prfsk2; + +TEST_CASE("we can perform the primitive operations for sumcheck proofs") { + cpu_driver drv; + exercise_driver(drv); +} diff --git a/sxt/scalar25/type/element.h b/sxt/scalar25/type/element.h index 478b22752..5e78cf638 100644 --- a/sxt/scalar25/type/element.h +++ b/sxt/scalar25/type/element.h @@ -54,6 +54,12 @@ class element : public s25o::operation_adl_stub { static constexpr element identity() noexcept { return element{}; }; + static constexpr element one() noexcept { + element res{}; + res.data_[0] = 1; + return res; + } + private: uint8_t data_[32]; }; From 09a0660f96871b593fb955a85f7c877bd6e08815 Mon Sep 17 00:00:00 2001 From: rnburn Date: Thu, 20 Feb 2025 18:26:58 -0800 Subject: [PATCH 30/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 17 +++++++++++++++++ sxt/proof/sumcheck2/gpu_driver.cc | 1 + sxt/proof/sumcheck2/gpu_driver.h | 4 ++++ sxt/proof/sumcheck2/gpu_driver.t.cc | 5 +++++ 4 files changed, 27 insertions(+) create mode 100644 sxt/proof/sumcheck2/gpu_driver.cc create mode 100644 sxt/proof/sumcheck2/gpu_driver.h create mode 100644 sxt/proof/sumcheck2/gpu_driver.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 68e7c1b1f..99be26916 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -56,6 +56,23 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "gpu_driver", + deps = [ + ":driver", + # ":polynomial_utility", + # "//sxt/base/container:stack_array", + # "//sxt/base/error:assert", + # "//sxt/base/num:ceil_log2", + # "//sxt/execution/async:coroutine", + # "//sxt/memory/management:managed_array", + ], + test_deps = [ + ":driver_test", + "//sxt/base/test:unit_test", + ], +) + sxt_cc_component( name = "polynomial_utility", deps = [ diff --git a/sxt/proof/sumcheck2/gpu_driver.cc b/sxt/proof/sumcheck2/gpu_driver.cc new file mode 100644 index 000000000..ab4d321ef --- /dev/null +++ b/sxt/proof/sumcheck2/gpu_driver.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/gpu_driver.h" diff --git a/sxt/proof/sumcheck2/gpu_driver.h b/sxt/proof/sumcheck2/gpu_driver.h new file mode 100644 index 000000000..15d2df639 --- /dev/null +++ b/sxt/proof/sumcheck2/gpu_driver.h @@ -0,0 +1,4 @@ +#pragma once + +namespace sxt::prfsk2 { +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/gpu_driver.t.cc b/sxt/proof/sumcheck2/gpu_driver.t.cc new file mode 100644 index 000000000..0fc2bb4d4 --- /dev/null +++ b/sxt/proof/sumcheck2/gpu_driver.t.cc @@ -0,0 +1,5 @@ +#include "sxt/proof/sumcheck2/gpu_driver.h" + +#include "sxt/base/test/unit_test.h" + +TEST_CASE("todo") {} From d73d9e8cfc3ffa9fbedff37cc71866029832dfb5 Mon Sep 17 00:00:00 2001 From: rnburn Date: Thu, 20 Feb 2025 18:46:05 -0800 Subject: [PATCH 31/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 11 +++-- sxt/proof/sumcheck2/gpu_driver.h | 79 ++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 99be26916..7e91cbb46 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -61,11 +61,14 @@ sxt_cc_component( deps = [ ":driver", # ":polynomial_utility", - # "//sxt/base/container:stack_array", # "//sxt/base/error:assert", - # "//sxt/base/num:ceil_log2", - # "//sxt/execution/async:coroutine", - # "//sxt/memory/management:managed_array", + "//sxt/base/device:memory_utility", + "//sxt/base/device:stream", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:coroutine", + "//sxt/execution/device:synchronization", + "//sxt/memory/management:managed_array", + "//sxt/memory/resource:device_resource", ], test_deps = [ ":driver_test", diff --git a/sxt/proof/sumcheck2/gpu_driver.h b/sxt/proof/sumcheck2/gpu_driver.h index 15d2df639..f4470e1a0 100644 --- a/sxt/proof/sumcheck2/gpu_driver.h +++ b/sxt/proof/sumcheck2/gpu_driver.h @@ -1,4 +1,83 @@ #pragma once +#include + +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/device/synchronization.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/device_resource.h" +#include "sxt/proof/sumcheck2/driver.h" + namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// gpu_driver +//-------------------------------------------------------------------------------------------------- +template +class gpu_driver final : public driver { +public: + struct gpu_workspace final : public workspace { + memmg::managed_array mles; + memmg::managed_array> product_table; + memmg::managed_array product_terms; + unsigned n; + unsigned num_variables; + + gpu_workspace() noexcept + : mles{memr::get_device_resource()}, product_table{memr::get_device_resource()}, + product_terms{memr::get_device_resource()} {} + }; + + // driver + xena::future> + make_workspace(basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned n) const noexcept override { + auto ws = std::make_unique(); + + // dimensions + ws->n = n; + ws->num_variables = std::max(basn::ceil_log2(n), 1); + + // mles + ws->mles = memmg::managed_array{ + mles.size(), + memr::get_device_resource(), + }; + basdv::stream mle_stream; + basdv::async_copy_host_to_device(ws->mles, mles, mle_stream); + + // product_table + ws->product_table = memmg::managed_array>{ + product_table.size(), + memr::get_device_resource(), + }; + basdv::stream product_table_stream; + basdv::async_copy_host_to_device(ws->product_table, product_table, product_table_stream); + + // product_terms + ws->product_terms = memmg::managed_array{ + product_terms.size(), + memr::get_device_resource(), + }; + basdv::stream product_terms_stream; + basdv::async_copy_host_to_device(ws->product_terms, product_terms, product_terms_stream); + + // await + co_await xendv::await_stream(mle_stream); + co_await xendv::await_stream(product_table_stream); + co_await xendv::await_stream(product_terms_stream); + co_return ws; + } + + xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { + return {}; + } + + xena::future<> fold(workspace& ws, const T& r) const noexcept override { + return {}; + } +}; } // namespace sxt::prfsk2 From 8e992f1ae40ce7c37f9dcfc52bfb11d42c7c4134 Mon Sep 17 00:00:00 2001 From: rnburn Date: Thu, 20 Feb 2025 18:49:28 -0800 Subject: [PATCH 32/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 3 +- sxt/proof/sumcheck2/gpu_driver.h | 47 +++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 7e91cbb46..fe2449c0a 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -61,7 +61,8 @@ sxt_cc_component( deps = [ ":driver", # ":polynomial_utility", - # "//sxt/base/error:assert", + "//sxt/base/error:assert", + "//sxt/algorithm/iteration:for_each", "//sxt/base/device:memory_utility", "//sxt/base/device:stream", "//sxt/base/num:ceil_log2", diff --git a/sxt/proof/sumcheck2/gpu_driver.h b/sxt/proof/sumcheck2/gpu_driver.h index f4470e1a0..6636e34b3 100644 --- a/sxt/proof/sumcheck2/gpu_driver.h +++ b/sxt/proof/sumcheck2/gpu_driver.h @@ -2,8 +2,10 @@ #include +#include "sxt/algorithm/iteration/for_each.h" #include "sxt/base/device/memory_utility.h" #include "sxt/base/device/stream.h" +#include "sxt/base/error/assert.h" #include "sxt/base/num/ceil_log2.h" #include "sxt/execution/async/coroutine.h" #include "sxt/execution/device/synchronization.h" @@ -77,7 +79,50 @@ class gpu_driver final : public driver { } xena::future<> fold(workspace& ws, const T& r) const noexcept override { - return {}; + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + auto num_mles = work.mles.size() / n; + SXT_RELEASE_ASSERT( + // clang-format off + work.n >= mid && work.mles.size() % n == 0 + // clang-format on + ); + + T one_m_r = T::one(); + sub(one_m_r, one_m_r, r); + + memmg::managed_array mles_p{num_mles * mid, memr::get_device_resource()}; + + auto f = + [ + // clang-format off + mles_p = mles_p.data(), + mles = work.mles.data(), + n = n, + num_mles = num_mles, + r = r, + one_m_r = one_m_r + // clang-format on + ] __device__ + __host__(unsigned mid, unsigned i) noexcept { + for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { + auto val = mles[i + mle_index * n]; + mul(val, val, one_m_r); + if (mid + i < n) { + muladd(val, r, mles[mid + i + mle_index * n], val); + } + mles_p[i + mle_index * mid] = val; + } + }; + auto fut = algi::for_each(f, mid); + + // complete + co_await std::move(fut); + + work.n = mid; + --work.num_variables; + work.mles = std::move(mles_p); } }; } // namespace sxt::prfsk2 From c21efbe77b8e0e0499d3f70c5c46e4a7059abfda Mon Sep 17 00:00:00 2001 From: rnburn Date: Thu, 20 Feb 2025 21:31:31 -0800 Subject: [PATCH 33/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 40 ++++++++++++ sxt/proof/sumcheck2/sum_gpu.cc | 1 + sxt/proof/sumcheck2/sum_gpu.h | 108 +++++++++++++++++++++++++++++++ sxt/proof/sumcheck2/sum_gpu.t.cc | 5 ++ 4 files changed, 154 insertions(+) create mode 100644 sxt/proof/sumcheck2/sum_gpu.cc create mode 100644 sxt/proof/sumcheck2/sum_gpu.h create mode 100644 sxt/proof/sumcheck2/sum_gpu.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index fe2449c0a..82857070d 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -134,6 +134,46 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "sum_gpu", + impl_deps = [ + # ":constant", + # ":device_cache", + # ":mle_utility", + # ":polynomial_mapper", + # ":reduction_gpu", + "//sxt/algorithm/reduction:kernel_fit", + "//sxt/algorithm/reduction:thread_reduction", + "//sxt/base/device:memory_utility", + "//sxt/base/device:stream", + "//sxt/base/device:state", + "//sxt/base/field:element", + "//sxt/base/iterator:split", + "//sxt/base/num:ceil_log2", + "//sxt/base/num:constexpr_switch", + "//sxt/execution/async:coroutine", + "//sxt/execution/async:future", + "//sxt/execution/device:for_each", + "//sxt/execution/kernel:kernel_dims", + "//sxt/memory/resource:device_resource", + "//sxt/memory/resource:async_device_resource", + ], + test_deps = [ + # ":device_cache", + "//sxt/base/iterator:split", + "//sxt/base/test:unit_test", + "//sxt/execution/async:future", + "//sxt/execution/schedule:scheduler", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/type:element", + "//sxt/scalar25/type:literal", + ], + deps = [ + "//sxt/base/container:span", + "//sxt/execution/async:future_fwd", + ], +) + sxt_cc_component( name = "verification", # impl_deps = [ diff --git a/sxt/proof/sumcheck2/sum_gpu.cc b/sxt/proof/sumcheck2/sum_gpu.cc new file mode 100644 index 000000000..89c02f6fd --- /dev/null +++ b/sxt/proof/sumcheck2/sum_gpu.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/sum_gpu.h" diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h new file mode 100644 index 000000000..561a83c99 --- /dev/null +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -0,0 +1,108 @@ +#pragma once + +#include + +#include "sxt/base/field/element.h" +#include "sxt/algorithm/reduction/kernel_fit.h" +#include "sxt/algorithm/reduction/thread_reduction.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/state.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/iterator/split.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/base/num/constexpr_switch.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/device/for_each.h" +#include "sxt/execution/kernel/kernel_dims.h" +#include "sxt/execution/kernel/launch.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/async_device_resource.h" +#include "sxt/memory/resource/device_resource.h" +/* #include "sxt/proof/sumcheck/constant.h" */ +/* #include "sxt/proof/sumcheck/device_cache.h" */ +/* #include "sxt/proof/sumcheck/mle_utility.h" */ +/* #include "sxt/proof/sumcheck/polynomial_mapper.h" */ +/* #include "sxt/proof/sumcheck/reduction_gpu.h" */ + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// sum_gpu +//-------------------------------------------------------------------------------------------------- +#if 0 +xena::future<> sum_gpu(basct::span p, device_cache& cache, + const basit::split_options& options, basct::cspan mles, + unsigned n) noexcept { + auto num_variables = std::max(basn::ceil_log2(n), 1); + auto mid = 1u << (num_variables - 1u); + auto num_mles = mles.size() / n; + auto num_coefficients = p.size(); + + // split + auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, mid}, options); + + // sum + size_t counter = 0; + co_await xendv::concurrent_for_each( + chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> { + basdv::stream stream; + memr::async_device_resource resource{stream}; + + // copy partial mles to device + memmg::managed_array partial_mles{&resource}; + copy_partial_mles(partial_mles, stream, mles, n, rng.a(), rng.b()); + auto split = rng.b() - rng.a(); + auto np = partial_mles.size() / num_mles; + + // lookup problem descriptor + basct::cspan> product_table; + basct::cspan product_terms; + cache.lookup(product_table, product_terms, stream); + + // compute + memmg::managed_array partial_p(num_coefficients); + co_await partial_sum(partial_p, stream, partial_mles, product_table, product_terms, split, + np); + + // fill in the result + if (counter == 0) { + for (unsigned i = 0; i < num_coefficients; ++i) { + p[i] = partial_p[i]; + } + } else { + for (unsigned i = 0; i < num_coefficients; ++i) { + s25o::add(p[i], p[i], partial_p[i]); + } + } + ++counter; + }); +} + +xena::future<> sum_gpu(basct::span p, device_cache& cache, + basct::cspan mles, unsigned n) noexcept { + basit::split_options options{ + .min_chunk_size = 100'000u, + .max_chunk_size = 200'000u, + .split_factor = basdv::get_num_devices(), + }; + co_await sum_gpu(p, cache, options, mles, n); +} + +xena::future<> sum_gpu(basct::span p, basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned n) noexcept { + auto num_variables = std::max(basn::ceil_log2(n), 1); + auto mid = 1u << (num_variables - 1u); + SXT_DEBUG_ASSERT( + // clang-format off + basdv::is_host_pointer(p.data()) && + basdv::is_active_device_pointer(mles.data()) && + basdv::is_active_device_pointer(product_table.data()) && + basdv::is_active_device_pointer(product_terms.data()) + // clang-format on + ); + basdv::stream stream; + co_await partial_sum(p, stream, mles, product_table, product_terms, mid, n); +} +#endif +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/sum_gpu.t.cc b/sxt/proof/sumcheck2/sum_gpu.t.cc new file mode 100644 index 000000000..a0550f4db --- /dev/null +++ b/sxt/proof/sumcheck2/sum_gpu.t.cc @@ -0,0 +1,5 @@ +#include "sxt/proof/sumcheck2/sum_gpu.h" + +#include "sxt/base/test/unit_test.h" + +TEST_CASE("todo") {} From d20097e2b2702cd84e2cbd2b5d11444da08abd9d Mon Sep 17 00:00:00 2001 From: rnburn Date: Thu, 20 Feb 2025 22:12:25 -0800 Subject: [PATCH 34/83] rework sumcheck --- sxt/proof/sumcheck2/sum_gpu.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index 561a83c99..f0fa2ad9e 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -87,9 +87,11 @@ xena::future<> sum_gpu(basct::span p, device_cache& cache, }; co_await sum_gpu(p, cache, options, mles, n); } +#endif -xena::future<> sum_gpu(basct::span p, basct::cspan mles, - basct::cspan> product_table, +template +xena::future<> sum_gpu(basct::span p, basct::cspan mles, + basct::cspan> product_table, basct::cspan product_terms, unsigned n) noexcept { auto num_variables = std::max(basn::ceil_log2(n), 1); auto mid = 1u << (num_variables - 1u); @@ -102,7 +104,7 @@ xena::future<> sum_gpu(basct::span p, basct::cspan // clang-format on ); basdv::stream stream; - co_await partial_sum(p, stream, mles, product_table, product_terms, mid, n); + /* co_await partial_sum(p, stream, mles, product_table, product_terms, mid, n); */ + return {}; } -#endif } // namespace sxt::prfsk2 From 5990f59e917c4acbe49e65be49c241b262c77997 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 09:15:50 -0800 Subject: [PATCH 35/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 9 +++++++++ sxt/proof/sumcheck2/polynomial_reducer.cc | 1 + sxt/proof/sumcheck2/polynomial_reducer.h | 19 +++++++++++++++++++ sxt/proof/sumcheck2/sum_gpu.h | 9 +++++++++ 4 files changed, 38 insertions(+) create mode 100644 sxt/proof/sumcheck2/polynomial_reducer.cc create mode 100644 sxt/proof/sumcheck2/polynomial_reducer.h diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 82857070d..2e9c30f87 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -77,6 +77,15 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "polynomial_reducer", + deps = [ + "//sxt/base/field:element", + "//sxt/base/macro:cuda_callable", + ], + with_test = False, +) + sxt_cc_component( name = "polynomial_utility", deps = [ diff --git a/sxt/proof/sumcheck2/polynomial_reducer.cc b/sxt/proof/sumcheck2/polynomial_reducer.cc new file mode 100644 index 000000000..42f9483b2 --- /dev/null +++ b/sxt/proof/sumcheck2/polynomial_reducer.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/polynomial_reducer.h" diff --git a/sxt/proof/sumcheck2/polynomial_reducer.h b/sxt/proof/sumcheck2/polynomial_reducer.h new file mode 100644 index 000000000..426be8716 --- /dev/null +++ b/sxt/proof/sumcheck2/polynomial_reducer.h @@ -0,0 +1,19 @@ +#pragma once + +#include "sxt/base/macro/cuda_callable.h" +#include "sxt/base/field/element.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// polynomial_reducer +//-------------------------------------------------------------------------------------------------- +template struct polynomial_reducer { + using value_type = std::array; + + CUDA_CALLABLE static void accumulate_inplace(value_type& res, const value_type& e) noexcept { + for (unsigned i = 0; i < res.size(); ++i) { + add(res[i], res[i], e[i]); + } + } +}; +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index f0fa2ad9e..d05a9e5fc 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -26,6 +26,15 @@ /* #include "sxt/proof/sumcheck/reduction_gpu.h" */ namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// sum_options +//-------------------------------------------------------------------------------------------------- +struct sum_options { + unsigned min_chunk_size = 100'000u; + unsigned max_chunk_size = 250'000u; + unsigned split_factor = unsigned(basdv::get_num_devices()); +}; + //-------------------------------------------------------------------------------------------------- // sum_gpu //-------------------------------------------------------------------------------------------------- From d704599de71ad98577f76599b679fbacfe146cab Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 09:18:21 -0800 Subject: [PATCH 36/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 9 ++++++++- sxt/proof/sumcheck2/constant.cc | 1 + sxt/proof/sumcheck2/constant.h | 10 ++++++++++ sxt/proof/sumcheck2/sum_gpu.h | 3 ++- 4 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 sxt/proof/sumcheck2/constant.cc create mode 100644 sxt/proof/sumcheck2/constant.h diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 2e9c30f87..4983cac70 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -3,6 +3,13 @@ load( "sxt_cc_component", ) +sxt_cc_component( + name = "constant", + with_test = False, + deps = [ + ], +) + sxt_cc_component( name = "workspace", with_test = False, @@ -146,7 +153,7 @@ sxt_cc_component( sxt_cc_component( name = "sum_gpu", impl_deps = [ - # ":constant", + ":constant", # ":device_cache", # ":mle_utility", # ":polynomial_mapper", diff --git a/sxt/proof/sumcheck2/constant.cc b/sxt/proof/sumcheck2/constant.cc new file mode 100644 index 000000000..5b59043b7 --- /dev/null +++ b/sxt/proof/sumcheck2/constant.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/constant.h" diff --git a/sxt/proof/sumcheck2/constant.h b/sxt/proof/sumcheck2/constant.h new file mode 100644 index 000000000..53a35ee97 --- /dev/null +++ b/sxt/proof/sumcheck2/constant.h @@ -0,0 +1,10 @@ +#pragma once + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// max_degree_v +//-------------------------------------------------------------------------------------------------- +// the maximum degree of the round polynomial +// used in sumcheck +constexpr unsigned max_degree_v = 5u; +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index d05a9e5fc..a53fc4d0e 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -19,7 +19,8 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" #include "sxt/memory/resource/device_resource.h" -/* #include "sxt/proof/sumcheck/constant.h" */ +#include "sxt/proof/sumcheck2/polynomial_reducer.h" +#include "sxt/proof/sumcheck/constant.h" /* #include "sxt/proof/sumcheck/device_cache.h" */ /* #include "sxt/proof/sumcheck/mle_utility.h" */ /* #include "sxt/proof/sumcheck/polynomial_mapper.h" */ From d79c241a56991b6df6b40c6755b7e00c62245049 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 13:21:26 -0800 Subject: [PATCH 37/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 12 ++++++++ sxt/proof/sumcheck2/polynomial_mapper.cc | 1 + sxt/proof/sumcheck2/polynomial_mapper.h | 35 ++++++++++++++++++++++ sxt/proof/sumcheck2/polynomial_mapper.t.cc | 5 ++++ sxt/proof/sumcheck2/sum_gpu.h | 25 ++++++++++++++++ 5 files changed, 78 insertions(+) create mode 100644 sxt/proof/sumcheck2/polynomial_mapper.cc create mode 100644 sxt/proof/sumcheck2/polynomial_mapper.h create mode 100644 sxt/proof/sumcheck2/polynomial_mapper.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 4983cac70..9b8de1116 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -84,6 +84,18 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "polynomial_mapper", + deps = [ + ":polynomial_utility", + "//sxt/base/field:element", + "//sxt/base/macro:cuda_callable", + ], + test_deps = [ + "//sxt/base/test:unit_test", + ], +) + sxt_cc_component( name = "polynomial_reducer", deps = [ diff --git a/sxt/proof/sumcheck2/polynomial_mapper.cc b/sxt/proof/sumcheck2/polynomial_mapper.cc new file mode 100644 index 000000000..6884b6798 --- /dev/null +++ b/sxt/proof/sumcheck2/polynomial_mapper.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/polynomial_mapper.h" diff --git a/sxt/proof/sumcheck2/polynomial_mapper.h b/sxt/proof/sumcheck2/polynomial_mapper.h new file mode 100644 index 000000000..932506fdc --- /dev/null +++ b/sxt/proof/sumcheck2/polynomial_mapper.h @@ -0,0 +1,35 @@ +#pragma once + +#include "sxt/base/field/element.h" +#include "sxt/base/macro/cuda_callable.h" +#include "sxt/proof/sumcheck2/polynomial_utility.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// polynomial_mapper +//-------------------------------------------------------------------------------------------------- +template struct polynomial_mapper { + using value_type = std::array; + + CUDA_CALLABLE + value_type map_index(unsigned index) const noexcept { + value_type res; + this->map_index(res, index); + return res; + } + + CUDA_CALLABLE + void map_index(value_type& p, unsigned index) const noexcept { + if (index + split < n) { + expand_products(p, mles + index, n, split, {product_terms, Degree}); + } else { + partial_expand_products(p, mles + index, n, {product_terms, Degree}); + } + } + + const T* __restrict__ mles; + const unsigned* __restrict__ product_terms; + unsigned split; + unsigned n; +}; +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/polynomial_mapper.t.cc b/sxt/proof/sumcheck2/polynomial_mapper.t.cc new file mode 100644 index 000000000..38b38a767 --- /dev/null +++ b/sxt/proof/sumcheck2/polynomial_mapper.t.cc @@ -0,0 +1,5 @@ +#include "sxt/proof/sumcheck2/polynomial_mapper.h" + +#include "sxt/base/test/unit_test.h" + +TEST_CASE("todo") {} diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index a53fc4d0e..30d8b868d 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -36,6 +36,31 @@ struct sum_options { unsigned split_factor = unsigned(basdv::get_num_devices()); }; +//-------------------------------------------------------------------------------------------------- +// partial_sum_kernel_impl +//-------------------------------------------------------------------------------------------------- +#if 0 +template +__device__ static void partial_sum_kernel_impl(T* __restrict__ shared_data, + const T* __restrict__ mles, + const unsigned* __restrict__ product_terms, + unsigned split, unsigned n) noexcept { + using Mapper = polynomial_mapper; + using Reducer = polynomial_reducer; + using T = Mapper::value_type; + Mapper mapper{ + .mles = mles, + .product_terms = product_terms, + .split = split, + .n = n, + }; + auto index = blockIdx.x * (BlockSize * 2) + threadIdx.x; + auto step = BlockSize * 2 * gridDim.x; + algr::thread_reduce(reinterpret_cast(shared_data), mapper, split, step, + threadIdx.x, index); +} +#endif + //-------------------------------------------------------------------------------------------------- // sum_gpu //-------------------------------------------------------------------------------------------------- From a62975921c07cb56b98228f7f3f8874f752b3d18 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 13:23:58 -0800 Subject: [PATCH 38/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 3 ++- sxt/proof/sumcheck2/sum_gpu.h | 11 +++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 9b8de1116..2b43a688b 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -168,7 +168,8 @@ sxt_cc_component( ":constant", # ":device_cache", # ":mle_utility", - # ":polynomial_mapper", + ":polynomial_mapper", + ":polynomial_reducer", # ":reduction_gpu", "//sxt/algorithm/reduction:kernel_fit", "//sxt/algorithm/reduction:thread_reduction", diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index 30d8b868d..44f481338 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -19,6 +19,7 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" #include "sxt/memory/resource/device_resource.h" +#include "sxt/proof/sumcheck2/polynomial_mapper.h" #include "sxt/proof/sumcheck2/polynomial_reducer.h" #include "sxt/proof/sumcheck/constant.h" /* #include "sxt/proof/sumcheck/device_cache.h" */ @@ -39,15 +40,14 @@ struct sum_options { //-------------------------------------------------------------------------------------------------- // partial_sum_kernel_impl //-------------------------------------------------------------------------------------------------- -#if 0 template __device__ static void partial_sum_kernel_impl(T* __restrict__ shared_data, const T* __restrict__ mles, const unsigned* __restrict__ product_terms, unsigned split, unsigned n) noexcept { - using Mapper = polynomial_mapper; - using Reducer = polynomial_reducer; - using T = Mapper::value_type; + using Mapper = polynomial_mapper; + using Reducer = polynomial_reducer; + using U = Mapper::value_type; Mapper mapper{ .mles = mles, .product_terms = product_terms, @@ -56,10 +56,9 @@ __device__ static void partial_sum_kernel_impl(T* __restrict__ shared_data, }; auto index = blockIdx.x * (BlockSize * 2) + threadIdx.x; auto step = BlockSize * 2 * gridDim.x; - algr::thread_reduce(reinterpret_cast(shared_data), mapper, split, step, + algr::thread_reduce(reinterpret_cast(shared_data), mapper, split, step, threadIdx.x, index); } -#endif //-------------------------------------------------------------------------------------------------- // sum_gpu From 3c3a56bc2e5aba3bf65ea0c65a5d13c2c01e39b2 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 13:27:31 -0800 Subject: [PATCH 39/83] rework sumcheck --- sxt/proof/sumcheck2/sum_gpu.h | 74 ++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index 44f481338..fa5d81462 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -21,7 +21,7 @@ #include "sxt/memory/resource/device_resource.h" #include "sxt/proof/sumcheck2/polynomial_mapper.h" #include "sxt/proof/sumcheck2/polynomial_reducer.h" -#include "sxt/proof/sumcheck/constant.h" +#include "sxt/proof/sumcheck2/constant.h" /* #include "sxt/proof/sumcheck/device_cache.h" */ /* #include "sxt/proof/sumcheck/mle_utility.h" */ /* #include "sxt/proof/sumcheck/polynomial_mapper.h" */ @@ -60,6 +60,78 @@ __device__ static void partial_sum_kernel_impl(T* __restrict__ shared_data, threadIdx.x, index); } +//-------------------------------------------------------------------------------------------------- +// partial_sum_kernel +//-------------------------------------------------------------------------------------------------- +template +__global__ static void +partial_sum_kernel(T* __restrict__ out, const T* __restrict__ mles, + const std::pair* __restrict__ product_table, + const unsigned* __restrict__ product_terms, unsigned num_coefficients, + unsigned split, unsigned n) noexcept { + auto product_index = blockIdx.y; + auto num_terms = product_table[product_index].second; + auto thread_index = threadIdx.x; + auto output_step = gridDim.x * gridDim.y; + + // shared data for reduction + __shared__ T shared_data[2 * BlockSize * (max_degree_v + 1u)]; + + // adjust pointers + out += blockIdx.x; + out += gridDim.x * product_index; + for (unsigned i = 0; i < product_index; ++i) { + product_terms += product_table[i].second; + } + + // sum + basn::constexpr_switch<1, max_degree_v + 1u>( + num_terms, [&](std::integral_constant) noexcept { + partial_sum_kernel_impl(shared_data, mles, product_terms, split, n); + }); + + // write out result + auto mult = product_table[product_index].first; + for (unsigned i = thread_index; i < num_coefficients; i += BlockSize) { + auto output_index = output_step * i; + if (i < num_terms + 1u) { + mul(out[output_index], mult, shared_data[i]); + } else { + out[output_index] = T{}; + } + } +} + +//-------------------------------------------------------------------------------------------------- +// partial_sum +//-------------------------------------------------------------------------------------------------- +#if 0 +template +static xena::future<> partial_sum(basct::span p, basdv::stream& stream, + basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned split, + unsigned n) noexcept { + auto num_coefficients = p.size(); + auto num_products = product_table.size(); + auto dims = algr::fit_reduction_kernel(split); + memr::async_device_resource resource{stream}; + + // partials + memmg::managed_array partials{num_coefficients * dims.num_blocks * num_products, + &resource}; + xenk::launch_kernel(dims.block_size, [&]( + std::integral_constant) noexcept { + partial_sum_kernel<<>>( + partials.data(), mles.data(), product_table.data(), product_terms.data(), num_coefficients, + split, n); + }); + + // reduce partials + co_await reduce_sums(p, stream, partials); +} +#endif + //-------------------------------------------------------------------------------------------------- // sum_gpu //-------------------------------------------------------------------------------------------------- From 8fefe565c4739b3fa84dcefe17fe7db2b82caa48 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 13:28:09 -0800 Subject: [PATCH 40/83] rework sumcheck --- sxt/proof/sumcheck2/sum_gpu.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index fa5d81462..a7b11311f 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -105,7 +105,6 @@ partial_sum_kernel(T* __restrict__ out, const T* __restrict__ mles, //-------------------------------------------------------------------------------------------------- // partial_sum //-------------------------------------------------------------------------------------------------- -#if 0 template static xena::future<> partial_sum(basct::span p, basdv::stream& stream, basct::cspan mles, @@ -130,7 +129,6 @@ static xena::future<> partial_sum(basct::span p, basdv::stream& stream, // reduce partials co_await reduce_sums(p, stream, partials); } -#endif //-------------------------------------------------------------------------------------------------- // sum_gpu From 707b826c9915189d6a4a71e092e3f1ef97299ae5 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 15:09:22 -0800 Subject: [PATCH 41/83] rework sumcheck --- sxt/proof/sumcheck/device_cache.h | 16 ++++++++++ sxt/proof/sumcheck2/BUILD | 21 +++++++++++++ sxt/proof/sumcheck2/device_cache.cc | 1 + sxt/proof/sumcheck2/device_cache.h | 45 +++++++++++++++++++++++++++ sxt/proof/sumcheck2/device_cache.t.cc | 5 +++ 5 files changed, 88 insertions(+) create mode 100644 sxt/proof/sumcheck2/device_cache.cc create mode 100644 sxt/proof/sumcheck2/device_cache.h create mode 100644 sxt/proof/sumcheck2/device_cache.t.cc diff --git a/sxt/proof/sumcheck/device_cache.h b/sxt/proof/sumcheck/device_cache.h index c61038392..5b934d256 100644 --- a/sxt/proof/sumcheck/device_cache.h +++ b/sxt/proof/sumcheck/device_cache.h @@ -29,6 +29,22 @@ class stream; } namespace sxt::prfsk { +//-------------------------------------------------------------------------------------------------- +// make_device_copy +//-------------------------------------------------------------------------------------------------- +template +std::unique_ptr> +make_device_copy(basct::cspan> product_table, + basct::cspan product_terms, basdv::stream& stream) noexcept { + device_cache_data res{ + .product_table{product_table.size(), memr::get_device_resource()}, + .product_terms{product_terms.size(), memr::get_device_resource()}, + }; + basdv::async_copy_host_to_device(res.product_table, product_table, stream); + basdv::async_copy_host_to_device(res.product_terms, product_terms, stream); + return std::make_unique(std::move(res)); +} + //-------------------------------------------------------------------------------------------------- // device_cache_data //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 2b43a688b..4f3acb695 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -17,6 +17,27 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "device_cache", + test_deps = [ + "//sxt/base/device:memory_utility", + "//sxt/base/device:stream", + "//sxt/base/device:synchronization", + "//sxt/base/test:unit_test", + "//sxt/scalar25/type:literal", + ], + deps = [ + "//sxt/base/container:span", + "//sxt/base/device:memory_utility", + "//sxt/base/device:state", + "//sxt/base/device:stream", + "//sxt/base/device:device_map", + "//sxt/base/field:element", + "//sxt/memory/resource:device_resource", + "//sxt/memory/management:managed_array", + ], +) + sxt_cc_component( name = "driver", with_test = False, diff --git a/sxt/proof/sumcheck2/device_cache.cc b/sxt/proof/sumcheck2/device_cache.cc new file mode 100644 index 000000000..d5e2569e1 --- /dev/null +++ b/sxt/proof/sumcheck2/device_cache.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/device_cache.h" diff --git a/sxt/proof/sumcheck2/device_cache.h b/sxt/proof/sumcheck2/device_cache.h new file mode 100644 index 000000000..dd7f5a6af --- /dev/null +++ b/sxt/proof/sumcheck2/device_cache.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + +#include "sxt/base/container/span.h" +#include "sxt/base/device/device_map.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/state.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/field/element.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/device_resource.h" +#include "sxt/scalar25/type/element.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// device_cache_data +//-------------------------------------------------------------------------------------------------- +template +struct device_cache_data { + memmg::managed_array> product_table; + memmg::managed_array product_terms; +}; + +//-------------------------------------------------------------------------------------------------- +// device_cache +//-------------------------------------------------------------------------------------------------- +template +class device_cache { +public: + device_cache(basct::cspan> product_table, + basct::cspan product_terms) noexcept; + + void lookup(basct::cspan>& product_table, + basct::cspan& product_terms, basdv::stream& stream) noexcept; + + std::unique_ptr> clear() noexcept; + +private: + basct::cspan> product_table_; + basct::cspan product_terms_; + basdv::device_map>> data_; +}; +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/device_cache.t.cc b/sxt/proof/sumcheck2/device_cache.t.cc new file mode 100644 index 000000000..e67167748 --- /dev/null +++ b/sxt/proof/sumcheck2/device_cache.t.cc @@ -0,0 +1,5 @@ +#include "sxt/proof/sumcheck2/device_cache.h" + +#include "sxt/base/test/unit_test.h" + +TEST_CASE("todo") {} From 0a163c0bad28db1154060b573aaa7d8949a86b06 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 15:11:33 -0800 Subject: [PATCH 42/83] rework sumcheck --- sxt/proof/sumcheck/device_cache.h | 16 ---------------- sxt/proof/sumcheck2/device_cache.h | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/sxt/proof/sumcheck/device_cache.h b/sxt/proof/sumcheck/device_cache.h index 5b934d256..c61038392 100644 --- a/sxt/proof/sumcheck/device_cache.h +++ b/sxt/proof/sumcheck/device_cache.h @@ -29,22 +29,6 @@ class stream; } namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// make_device_copy -//-------------------------------------------------------------------------------------------------- -template -std::unique_ptr> -make_device_copy(basct::cspan> product_table, - basct::cspan product_terms, basdv::stream& stream) noexcept { - device_cache_data res{ - .product_table{product_table.size(), memr::get_device_resource()}, - .product_terms{product_terms.size(), memr::get_device_resource()}, - }; - basdv::async_copy_host_to_device(res.product_table, product_table, stream); - basdv::async_copy_host_to_device(res.product_terms, product_terms, stream); - return std::make_unique(std::move(res)); -} - //-------------------------------------------------------------------------------------------------- // device_cache_data //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/device_cache.h b/sxt/proof/sumcheck2/device_cache.h index dd7f5a6af..70e9bab1e 100644 --- a/sxt/proof/sumcheck2/device_cache.h +++ b/sxt/proof/sumcheck2/device_cache.h @@ -23,6 +23,22 @@ struct device_cache_data { memmg::managed_array product_terms; }; +//-------------------------------------------------------------------------------------------------- +// make_device_copy +//-------------------------------------------------------------------------------------------------- +template +std::unique_ptr> +make_device_copy(basct::cspan> product_table, + basct::cspan product_terms, basdv::stream& stream) noexcept { + device_cache_data res{ + .product_table{product_table.size(), memr::get_device_resource()}, + .product_terms{product_terms.size(), memr::get_device_resource()}, + }; + basdv::async_copy_host_to_device(res.product_table, product_table, stream); + basdv::async_copy_host_to_device(res.product_terms, product_terms, stream); + return std::make_unique(std::move(res)); +} + //-------------------------------------------------------------------------------------------------- // device_cache //-------------------------------------------------------------------------------------------------- From 8f397019ebf50e92ab001fb216ec4e18f5eeb812 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 15:17:01 -0800 Subject: [PATCH 43/83] rework sumcheck --- sxt/proof/sumcheck2/device_cache.h | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/sxt/proof/sumcheck2/device_cache.h b/sxt/proof/sumcheck2/device_cache.h index 70e9bab1e..6f391ba5a 100644 --- a/sxt/proof/sumcheck2/device_cache.h +++ b/sxt/proof/sumcheck2/device_cache.h @@ -46,12 +46,26 @@ template class device_cache { public: device_cache(basct::cspan> product_table, - basct::cspan product_terms) noexcept; + basct::cspan product_terms) noexcept + : product_table_{product_table}, product_terms_{product_terms} {} void lookup(basct::cspan>& product_table, - basct::cspan& product_terms, basdv::stream& stream) noexcept; + basct::cspan& product_terms, basdv::stream& stream) noexcept { + auto& ptr = data_[basdv::get_device()]; + if (ptr == nullptr) { + ptr = make_device_copy(product_table_, product_terms_, stream); + } + product_table = ptr->product_table; + product_terms = ptr->product_terms; + } - std::unique_ptr> clear() noexcept; + std::unique_ptr> clear() noexcept { + auto res{std::move(data_[basdv::get_device()])}; + for (auto& ptr : data_) { + ptr.reset(); + } + return res; + } private: basct::cspan> product_table_; From 2462a49f6f0f867e2337e624feb13b4725d506c3 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 17:02:39 -0800 Subject: [PATCH 44/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 1 + sxt/proof/sumcheck2/device_cache.h | 4 +- sxt/proof/sumcheck2/device_cache.t.cc | 60 ++++++++++++++++++++++++++- 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 4f3acb695..de46232ac 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -24,6 +24,7 @@ sxt_cc_component( "//sxt/base/device:stream", "//sxt/base/device:synchronization", "//sxt/base/test:unit_test", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], deps = [ diff --git a/sxt/proof/sumcheck2/device_cache.h b/sxt/proof/sumcheck2/device_cache.h index 6f391ba5a..b5fc54700 100644 --- a/sxt/proof/sumcheck2/device_cache.h +++ b/sxt/proof/sumcheck2/device_cache.h @@ -30,13 +30,13 @@ template std::unique_ptr> make_device_copy(basct::cspan> product_table, basct::cspan product_terms, basdv::stream& stream) noexcept { - device_cache_data res{ + device_cache_data res{ .product_table{product_table.size(), memr::get_device_resource()}, .product_terms{product_terms.size(), memr::get_device_resource()}, }; basdv::async_copy_host_to_device(res.product_table, product_table, stream); basdv::async_copy_host_to_device(res.product_terms, product_terms, stream); - return std::make_unique(std::move(res)); + return std::make_unique>(std::move(res)); } //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/device_cache.t.cc b/sxt/proof/sumcheck2/device_cache.t.cc index e67167748..e2c3df658 100644 --- a/sxt/proof/sumcheck2/device_cache.t.cc +++ b/sxt/proof/sumcheck2/device_cache.t.cc @@ -1,5 +1,63 @@ #include "sxt/proof/sumcheck2/device_cache.h" +#include + +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/device/synchronization.h" #include "sxt/base/test/unit_test.h" +#include "sxt/scalar25/realization/field.h" +#include "sxt/scalar25/type/literal.h" + +using namespace sxt; +using namespace sxt::prfsk2; +using s25t::operator""_s25; + + +TEST_CASE("we can cache device values that don't change as a proof is computed") { + using T = s25t::element; + std::vector> product_table; + std::vector product_terms; + + basdv::stream stream; + + basct::cspan> product_table_dev; + basct::cspan product_terms_dev; + + SECTION("we can access values from device memory") { + product_table = {{0x123_s25, 0}}; + product_terms = {0}; + device_cache cache{product_table, product_terms}; + cache.lookup(product_table_dev, product_terms_dev, stream); + + std::vector> product_table_p(product_table.size()); + basdv::async_copy_device_to_host(product_table_p, product_table_dev, stream); + + std::vector product_terms_p(product_terms.size()); + basdv::async_copy_device_to_host(product_terms_p, product_terms_dev, stream); + + basdv::synchronize_stream(stream); + REQUIRE(product_table_p == product_table); + REQUIRE(product_terms_p == product_terms); + } + + SECTION("we can clear the device cache") { + product_table = {{0x123_s25, 0}}; + product_terms = {0}; + device_cache cache{product_table, product_terms}; + cache.lookup(product_table_dev, product_terms_dev, stream); + + std::vector> product_table_p(product_table.size()); + basdv::async_copy_device_to_host(product_table_p, product_table_dev, stream); + + std::vector product_terms_p(product_terms.size()); + basdv::async_copy_device_to_host(product_terms_p, product_terms_dev, stream); -TEST_CASE("todo") {} + auto data = cache.clear(); + basdv::async_copy_device_to_host(product_table_p, data->product_table, stream); + basdv::async_copy_device_to_host(product_terms_p, data->product_terms, stream); + basdv::synchronize_stream(stream); + REQUIRE(product_table_p == product_table); + REQUIRE(product_terms_p == product_terms); + } +} From f0548c5fe32599a423345018b2bb04ebc8b9093c Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 17:06:08 -0800 Subject: [PATCH 45/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 2 +- sxt/proof/sumcheck2/sum_gpu.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index de46232ac..00d511c7d 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -188,7 +188,7 @@ sxt_cc_component( name = "sum_gpu", impl_deps = [ ":constant", - # ":device_cache", + ":device_cache", # ":mle_utility", ":polynomial_mapper", ":polynomial_reducer", diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index a7b11311f..9f06e5096 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -22,7 +22,7 @@ #include "sxt/proof/sumcheck2/polynomial_mapper.h" #include "sxt/proof/sumcheck2/polynomial_reducer.h" #include "sxt/proof/sumcheck2/constant.h" -/* #include "sxt/proof/sumcheck/device_cache.h" */ +#include "sxt/proof/sumcheck2/device_cache.h" /* #include "sxt/proof/sumcheck/mle_utility.h" */ /* #include "sxt/proof/sumcheck/polynomial_mapper.h" */ /* #include "sxt/proof/sumcheck/reduction_gpu.h" */ From cc6d3886dae4e461b362581869cebdd29a5b5e49 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 17:12:12 -0800 Subject: [PATCH 46/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 23 ++++++++++++ sxt/proof/sumcheck2/mle_utility.cc | 1 + sxt/proof/sumcheck2/mle_utility.h | 53 ++++++++++++++++++++++++++++ sxt/proof/sumcheck2/mle_utility.t.cc | 6 ++++ 4 files changed, 83 insertions(+) create mode 100644 sxt/proof/sumcheck2/mle_utility.cc create mode 100644 sxt/proof/sumcheck2/mle_utility.h create mode 100644 sxt/proof/sumcheck2/mle_utility.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 00d511c7d..78bf56fa2 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -106,6 +106,29 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "mle_utility", + impl_deps = [ + "//sxt/base/container:span_utility", + "//sxt/base/device:memory_utility", + "//sxt/base/device:property", + "//sxt/base/container:span", + "//sxt/base/field:element", + "//sxt/base/num:divide_up", + "//sxt/base/num:ceil_log2", + "//sxt/memory/management:managed_array", + ], + test_deps = [ + "//sxt/base/device:stream", + "//sxt/base/device:synchronization", + "//sxt/base/test:unit_test", + "//sxt/memory/management:managed_array", + "//sxt/memory/resource:managed_device_resource", + "//sxt/scalar25/type:element", + "//sxt/scalar25/type:literal", + ], +) + sxt_cc_component( name = "polynomial_mapper", deps = [ diff --git a/sxt/proof/sumcheck2/mle_utility.cc b/sxt/proof/sumcheck2/mle_utility.cc new file mode 100644 index 000000000..df389d796 --- /dev/null +++ b/sxt/proof/sumcheck2/mle_utility.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/mle_utility.h" diff --git a/sxt/proof/sumcheck2/mle_utility.h b/sxt/proof/sumcheck2/mle_utility.h new file mode 100644 index 000000000..a893a4a04 --- /dev/null +++ b/sxt/proof/sumcheck2/mle_utility.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include "sxt/base/container/span.h" +#include "sxt/base/container/span_utility.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/property.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/field/element.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/base/num/divide_up.h" +#include "sxt/memory/management/managed_array.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// copy_partial_mles +//-------------------------------------------------------------------------------------------------- +template +void copy_partial_mles(memmg::managed_array& partial_mles, basdv::stream& stream, + basct::cspan mles, unsigned n, unsigned a, + unsigned b) noexcept { + auto num_variables = std::max(basn::ceil_log2(n), 1); + auto mid = 1u << (num_variables - 1u); + auto num_mles = mles.size() / n; + auto part1_size = b - a; + SXT_DEBUG_ASSERT(a < b && b <= n); + auto ap = std::min(mid + a, n); + auto bp = std::min(mid + b, n); + auto part2_size = bp - ap; + + // resize array + auto partial_length = part1_size + part2_size; + partial_mles.resize(partial_length * num_mles); + + // copy data + for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { + // first part + auto src = mles.subspan(n * mle_index + a, part1_size); + auto dst = basct::subspan(partial_mles, partial_length * mle_index, part1_size); + basdv::async_copy_host_to_device(dst, src, stream); + + // second part + src = mles.subspan(n * mle_index + ap, part2_size); + dst = basct::subspan(partial_mles, partial_length * mle_index + part1_size, part2_size); + if (!src.empty()) { + basdv::async_copy_host_to_device(dst, src, stream); + } + } +} +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/mle_utility.t.cc b/sxt/proof/sumcheck2/mle_utility.t.cc new file mode 100644 index 000000000..d53b74aa8 --- /dev/null +++ b/sxt/proof/sumcheck2/mle_utility.t.cc @@ -0,0 +1,6 @@ +#include "sxt/proof/sumcheck2/mle_utility.h" + +#include "sxt/base/test/unit_test.h" + +TEST_CASE("todo") {} + From 0327047db2bd9d72bd0c42eca521d1970340e30b Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 17:14:34 -0800 Subject: [PATCH 47/83] rework sumcheck --- sxt/proof/sumcheck2/mle_utility.h | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/sxt/proof/sumcheck2/mle_utility.h b/sxt/proof/sumcheck2/mle_utility.h index a893a4a04..ccaa550d9 100644 --- a/sxt/proof/sumcheck2/mle_utility.h +++ b/sxt/proof/sumcheck2/mle_utility.h @@ -50,4 +50,37 @@ void copy_partial_mles(memmg::managed_array& partial_mles, basdv::stream& str } } } + +//-------------------------------------------------------------------------------------------------- +// copy_folded_mles +//-------------------------------------------------------------------------------------------------- +template +void copy_folded_mles(basct::span host_mles, basdv::stream& stream, + basct::cspan device_mles, unsigned np, unsigned a, + unsigned b) noexcept { + auto num_mles = host_mles.size() / np; + auto slice_n = device_mles.size() / num_mles; + auto slice_np = b - a; + SXT_DEBUG_ASSERT( + // clang-format off + host_mles.size() == num_mles * np && + device_mles.size() == num_mles * slice_n && + b <= np + // clang-format on + ); + for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { + auto src = device_mles.subspan(mle_index * slice_n, slice_np); + auto dst = host_mles.subspan(mle_index * np + a, slice_np); + basdv::async_copy_device_to_host(dst, src, stream); + } +} + +//-------------------------------------------------------------------------------------------------- +// get_gpu_memory_fraction +//-------------------------------------------------------------------------------------------------- +template +double get_gpu_memory_fraction(basct::cspan mles) noexcept { + auto total_memory = static_cast(basdv::get_total_device_memory()); + return static_cast(mles.size() * sizeof(T)) / total_memory; +} } // namespace sxt::prfsk2 From 589043c039cbbcf5326b399916fddea983071bb6 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 17:19:16 -0800 Subject: [PATCH 48/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 2 +- sxt/proof/sumcheck2/mle_utility.t.cc | 76 +++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 78bf56fa2..32921e1d8 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -124,7 +124,7 @@ sxt_cc_component( "//sxt/base/test:unit_test", "//sxt/memory/management:managed_array", "//sxt/memory/resource:managed_device_resource", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], ) diff --git a/sxt/proof/sumcheck2/mle_utility.t.cc b/sxt/proof/sumcheck2/mle_utility.t.cc index d53b74aa8..222d79a1c 100644 --- a/sxt/proof/sumcheck2/mle_utility.t.cc +++ b/sxt/proof/sumcheck2/mle_utility.t.cc @@ -1,6 +1,80 @@ #include "sxt/proof/sumcheck2/mle_utility.h" +#include + +#include "sxt/base/device/stream.h" +#include "sxt/base/device/synchronization.h" #include "sxt/base/test/unit_test.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/managed_device_resource.h" +#include "sxt/scalar25/realization/field.h" +#include "sxt/scalar25/type/literal.h" + +using namespace sxt; +using namespace sxt::prfsk2; +using s25t::operator""_s25; + +using T = s25t::element; + +TEST_CASE("we can copy a slice of mles to device memory") { + std::pmr::vector mles{memr::get_managed_device_resource()}; + memmg::managed_array partial_mles{memr::get_managed_device_resource()}; + + basdv::stream stream; + + SECTION("we can copy an mle with a single element") { + mles = {0x123_s25}; + copy_partial_mles(partial_mles, stream, mles, 1, 0, 1); + basdv::synchronize_stream(stream); + memmg::managed_array expected = {0x123_s25}; + REQUIRE(partial_mles == expected); + } + + SECTION("we can copy a slice of MLEs") { + mles = {0x1_s25, 0x2_s25, 0x3_s25, 0x4_s25, 0x5_s25, 0x6_s25}; + copy_partial_mles(partial_mles, stream, mles, 3, 0, 1); + basdv::synchronize_stream(stream); + memmg::managed_array expected = {0x1_s25, 0x3_s25, 0x4_s25, 0x6_s25}; + REQUIRE(partial_mles == expected); + } +} + +TEST_CASE("we can copy partially folded MLEs to the host") { + std::pmr::vector device_mles{memr::get_managed_device_resource()}; + std::vector host_mles; + + basdv::stream stream; + + SECTION("we can copy a single element") { + device_mles = {0x123_s25}; + host_mles.resize(1); + copy_folded_mles(host_mles, stream, device_mles, 1, 0, 1); + basdv::synchronize_stream(stream); + std::vector expected = {0x123_s25}; + REQUIRE(host_mles == expected); + } + + SECTION("we can copy partially folded MLEs") { + device_mles = {0x123_s25, 0x456_s25}; + host_mles.resize(4); + copy_folded_mles(host_mles, stream, device_mles, 2, 0, 1); + basdv::synchronize_stream(stream); + std::vector expected = {0x123_s25, 0x0_s25, 0x456_s25, 0x0_s25}; + REQUIRE(host_mles == expected); + } +} + +TEST_CASE("we can query the fraction of device memory taken by MLEs") { + std::vector mles; -TEST_CASE("todo") {} + SECTION("we handle the zero case") { REQUIRE(get_gpu_memory_fraction(mles) == 0.0); } + SECTION("the fractions doubles if the length of mles doubles") { + mles.resize(1); + auto f1 = get_gpu_memory_fraction(mles); + REQUIRE(f1 > 0); + mles.resize(2); + auto f2 = get_gpu_memory_fraction(mles); + REQUIRE(f2 == Catch::Approx(2 * f1)); + } +} From a24c2d9b5760c39eb6525c35a7fba0e342299d1f Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 19:26:55 -0800 Subject: [PATCH 49/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 2 +- sxt/proof/sumcheck2/sum_gpu.h | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 32921e1d8..a7fb41127 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -212,7 +212,7 @@ sxt_cc_component( impl_deps = [ ":constant", ":device_cache", - # ":mle_utility", + ":mle_utility", ":polynomial_mapper", ":polynomial_reducer", # ":reduction_gpu", diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index 9f06e5096..3a5f85268 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -23,7 +23,7 @@ #include "sxt/proof/sumcheck2/polynomial_reducer.h" #include "sxt/proof/sumcheck2/constant.h" #include "sxt/proof/sumcheck2/device_cache.h" -/* #include "sxt/proof/sumcheck/mle_utility.h" */ +#include "sxt/proof/sumcheck2/mle_utility.h" /* #include "sxt/proof/sumcheck/polynomial_mapper.h" */ /* #include "sxt/proof/sumcheck/reduction_gpu.h" */ @@ -133,9 +133,9 @@ static xena::future<> partial_sum(basct::span p, basdv::stream& stream, //-------------------------------------------------------------------------------------------------- // sum_gpu //-------------------------------------------------------------------------------------------------- -#if 0 -xena::future<> sum_gpu(basct::span p, device_cache& cache, - const basit::split_options& options, basct::cspan mles, +template +xena::future<> sum_gpu(basct::span p, device_cache& cache, + const basit::split_options& options, basct::cspan mles, unsigned n) noexcept { auto num_variables = std::max(basn::ceil_log2(n), 1); auto mid = 1u << (num_variables - 1u); @@ -153,19 +153,19 @@ xena::future<> sum_gpu(basct::span p, device_cache& cache, memr::async_device_resource resource{stream}; // copy partial mles to device - memmg::managed_array partial_mles{&resource}; - copy_partial_mles(partial_mles, stream, mles, n, rng.a(), rng.b()); + memmg::managed_array partial_mles{&resource}; + copy_partial_mles(partial_mles, stream, mles, n, rng.a(), rng.b()); auto split = rng.b() - rng.a(); auto np = partial_mles.size() / num_mles; // lookup problem descriptor - basct::cspan> product_table; + basct::cspan> product_table; basct::cspan product_terms; cache.lookup(product_table, product_terms, stream); // compute - memmg::managed_array partial_p(num_coefficients); - co_await partial_sum(partial_p, stream, partial_mles, product_table, product_terms, split, + memmg::managed_array partial_p(num_coefficients); + co_await partial_sum(partial_p, stream, partial_mles, product_table, product_terms, split, np); // fill in the result @@ -175,13 +175,14 @@ xena::future<> sum_gpu(basct::span p, device_cache& cache, } } else { for (unsigned i = 0; i < num_coefficients; ++i) { - s25o::add(p[i], p[i], partial_p[i]); + add(p[i], p[i], partial_p[i]); } } ++counter; }); } +#if 0 xena::future<> sum_gpu(basct::span p, device_cache& cache, basct::cspan mles, unsigned n) noexcept { basit::split_options options{ From 3c35207e0ae2acf0997c6881cb155f76bb061b61 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 19:28:37 -0800 Subject: [PATCH 50/83] rework sumcheck --- sxt/proof/sumcheck2/sum_gpu.h | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index 3a5f85268..510f14892 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -182,17 +182,16 @@ xena::future<> sum_gpu(basct::span p, device_cache& cache, }); } -#if 0 -xena::future<> sum_gpu(basct::span p, device_cache& cache, - basct::cspan mles, unsigned n) noexcept { +template +xena::future<> sum_gpu(basct::span p, device_cache& cache, + basct::cspan mles, unsigned n) noexcept { basit::split_options options{ .min_chunk_size = 100'000u, .max_chunk_size = 200'000u, .split_factor = basdv::get_num_devices(), }; - co_await sum_gpu(p, cache, options, mles, n); + co_await sum_gpu(p, cache, options, mles, n); } -#endif template xena::future<> sum_gpu(basct::span p, basct::cspan mles, @@ -209,7 +208,6 @@ xena::future<> sum_gpu(basct::span p, basct::cspan mles, // clang-format on ); basdv::stream stream; - /* co_await partial_sum(p, stream, mles, product_table, product_terms, mid, n); */ - return {}; + co_await partial_sum(p, stream, mles, product_table, product_terms, mid, n); } } // namespace sxt::prfsk2 From 80baa5f1527def05ba504070161cce38261fd5a7 Mon Sep 17 00:00:00 2001 From: rnburn Date: Fri, 21 Feb 2025 19:36:44 -0800 Subject: [PATCH 51/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 2 +- sxt/proof/sumcheck2/sum_gpu.h | 1 - sxt/proof/sumcheck2/sum_gpu.t.cc | 123 ++++++++++++++++++++++++++++++- 3 files changed, 123 insertions(+), 3 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index a7fb41127..c1464ef11 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -239,7 +239,7 @@ sxt_cc_component( "//sxt/execution/async:future", "//sxt/execution/schedule:scheduler", "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], deps = [ diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index 510f14892..9aeb97213 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -24,7 +24,6 @@ #include "sxt/proof/sumcheck2/constant.h" #include "sxt/proof/sumcheck2/device_cache.h" #include "sxt/proof/sumcheck2/mle_utility.h" -/* #include "sxt/proof/sumcheck/polynomial_mapper.h" */ /* #include "sxt/proof/sumcheck/reduction_gpu.h" */ namespace sxt::prfsk2 { diff --git a/sxt/proof/sumcheck2/sum_gpu.t.cc b/sxt/proof/sumcheck2/sum_gpu.t.cc index a0550f4db..56bfdaf5e 100644 --- a/sxt/proof/sumcheck2/sum_gpu.t.cc +++ b/sxt/proof/sumcheck2/sum_gpu.t.cc @@ -1,5 +1,126 @@ #include "sxt/proof/sumcheck2/sum_gpu.h" +#include + +#include "sxt/base/iterator/split.h" #include "sxt/base/test/unit_test.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/schedule/scheduler.h" +#include "sxt/proof/sumcheck/device_cache.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/realization/field.h" +#include "sxt/scalar25/type/literal.h" + +using namespace sxt; +using namespace sxt::prfsk2; +using s25t::operator""_s25; + +#if 0 +TEST_CASE("we can sum MLEs") { + using T = s25t::element; + + std::vector> product_table; + std::vector product_terms; + std::vector mles; + std::vector p(2); + + SECTION("we can sum an MLE with a single term and n=1") { + product_table = {{0x1_s25, 1}}; + product_terms = {0}; + device_cache cache{product_table, product_terms}; + mles = {0x123_s25}; + auto fut = sum_gpu(p, cache, mles, 1); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == -mles[0]); + } + + SECTION("we can sum an MLE with a single term, n=1, and a non-unity multiplier") { + product_table = {{0x2_s25, 1}}; + product_terms = {0}; + device_cache cache{product_table, product_terms}; + mles = {0x123_s25}; + auto fut = sum_gpu(p, cache, mles, 1); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == product_table[0].first * mles[0]); + REQUIRE(p[1] == -product_table[0].first * mles[0]); + } + + SECTION("we can sum an MLE with a single term and n=2") { + product_table = {{0x1_s25, 1}}; + product_terms = {0}; + device_cache cache{product_table, product_terms}; + mles = {0x123_s25, 0x456_s25}; + auto fut = sum_gpu(p, cache, mles, 2); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == mles[1] - mles[0]); + } + + SECTION("we can sum an MLE with multiple terms and n=1") { + p.resize(3); + product_table = {{0x1_s25, 2}}; + product_terms = {0, 1}; + device_cache cache{product_table, product_terms}; + mles = {0x123_s25, 0x456_s25}; + auto fut = sum_gpu(p, cache, mles, 1); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0] * mles[1]); + REQUIRE(p[1] == -mles[0] * mles[1] - mles[1] * mles[0]); + REQUIRE(p[2] == mles[0] * mles[1]); + } + + SECTION("we can sum multiple mles") { + product_table = { + {0x1_s25, 1}, + {0x1_s25, 1}, + }; + product_terms = {0, 1}; + device_cache cache{product_table, product_terms}; + mles = {0x123_s25, 0x456_s25}; + auto fut = sum_gpu(p, cache, mles, 1); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0] + mles[1]); + REQUIRE(p[1] == -mles[0] - mles[1]); + } + + SECTION("we can chunk sums with n=4") { + product_table = {{0x1_s25, 1}}; + product_terms = {0}; + device_cache cache{product_table, product_terms}; + mles = {0x123_s25, 0x456_s25, 0x789_s25, 0x91011_s25}; + basit::split_options options{ + .min_chunk_size = 1, + .max_chunk_size = 1, + .split_factor = 2, + }; + auto fut = sum_gpu(p, cache, options, mles, 4); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0] + mles[1]); + REQUIRE(p[1] == (mles[2] - mles[0]) + (mles[3] - mles[1])); + } -TEST_CASE("todo") {} + SECTION("we can chunk sums with n=4") { + product_table = {{0x1_s25, 1}}; + product_terms = {0}; + device_cache cache{product_table, product_terms}; + mles = {0x2_s25, 0x4_s25, 0x7_s25, 0x9_s25}; + basit::split_options options{ + .min_chunk_size = 16, + .max_chunk_size = 16, + .split_factor = 2, + }; + auto fut = sum_gpu(p, cache, options, mles, 4); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == mles[0] + mles[1]); + REQUIRE(p[1] == (mles[2] - mles[0]) + (mles[3] - mles[1])); + } +} +#endif From 265db5e41e8aa649236033dc0553aa0d2e464246 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 09:04:32 -0800 Subject: [PATCH 52/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 35 +++++++++++++++++++++++++- sxt/proof/sumcheck2/gpu_driver.h | 12 ++++++++- sxt/proof/sumcheck2/reduction_gpu.cc | 1 + sxt/proof/sumcheck2/reduction_gpu.h | 4 +++ sxt/proof/sumcheck2/reduction_gpu.t.cc | 19 ++++++++++++++ sxt/proof/sumcheck2/sum_gpu.t.cc | 4 +-- 6 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 sxt/proof/sumcheck2/reduction_gpu.cc create mode 100644 sxt/proof/sumcheck2/reduction_gpu.h create mode 100644 sxt/proof/sumcheck2/reduction_gpu.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index c1464ef11..4076e90f6 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -89,7 +89,7 @@ sxt_cc_component( name = "gpu_driver", deps = [ ":driver", - # ":polynomial_utility", + ":sum_gpu", "//sxt/base/error:assert", "//sxt/algorithm/iteration:for_each", "//sxt/base/device:memory_utility", @@ -207,6 +207,39 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "reduction_gpu", + test_deps = [ + "//sxt/base/device:property", + "//sxt/base/device:stream", + "//sxt/base/test:unit_test", + "//sxt/execution/async:future", + "//sxt/execution/schedule:scheduler", + "//sxt/memory/resource:managed_device_resource", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/realization:field", + "//sxt/scalar25/type:literal", + ], + deps = [ + "//sxt/algorithm/base:identity_mapper", + "//sxt/algorithm/reduction:kernel_fit", + "//sxt/algorithm/reduction:thread_reduction", + "//sxt/base/container:span", + "//sxt/base/device:memory_utility", + "//sxt/base/device:stream", + "//sxt/base/error:assert", + "//sxt/base/field:element", + "//sxt/execution/async:future_fwd", + "//sxt/execution/async:coroutine", + "//sxt/execution/async:future", + "//sxt/execution/device:synchronization", + "//sxt/execution/kernel:kernel_dims", + "//sxt/execution/kernel:launch", + "//sxt/memory/management:managed_array", + "//sxt/memory/resource:async_device_resource", + ], +) + sxt_cc_component( name = "sum_gpu", impl_deps = [ diff --git a/sxt/proof/sumcheck2/gpu_driver.h b/sxt/proof/sumcheck2/gpu_driver.h index 6636e34b3..e20ce32a9 100644 --- a/sxt/proof/sumcheck2/gpu_driver.h +++ b/sxt/proof/sumcheck2/gpu_driver.h @@ -12,6 +12,7 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/device_resource.h" #include "sxt/proof/sumcheck2/driver.h" +#include "sxt/proof/sumcheck2/sum_gpu.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- @@ -75,7 +76,16 @@ class gpu_driver final : public driver { } xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { - return {}; + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + SXT_RELEASE_ASSERT( + // clang-format off + work.n >= mid && + polynomial.size() - 1u <= max_degree_v + // clang-format on + ); + co_await sum_gpu(polynomial, work.mles, work.product_table, work.product_terms, n); } xena::future<> fold(workspace& ws, const T& r) const noexcept override { diff --git a/sxt/proof/sumcheck2/reduction_gpu.cc b/sxt/proof/sumcheck2/reduction_gpu.cc new file mode 100644 index 000000000..ff9abf6af --- /dev/null +++ b/sxt/proof/sumcheck2/reduction_gpu.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/reduction_gpu.h" diff --git a/sxt/proof/sumcheck2/reduction_gpu.h b/sxt/proof/sumcheck2/reduction_gpu.h new file mode 100644 index 000000000..0a8601c23 --- /dev/null +++ b/sxt/proof/sumcheck2/reduction_gpu.h @@ -0,0 +1,4 @@ +#pragma once + +namespace sxt::prfsk2 { +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/reduction_gpu.t.cc b/sxt/proof/sumcheck2/reduction_gpu.t.cc new file mode 100644 index 000000000..0a087425f --- /dev/null +++ b/sxt/proof/sumcheck2/reduction_gpu.t.cc @@ -0,0 +1,19 @@ +#include "sxt/proof/sumcheck2/reduction_gpu.h" + +#include + +#include "sxt/base/device/stream.h" +#include "sxt/base/test/unit_test.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/schedule/scheduler.h" +#include "sxt/memory/resource/managed_device_resource.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/type/literal.h" + +using namespace sxt; +using namespace sxt::prfsk2; +using s25t::operator""_s25; + +TEST_CASE("we can reduce sumcheck polynomials") { +} diff --git a/sxt/proof/sumcheck2/sum_gpu.t.cc b/sxt/proof/sumcheck2/sum_gpu.t.cc index 56bfdaf5e..1d81371b6 100644 --- a/sxt/proof/sumcheck2/sum_gpu.t.cc +++ b/sxt/proof/sumcheck2/sum_gpu.t.cc @@ -15,7 +15,6 @@ using namespace sxt; using namespace sxt::prfsk2; using s25t::operator""_s25; -#if 0 TEST_CASE("we can sum MLEs") { using T = s25t::element; @@ -24,6 +23,7 @@ TEST_CASE("we can sum MLEs") { std::vector mles; std::vector p(2); +#if 0 SECTION("we can sum an MLE with a single term and n=1") { product_table = {{0x1_s25, 1}}; product_terms = {0}; @@ -122,5 +122,5 @@ TEST_CASE("we can sum MLEs") { REQUIRE(p[0] == mles[0] + mles[1]); REQUIRE(p[1] == (mles[2] - mles[0]) + (mles[3] - mles[1])); } -} #endif +} From ef1d2e77e5e997cf0f5b1ad6028d4804a7ef07b1 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 09:22:06 -0800 Subject: [PATCH 53/83] add field accumulator --- sxt/base/field/BUILD | 9 +++ sxt/base/field/accumulator.cc | 1 + sxt/base/field/accumulator.h | 18 ++++++ sxt/proof/sumcheck2/reduction_gpu.h | 94 +++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+) create mode 100644 sxt/base/field/accumulator.cc create mode 100644 sxt/base/field/accumulator.h diff --git a/sxt/base/field/BUILD b/sxt/base/field/BUILD index b68abf0bf..7a8f0c8ad 100644 --- a/sxt/base/field/BUILD +++ b/sxt/base/field/BUILD @@ -19,3 +19,12 @@ sxt_cc_component( name = "element", with_test = False, ) + +sxt_cc_component( + name = "accumulator", + with_test = False, + deps = [ + ":element", + "//sxt/base/macro:cuda_callable", + ] +) diff --git a/sxt/base/field/accumulator.cc b/sxt/base/field/accumulator.cc new file mode 100644 index 000000000..5ab79a772 --- /dev/null +++ b/sxt/base/field/accumulator.cc @@ -0,0 +1 @@ +#include "sxt/base/field/accumulator.h" diff --git a/sxt/base/field/accumulator.h b/sxt/base/field/accumulator.h new file mode 100644 index 000000000..53a14cb1a --- /dev/null +++ b/sxt/base/field/accumulator.h @@ -0,0 +1,18 @@ +#pragma once + +#include "sxt/base/field/element.h" +#include "sxt/base/macro/cuda_callable.h" + +namespace sxt::basfld { +//-------------------------------------------------------------------------------------------------- +// accumulator +//-------------------------------------------------------------------------------------------------- +template +struct accumulator { + using value_type = T; + + CUDA_CALLABLE static void accumulate_inplace(T& res, T& e) noexcept { + add(res, res, e); + } +}; +} // namespace sxt::basfld diff --git a/sxt/proof/sumcheck2/reduction_gpu.h b/sxt/proof/sumcheck2/reduction_gpu.h index 0a8601c23..a423519f7 100644 --- a/sxt/proof/sumcheck2/reduction_gpu.h +++ b/sxt/proof/sumcheck2/reduction_gpu.h @@ -1,4 +1,98 @@ #pragma once +#include "sxt/algorithm/base/identity_mapper.h" +#include "sxt/algorithm/reduction/kernel_fit.h" +#include "sxt/algorithm/reduction/thread_reduction.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/field/element.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/device/synchronization.h" +#include "sxt/execution/kernel/kernel_dims.h" +#include "sxt/execution/kernel/launch.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/async_device_resource.h" + namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// reduction_kernel +//-------------------------------------------------------------------------------------------------- +template +__global__ static void reduction_kernel(T* __restrict__ out, + const T* __restrict__ partials, + unsigned n) noexcept { + auto thread_index = threadIdx.x; + auto block_index = blockIdx.x; + auto coefficient_index = blockIdx.y; + auto index = block_index * (BlockSize * 2) + thread_index; + auto step = BlockSize * 2 * gridDim.x; + __shared__ T shared_data[2 * BlockSize]; + + // coefficient adjustment + out += coefficient_index; + partials += coefficient_index * n; + + // mapper + algb::identity_mapper mapper{partials}; + + // reduce + /* algr::thread_reduce(out + block_index, shared_data, mapper, n, step, */ + /* thread_index, index); */ +} + +#if 0 +//-------------------------------------------------------------------------------------------------- +// reduce_sums +//-------------------------------------------------------------------------------------------------- +xena::future<> reduce_sums(basct::span p, basdv::stream& stream, + basct::cspan partial_terms) noexcept { + auto num_coefficients = p.size(); + auto n = partial_terms.size() / num_coefficients; + SXT_DEBUG_ASSERT( + // clang-format off + n > 0 && + partial_terms.size() == num_coefficients * n && + basdv::is_host_pointer(p.data()) && + basdv::is_active_device_pointer(partial_terms.data()) + // clang-format on + ); + auto dims = algr::fit_reduction_kernel(n); + + // p_dev + memr::async_device_resource resource{stream}; + memmg::managed_array p_dev{num_coefficients * dims.num_blocks, &resource}; + + // launch kernel + xenk::launch_kernel(dims.block_size, [&]( + std::integral_constant) noexcept { + reduction_kernel + <<>>( + p_dev.data(), partial_terms.data(), n); + }); + + // copy polynomial to host + memmg::managed_array p_host_data; + basct::span p_host = p; + if (dims.num_blocks > 1) { + p_host_data.resize(p_dev.size()); + p_host = p_host_data; + } + basdv::async_copy_device_to_host(p_host, p_dev, stream); + co_await xendv::await_stream(stream); + + // complete reduction on host if necessary + if (dims.num_blocks == 1) { + co_return; + } + for (unsigned coefficient_index = 0; coefficient_index < num_coefficients; ++coefficient_index) { + p[coefficient_index] = p_host[coefficient_index * dims.num_blocks]; + for (unsigned block_index = 1; block_index < dims.num_blocks; ++block_index) { + s25o::add(p[coefficient_index], p[coefficient_index], + p_host[coefficient_index * dims.num_blocks + block_index]); + } + } +} +#endif } // namespace sxt::prfsk2 From 43cdbae73ea19bfed7b2255058f604293707d9c4 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 09:23:40 -0800 Subject: [PATCH 54/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 1 + sxt/proof/sumcheck2/reduction_gpu.h | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 4076e90f6..903d69bf7 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -228,6 +228,7 @@ sxt_cc_component( "//sxt/base/device:memory_utility", "//sxt/base/device:stream", "//sxt/base/error:assert", + "//sxt/base/field:accumulator", "//sxt/base/field:element", "//sxt/execution/async:future_fwd", "//sxt/execution/async:coroutine", diff --git a/sxt/proof/sumcheck2/reduction_gpu.h b/sxt/proof/sumcheck2/reduction_gpu.h index a423519f7..7a53effa3 100644 --- a/sxt/proof/sumcheck2/reduction_gpu.h +++ b/sxt/proof/sumcheck2/reduction_gpu.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #include "sxt/algorithm/base/identity_mapper.h" #include "sxt/algorithm/reduction/kernel_fit.h" @@ -6,6 +6,7 @@ #include "sxt/base/device/memory_utility.h" #include "sxt/base/device/stream.h" #include "sxt/base/error/assert.h" +#include "sxt/base/field/accumulator.h" #include "sxt/base/field/element.h" #include "sxt/execution/async/coroutine.h" #include "sxt/execution/async/future.h" @@ -38,8 +39,8 @@ __global__ static void reduction_kernel(T* __restrict__ out, algb::identity_mapper mapper{partials}; // reduce - /* algr::thread_reduce(out + block_index, shared_data, mapper, n, step, */ - /* thread_index, index); */ + algr::thread_reduce, BlockSize>(out + block_index, shared_data, mapper, n, + step, thread_index, index); } #if 0 From ca68de6767a5ebb46298a3319f4b5ce84ba27eac Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 09:25:19 -0800 Subject: [PATCH 55/83] rework sumcheck --- sxt/proof/sumcheck2/reduction_gpu.h | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sxt/proof/sumcheck2/reduction_gpu.h b/sxt/proof/sumcheck2/reduction_gpu.h index 7a53effa3..3624e5bee 100644 --- a/sxt/proof/sumcheck2/reduction_gpu.h +++ b/sxt/proof/sumcheck2/reduction_gpu.h @@ -43,12 +43,12 @@ __global__ static void reduction_kernel(T* __restrict__ out, step, thread_index, index); } -#if 0 //-------------------------------------------------------------------------------------------------- // reduce_sums //-------------------------------------------------------------------------------------------------- -xena::future<> reduce_sums(basct::span p, basdv::stream& stream, - basct::cspan partial_terms) noexcept { +template +xena::future<> reduce_sums(basct::span p, basdv::stream& stream, + basct::cspan partial_terms) noexcept { auto num_coefficients = p.size(); auto n = partial_terms.size() / num_coefficients; SXT_DEBUG_ASSERT( @@ -63,7 +63,7 @@ xena::future<> reduce_sums(basct::span p, basdv::stream& stream, // p_dev memr::async_device_resource resource{stream}; - memmg::managed_array p_dev{num_coefficients * dims.num_blocks, &resource}; + memmg::managed_array p_dev{num_coefficients * dims.num_blocks, &resource}; // launch kernel xenk::launch_kernel(dims.block_size, [&]( @@ -74,8 +74,8 @@ xena::future<> reduce_sums(basct::span p, basdv::stream& stream, }); // copy polynomial to host - memmg::managed_array p_host_data; - basct::span p_host = p; + memmg::managed_array p_host_data; + basct::span p_host = p; if (dims.num_blocks > 1) { p_host_data.resize(p_dev.size()); p_host = p_host_data; @@ -90,10 +90,9 @@ xena::future<> reduce_sums(basct::span p, basdv::stream& stream, for (unsigned coefficient_index = 0; coefficient_index < num_coefficients; ++coefficient_index) { p[coefficient_index] = p_host[coefficient_index * dims.num_blocks]; for (unsigned block_index = 1; block_index < dims.num_blocks; ++block_index) { - s25o::add(p[coefficient_index], p[coefficient_index], - p_host[coefficient_index * dims.num_blocks + block_index]); + add(p[coefficient_index], p[coefficient_index], + p_host[coefficient_index * dims.num_blocks + block_index]); } } } -#endif } // namespace sxt::prfsk2 From c6ebaf835f7ee1c64d5993c1c66c7241eef5db10 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 09:26:55 -0800 Subject: [PATCH 56/83] rework sumcheck --- sxt/proof/sumcheck2/reduction_gpu.t.cc | 35 +++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/sxt/proof/sumcheck2/reduction_gpu.t.cc b/sxt/proof/sumcheck2/reduction_gpu.t.cc index 0a087425f..6fd017a1f 100644 --- a/sxt/proof/sumcheck2/reduction_gpu.t.cc +++ b/sxt/proof/sumcheck2/reduction_gpu.t.cc @@ -8,7 +8,7 @@ #include "sxt/execution/schedule/scheduler.h" #include "sxt/memory/resource/managed_device_resource.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; @@ -16,4 +16,37 @@ using namespace sxt::prfsk2; using s25t::operator""_s25; TEST_CASE("we can reduce sumcheck polynomials") { + using T = s25t::element; + std::vector p; + std::pmr::vector partial_terms{memr::get_managed_device_resource()}; + + basdv::stream stream; + + SECTION("we can reduce a sum with a single term") { + p.resize(1); + partial_terms = {0x123_s25}; + auto fut = reduce_sums(p, stream, partial_terms); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == 0x123_s25); + } + + SECTION("we can reduce two terms") { + p.resize(1); + partial_terms = {0x123_s25, 0x456_s25}; + auto fut = reduce_sums(p, stream, partial_terms); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == 0x123_s25 + 0x456_s25); + } + + SECTION("we can reduce multiple coefficients") { + p.resize(2); + partial_terms = {0x123_s25, 0x456_s25}; + auto fut = reduce_sums(p, stream, partial_terms); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(p[0] == 0x123_s25); + REQUIRE(p[1] == 0x456_s25); + } } From 2f0a9527b8832d21df24422400cfb711cc513043 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 09:50:19 -0800 Subject: [PATCH 57/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 2 +- sxt/proof/sumcheck2/sum_gpu.h | 10 +++++----- sxt/proof/sumcheck2/sum_gpu.t.cc | 2 -- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 903d69bf7..48ad9447e 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -249,7 +249,7 @@ sxt_cc_component( ":mle_utility", ":polynomial_mapper", ":polynomial_reducer", - # ":reduction_gpu", + ":reduction_gpu", "//sxt/algorithm/reduction:kernel_fit", "//sxt/algorithm/reduction:thread_reduction", "//sxt/base/device:memory_utility", diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index 9aeb97213..80e4bcc59 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -2,12 +2,12 @@ #include -#include "sxt/base/field/element.h" #include "sxt/algorithm/reduction/kernel_fit.h" #include "sxt/algorithm/reduction/thread_reduction.h" #include "sxt/base/device/memory_utility.h" #include "sxt/base/device/state.h" #include "sxt/base/device/stream.h" +#include "sxt/base/field/element.h" #include "sxt/base/iterator/split.h" #include "sxt/base/num/ceil_log2.h" #include "sxt/base/num/constexpr_switch.h" @@ -19,12 +19,12 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" #include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck2/polynomial_mapper.h" -#include "sxt/proof/sumcheck2/polynomial_reducer.h" #include "sxt/proof/sumcheck2/constant.h" #include "sxt/proof/sumcheck2/device_cache.h" #include "sxt/proof/sumcheck2/mle_utility.h" -/* #include "sxt/proof/sumcheck/reduction_gpu.h" */ +#include "sxt/proof/sumcheck2/polynomial_mapper.h" +#include "sxt/proof/sumcheck2/polynomial_reducer.h" +#include "sxt/proof/sumcheck2/reduction_gpu.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- @@ -126,7 +126,7 @@ static xena::future<> partial_sum(basct::span p, basdv::stream& stream, }); // reduce partials - co_await reduce_sums(p, stream, partials); + co_await reduce_sums(p, stream, partials); } //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/sum_gpu.t.cc b/sxt/proof/sumcheck2/sum_gpu.t.cc index 1d81371b6..67740bc26 100644 --- a/sxt/proof/sumcheck2/sum_gpu.t.cc +++ b/sxt/proof/sumcheck2/sum_gpu.t.cc @@ -23,7 +23,6 @@ TEST_CASE("we can sum MLEs") { std::vector mles; std::vector p(2); -#if 0 SECTION("we can sum an MLE with a single term and n=1") { product_table = {{0x1_s25, 1}}; product_terms = {0}; @@ -122,5 +121,4 @@ TEST_CASE("we can sum MLEs") { REQUIRE(p[0] == mles[0] + mles[1]); REQUIRE(p[1] == (mles[2] - mles[0]) + (mles[3] - mles[1])); } -#endif } From f09aab1f2f250ec96f78bce4ae7e0b56cd5b8b55 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 09:51:39 -0800 Subject: [PATCH 58/83] rework sumcheck --- sxt/proof/sumcheck2/gpu_driver.t.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sxt/proof/sumcheck2/gpu_driver.t.cc b/sxt/proof/sumcheck2/gpu_driver.t.cc index 0fc2bb4d4..b6168a7a7 100644 --- a/sxt/proof/sumcheck2/gpu_driver.t.cc +++ b/sxt/proof/sumcheck2/gpu_driver.t.cc @@ -1,5 +1,12 @@ #include "sxt/proof/sumcheck2/gpu_driver.h" +#include "sxt/proof/sumcheck2/driver_test.h" #include "sxt/base/test/unit_test.h" -TEST_CASE("todo") {} +using namespace sxt; +using namespace sxt::prfsk2; + +TEST_CASE("we can perform the primitive operations for sumcheck proofs") { + gpu_driver drv; + exercise_driver(drv); +} From e38643d02261efd3e94a52eebe7b27cee7d3830b Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 10:22:05 -0800 Subject: [PATCH 59/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 37 +++++++++++++++++ sxt/proof/sumcheck2/fold_gpu.cc | 1 + sxt/proof/sumcheck2/fold_gpu.h | 69 +++++++++++++++++++++++++++++++ sxt/proof/sumcheck2/fold_gpu.t.cc | 67 ++++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+) create mode 100644 sxt/proof/sumcheck2/fold_gpu.cc create mode 100644 sxt/proof/sumcheck2/fold_gpu.h create mode 100644 sxt/proof/sumcheck2/fold_gpu.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 48ad9447e..5041268ae 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -85,6 +85,43 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "fold_gpu", + test_deps = [ + "//sxt/base/iterator:split", + "//sxt/base/test:unit_test", + "//sxt/execution/async:future", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/realization:field", + "//sxt/scalar25/type:literal", + ], + deps = [ + ":mle_utility", + "//sxt/algorithm/iteration:kernel_fit", + "//sxt/base/container:span", + "//sxt/base/error:assert", + "//sxt/base/device:property", + "//sxt/base/device:memory_utility", + "//sxt/base/device:stream", + "//sxt/base/field:element", + "//sxt/base/iterator:split", + "//sxt/base/num:ceil_log2", + "//sxt/scalar25/type:element", + "//sxt/scalar25/type:literal", + "//sxt/scalar25/operation:mul", + "//sxt/scalar25/operation:muladd", + "//sxt/scalar25/operation:sub", + "//sxt/execution/async:coroutine", + "//sxt/execution/async:future", + "//sxt/execution/device:for_each", + "//sxt/execution/device:synchronization", + "//sxt/execution/kernel:kernel_dims", + "//sxt/memory/management:managed_array", + "//sxt/memory/resource:async_device_resource", + "//sxt/memory/resource:device_resource", + ], +) + sxt_cc_component( name = "gpu_driver", deps = [ diff --git a/sxt/proof/sumcheck2/fold_gpu.cc b/sxt/proof/sumcheck2/fold_gpu.cc new file mode 100644 index 000000000..5b542af66 --- /dev/null +++ b/sxt/proof/sumcheck2/fold_gpu.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/fold_gpu.h" diff --git a/sxt/proof/sumcheck2/fold_gpu.h b/sxt/proof/sumcheck2/fold_gpu.h new file mode 100644 index 000000000..2e98c986f --- /dev/null +++ b/sxt/proof/sumcheck2/fold_gpu.h @@ -0,0 +1,69 @@ +#pragma once + +#include + +#include "sxt/algorithm/iteration/kernel_fit.h" +#include "sxt/base/container/span.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/property.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/field/element.h" +#include "sxt/base/iterator/split.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/device/for_each.h" +#include "sxt/execution/device/synchronization.h" +#include "sxt/execution/kernel/kernel_dims.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/async_device_resource.h" +#include "sxt/memory/resource/device_resource.h" +#include "sxt/proof/sumcheck/mle_utility.h" +#include "sxt/scalar25/operation/mul.h" +#include "sxt/scalar25/operation/muladd.h" +#include "sxt/scalar25/operation/sub.h" +#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/type/literal.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// fold_kernel +//-------------------------------------------------------------------------------------------------- +template +__global__ void fold_kernel(T* __restrict__ mles, unsigned np, unsigned split, T r, + T one_m_r) noexcept { + auto thread_index = threadIdx.x; + auto block_index = blockIdx.x; + auto block_size = blockDim.x; + auto k = basn::divide_up(split, gridDim.x * block_size) * block_size; + auto block_first = block_index * k; + assert(block_first < split && "every block should be active"); + auto m = umin(block_first + k, split); + + // adjust mles + mles += np * blockIdx.y; + + // fold + auto index = block_first + thread_index; + for (; index < m; index += block_size) { + auto x = mles[index]; + mul(x, x, one_m_r); + auto index_p = split + index; + if (index_p < np) { + muladd(x, mles[index_p], r, x); + } + mles[index] = x; + } +} + +//-------------------------------------------------------------------------------------------------- +// fold_gpu +//-------------------------------------------------------------------------------------------------- +/* xena::future<> fold_gpu(basct::span mles_p, */ +/* const basit::split_options& split_options, basct::cspan mles, */ +/* unsigned n, const s25t::element& r) noexcept; */ +/* */ +/* xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, */ +/* unsigned n, const s25t::element& r) noexcept; */ +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/fold_gpu.t.cc b/sxt/proof/sumcheck2/fold_gpu.t.cc new file mode 100644 index 000000000..bd515aab5 --- /dev/null +++ b/sxt/proof/sumcheck2/fold_gpu.t.cc @@ -0,0 +1,67 @@ +#include "sxt/proof/sumcheck2/fold_gpu.h" + +#include + +#include "sxt/base/iterator/split.h" +#include "sxt/base/test/unit_test.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/schedule/scheduler.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/realization/field.h" +#include "sxt/scalar25/type/literal.h" + +using namespace sxt; +using namespace sxt::prfsk2; +using s25t::operator""_s25; + +TEST_CASE("we can fold scalars using the gpu") { +#if 0 + std::vector mles, mles_p, expected; + + auto r = 0xabc123_s25; + auto one_m_r = 0x1_s25 - r; + + SECTION("we can fold a single mle with n=2") { + mles = {0x1_s25, 0x2_s25}; + mles_p.resize(1); + auto fut = fold_gpu(mles_p, mles, 2, r); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + expected = { + one_m_r * mles[0] + r * mles[1], + }; + REQUIRE(mles_p == expected); + } + + SECTION("we can fold a single mle with n=3") { + mles = {0x123_s25, 0x456_s25, 0x789_s25}; + mles_p.resize(2); + auto fut = fold_gpu(mles_p, mles, 3, r); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + expected = { + one_m_r * mles[0] + r * mles[2], + one_m_r * mles[1], + }; + REQUIRE(mles_p == expected); + } + + SECTION("we can split folds") { + basit::split_options split_options{ + .min_chunk_size = 1, + .max_chunk_size = 1, + .split_factor = 2, + }; + mles = {0x123_s25, 0x456_s25, 0x789_s25, 0x101112_s25}; + mles_p.resize(2); + auto fut = fold_gpu(mles_p, split_options, mles, 4, r); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + expected = { + one_m_r * mles[0] + r * mles[2], + one_m_r * mles[1] + r * mles[3], + }; + REQUIRE(mles_p == expected); + } +#endif +} From 826682a520d25d6a272a17b9c2ec9d2b0ab1dc67 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 10:47:02 -0800 Subject: [PATCH 60/83] rework sumcheck --- sxt/proof/sumcheck2/fold_gpu.h | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/sxt/proof/sumcheck2/fold_gpu.h b/sxt/proof/sumcheck2/fold_gpu.h index 2e98c986f..932a57331 100644 --- a/sxt/proof/sumcheck2/fold_gpu.h +++ b/sxt/proof/sumcheck2/fold_gpu.h @@ -57,6 +57,33 @@ __global__ void fold_kernel(T* __restrict__ mles, unsigned np, unsigned split, T } } +//-------------------------------------------------------------------------------------------------- +// fold_impl +//-------------------------------------------------------------------------------------------------- +template +xena::future<> fold_impl(basct::span mles_p, basct::cspan mles, unsigned n, unsigned mid, + unsigned a, unsigned b, const T& r, const T& one_m_r) noexcept { + auto num_mles = mles.size() / n; + auto split = b - a; + + // copy MLEs to device + basdv::stream stream; + memr::async_device_resource resource{stream}; + memmg::managed_array mles_dev{&resource}; + copy_partial_mles(mles_dev, stream, mles, n, a, b); + + // fold + auto np = mles_dev.size() / num_mles; + auto dims = algi::fit_iteration_kernel(split); + fold_kernel<<(dims.block_size), 0, + stream>>>(mles_dev.data(), np, split, r, one_m_r); + + // copy results back + copy_folded_mles(mles_p, stream, mles_dev, mid, a, b); + + co_await xendv::await_stream(stream); +} + //-------------------------------------------------------------------------------------------------- // fold_gpu //-------------------------------------------------------------------------------------------------- From 5ca3e73f46004ce52b70fdb959d47b2d4118c1eb Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 10:57:19 -0800 Subject: [PATCH 61/83] rework sumcheck --- sxt/proof/sumcheck2/fold_gpu.h | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/sxt/proof/sumcheck2/fold_gpu.h b/sxt/proof/sumcheck2/fold_gpu.h index 932a57331..fcaf7b7c1 100644 --- a/sxt/proof/sumcheck2/fold_gpu.h +++ b/sxt/proof/sumcheck2/fold_gpu.h @@ -84,6 +84,34 @@ xena::future<> fold_impl(basct::span mles_p, basct::cspan mles, unsigned n co_await xendv::await_stream(stream); } +//-------------------------------------------------------------------------------------------------- +// fold_gpu +//-------------------------------------------------------------------------------------------------- +template +xena::future<> fold_gpu(basct::span mles_p, + const basit::split_options& split_options, basct::cspan mles, + unsigned n, const T& r) noexcept { + auto num_mles = mles.size() / n; + auto num_variables = std::max(basn::ceil_log2(n), 1); + auto mid = 1u << (num_variables - 1u); + SXT_DEBUG_ASSERT( + // clang-format off + n > 1 && mles.size() == num_mles * n + // clang-format on + ); + auto one_m_r = T::one(); + sub(one_m_r, one_m_r, r); + + // split + auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, mid}, split_options); + + // fold + co_await xendv::concurrent_for_each( + chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> { + co_await fold_impl(mles_p, mles, n, mid, rng.a(), rng.b(), r, one_m_r); + }); +} + //-------------------------------------------------------------------------------------------------- // fold_gpu //-------------------------------------------------------------------------------------------------- From 771235859929c0e1e8e4bd8e1deb24ccad56acbc Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 11:00:00 -0800 Subject: [PATCH 62/83] rework sumcheck --- sxt/proof/sumcheck2/fold_gpu.h | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sxt/proof/sumcheck2/fold_gpu.h b/sxt/proof/sumcheck2/fold_gpu.h index fcaf7b7c1..aef75009b 100644 --- a/sxt/proof/sumcheck2/fold_gpu.h +++ b/sxt/proof/sumcheck2/fold_gpu.h @@ -112,13 +112,14 @@ xena::future<> fold_gpu(basct::span mles_p, }); } -//-------------------------------------------------------------------------------------------------- -// fold_gpu -//-------------------------------------------------------------------------------------------------- -/* xena::future<> fold_gpu(basct::span mles_p, */ -/* const basit::split_options& split_options, basct::cspan mles, */ -/* unsigned n, const s25t::element& r) noexcept; */ -/* */ -/* xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, */ -/* unsigned n, const s25t::element& r) noexcept; */ +template +xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, + unsigned n, const T& r) noexcept { + basit::split_options split_options{ + .min_chunk_size = 1024u * 128u, + .max_chunk_size = 1024u * 256u, + .split_factor = basdv::get_num_devices(), + }; + co_await fold_gpu(mles_p, split_options, mles, n, r); +} } // namespace sxt::prfsk2 From 7cf50bc245fa1452c66e81e8dfc67d334ef5c1da Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 11:32:15 -0800 Subject: [PATCH 63/83] rework sumcheck --- sxt/proof/sumcheck2/fold_gpu.h | 2 +- sxt/proof/sumcheck2/fold_gpu.t.cc | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sxt/proof/sumcheck2/fold_gpu.h b/sxt/proof/sumcheck2/fold_gpu.h index aef75009b..757b7dba2 100644 --- a/sxt/proof/sumcheck2/fold_gpu.h +++ b/sxt/proof/sumcheck2/fold_gpu.h @@ -19,7 +19,7 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" #include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck/mle_utility.h" +#include "sxt/proof/sumcheck2/mle_utility.h" #include "sxt/scalar25/operation/mul.h" #include "sxt/scalar25/operation/muladd.h" #include "sxt/scalar25/operation/sub.h" diff --git a/sxt/proof/sumcheck2/fold_gpu.t.cc b/sxt/proof/sumcheck2/fold_gpu.t.cc index bd515aab5..ecda894be 100644 --- a/sxt/proof/sumcheck2/fold_gpu.t.cc +++ b/sxt/proof/sumcheck2/fold_gpu.t.cc @@ -15,7 +15,7 @@ using namespace sxt::prfsk2; using s25t::operator""_s25; TEST_CASE("we can fold scalars using the gpu") { -#if 0 + using T = s25t::element; std::vector mles, mles_p, expected; auto r = 0xabc123_s25; @@ -24,7 +24,7 @@ TEST_CASE("we can fold scalars using the gpu") { SECTION("we can fold a single mle with n=2") { mles = {0x1_s25, 0x2_s25}; mles_p.resize(1); - auto fut = fold_gpu(mles_p, mles, 2, r); + auto fut = fold_gpu(mles_p, mles, 2, r); xens::get_scheduler().run(); REQUIRE(fut.ready()); expected = { @@ -36,7 +36,7 @@ TEST_CASE("we can fold scalars using the gpu") { SECTION("we can fold a single mle with n=3") { mles = {0x123_s25, 0x456_s25, 0x789_s25}; mles_p.resize(2); - auto fut = fold_gpu(mles_p, mles, 3, r); + auto fut = fold_gpu(mles_p, mles, 3, r); xens::get_scheduler().run(); REQUIRE(fut.ready()); expected = { @@ -54,7 +54,7 @@ TEST_CASE("we can fold scalars using the gpu") { }; mles = {0x123_s25, 0x456_s25, 0x789_s25, 0x101112_s25}; mles_p.resize(2); - auto fut = fold_gpu(mles_p, split_options, mles, 4, r); + auto fut = fold_gpu(mles_p, split_options, mles, 4, r); xens::get_scheduler().run(); REQUIRE(fut.ready()); expected = { @@ -63,5 +63,4 @@ TEST_CASE("we can fold scalars using the gpu") { }; REQUIRE(mles_p == expected); } -#endif } From 6f0b01411efa11dfaac8b44285edf38fb2d8725d Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 13:48:48 -0800 Subject: [PATCH 64/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 22 ++++++++++ sxt/proof/sumcheck2/chunked_gpu_driver.cc | 1 + sxt/proof/sumcheck2/chunked_gpu_driver.h | 47 +++++++++++++++++++++ sxt/proof/sumcheck2/chunked_gpu_driver.t.cc | 2 + 4 files changed, 72 insertions(+) create mode 100644 sxt/proof/sumcheck2/chunked_gpu_driver.cc create mode 100644 sxt/proof/sumcheck2/chunked_gpu_driver.h create mode 100644 sxt/proof/sumcheck2/chunked_gpu_driver.t.cc diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 5041268ae..925dd1eb6 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -143,6 +143,28 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "chunked_gpu_driver", + test_deps = [ + ":driver_test", + "//sxt/base/test:unit_test", + ], + deps = [ + ":driver", + ":device_cache", + ":fold_gpu", + ":gpu_driver", + ":mle_utility", + ":sum_gpu", + "//sxt/algorithm/iteration:transform", + "//sxt/base/error:assert", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:coroutine", + "//sxt/execution/async:future", + "//sxt/memory/management:managed_array", + ], +) + sxt_cc_component( name = "mle_utility", impl_deps = [ diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.cc b/sxt/proof/sumcheck2/chunked_gpu_driver.cc new file mode 100644 index 000000000..7cda45a3c --- /dev/null +++ b/sxt/proof/sumcheck2/chunked_gpu_driver.cc @@ -0,0 +1 @@ +#include "sxt/proof/sumcheck2/chunked_gpu_driver.h" diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.h b/sxt/proof/sumcheck2/chunked_gpu_driver.h new file mode 100644 index 000000000..f89c07fc5 --- /dev/null +++ b/sxt/proof/sumcheck2/chunked_gpu_driver.h @@ -0,0 +1,47 @@ +#pragma once + +#include + +#include "sxt/algorithm/iteration/transform.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/async/future.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/proof/sumcheck2/device_cache.h" +#include "sxt/proof/sumcheck2/driver.h" +#include "sxt/proof/sumcheck2/fold_gpu.h" +#include "sxt/proof/sumcheck2/gpu_driver.h" +#include "sxt/proof/sumcheck2/mle_utility.h" +#include "sxt/proof/sumcheck2/sum_gpu.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// chunked_gpu_driver +//-------------------------------------------------------------------------------------------------- +template +class chunked_gpu_driver final : public driver { +public: + explicit chunked_gpu_driver(double no_chunk_cutoff = 0.5) noexcept + : no_chunk_cutoff_{no_chunk_cutoff} {} + + // driver + xena::future> + make_workspace(basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned n) const noexcept override { + return {}; + } + + xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { + return {}; + } + + xena::future<> fold(workspace& ws, const T& r) const noexcept override { + return {}; + } + +private: + double no_chunk_cutoff_; +}; +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc b/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc new file mode 100644 index 000000000..bee129fa7 --- /dev/null +++ b/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc @@ -0,0 +1,2 @@ +#include "sxt/proof/sumcheck2/chunked_gpu_driver.h" + From adc44eb397203538754e231b436aa63e9e922a36 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 13:53:22 -0800 Subject: [PATCH 65/83] fill in chunked driver --- sxt/proof/sumcheck2/chunked_gpu_driver.h | 29 +++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.h b/sxt/proof/sumcheck2/chunked_gpu_driver.h index f89c07fc5..3b4c8908c 100644 --- a/sxt/proof/sumcheck2/chunked_gpu_driver.h +++ b/sxt/proof/sumcheck2/chunked_gpu_driver.h @@ -21,16 +21,39 @@ namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- template class chunked_gpu_driver final : public driver { + struct chunked_gpu_workspace final : public workspace { + std::unique_ptr single_gpu_workspace; + + device_cache cache; + memmg::managed_array mles_data; + basct::cspan mles; + unsigned n; + unsigned num_variables; + + chunked_gpu_workspace(basct::cspan> product_table, + basct::cspan product_terms) noexcept + : cache{product_table, product_terms} {} + }; + public: explicit chunked_gpu_driver(double no_chunk_cutoff = 0.5) noexcept : no_chunk_cutoff_{no_chunk_cutoff} {} // driver xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, + make_workspace(basct::cspan mles, basct::cspan> product_table, basct::cspan product_terms, unsigned n) const noexcept override { - return {}; + auto res = std::make_unique(product_table, product_terms); + res->mles = mles; + res->n = n; + res->num_variables = std::max(basn::ceil_log2(n), 1); + auto gpu_memory_fraction = get_gpu_memory_fraction(mles); + if (gpu_memory_fraction <= no_chunk_cutoff_) { + gpu_driver drv; + res->single_gpu_workspace = + co_await drv.make_workspace(mles, product_table, product_terms, n); + } + co_return std::unique_ptr(std::move(res)); } xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { From 8c7958cf6a7f08ccc8dae1b6d5fc476647c5c6cc Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 14:01:47 -0800 Subject: [PATCH 66/83] rework sumcheck --- sxt/proof/sumcheck2/chunked_gpu_driver.h | 56 ++++++++++++++++++++- sxt/proof/sumcheck2/chunked_gpu_driver.t.cc | 17 +++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.h b/sxt/proof/sumcheck2/chunked_gpu_driver.h index 3b4c8908c..6f0d28d87 100644 --- a/sxt/proof/sumcheck2/chunked_gpu_driver.h +++ b/sxt/proof/sumcheck2/chunked_gpu_driver.h @@ -35,6 +35,25 @@ class chunked_gpu_driver final : public driver { : cache{product_table, product_terms} {} }; + static xena::future<> try_make_single_gpu_workspace(chunked_gpu_workspace& work, + double no_chunk_cutoff) noexcept { + auto gpu_memory_fraction = get_gpu_memory_fraction(work.mles); + if (gpu_memory_fraction > no_chunk_cutoff) { + co_return; + } + + // construct single gpu workspace + auto cache_data = work.cache.clear(); + gpu_driver drv; + work.single_gpu_workspace = + co_await drv.make_workspace(work.mles, std::move(cache_data->product_table), + std::move(cache_data->product_terms), work.n); + + // free data we no longer need + work.mles_data.reset(); + work.mles = {}; + } + public: explicit chunked_gpu_driver(double no_chunk_cutoff = 0.5) noexcept : no_chunk_cutoff_{no_chunk_cutoff} {} @@ -57,11 +76,44 @@ class chunked_gpu_driver final : public driver { } xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { - return {}; + auto& work = static_cast(ws); + if (work.single_gpu_workspace) { + gpu_driver drv; + co_return co_await drv.sum(polynomial, *work.single_gpu_workspace); + } + co_await sum_gpu(polynomial, work.cache, work.mles, work.n); } xena::future<> fold(workspace& ws, const T& r) const noexcept override { - return {}; + auto& work = static_cast(ws); + if (work.single_gpu_workspace) { + gpu_driver drv; + co_return co_await drv.fold(*work.single_gpu_workspace, r); + } + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + auto num_mles = work.mles.size() / n; + SXT_RELEASE_ASSERT( + // clang-format off + work.n >= mid && work.mles.size() % n == 0 + // clang-format on + ); + + auto one_m_r = T::one(); + sub(one_m_r, one_m_r, r); + + // fold + memmg::managed_array mles_p(num_mles * mid); + co_await fold_gpu(mles_p, work.mles, n, r); + + // update + work.n = mid; + --work.num_variables; + work.mles_data = std::move(mles_p); + work.mles = work.mles_data; + + // check if we should fall back to single gpu workspace + co_await try_make_single_gpu_workspace(work, no_chunk_cutoff_); } private: diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc b/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc index bee129fa7..751b722e9 100644 --- a/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc +++ b/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc @@ -1,2 +1,19 @@ #include "sxt/proof/sumcheck2/chunked_gpu_driver.h" +#include "sxt/base/test/unit_test.h" +#include "sxt/proof/sumcheck2/driver_test.h" + +using namespace sxt; +using namespace sxt::prfsk2; + +TEST_CASE("we can perform the primitive operations for sumcheck proofs") { + SECTION("we handle the case when only chunking is used") { + chunked_gpu_driver drv{0.0}; + exercise_driver(drv); + } + + SECTION("we handle the case when the chunked driver falls back to the single gpu driver") { + chunked_gpu_driver drv{1.0}; + exercise_driver(drv); + } +} From 641a6ea2c9973561f8e3195aa67dc54de8fea293 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 18:39:24 -0800 Subject: [PATCH 67/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 14 +- sxt/proof/sumcheck2/proof_computation.t.cc | 219 ++++++++++++++++++++- 2 files changed, 231 insertions(+), 2 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 925dd1eb6..e2c91f1bc 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -237,7 +237,19 @@ sxt_cc_component( "//sxt/execution/async:coroutine", ], test_deps = [ - "//sxt/base/test:unit_test", + ":chunked_gpu_driver", + ":cpu_driver", + ":gpu_driver", + ":mle_utility", + ":reference_transcript", + # ":sumcheck_random", + ":verification", + "//sxt/base/test:unit_test", + "//sxt/execution/async:future", + "//sxt/execution/schedule:scheduler", + "//sxt/proof/transcript", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/realization:field", ], ) diff --git a/sxt/proof/sumcheck2/proof_computation.t.cc b/sxt/proof/sumcheck2/proof_computation.t.cc index 608b8e4d8..e5622a335 100644 --- a/sxt/proof/sumcheck2/proof_computation.t.cc +++ b/sxt/proof/sumcheck2/proof_computation.t.cc @@ -1,5 +1,222 @@ #include "sxt/proof/sumcheck2/proof_computation.h" +#include +#include + +#include "sxt/base/container/span_utility.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/base/num/fast_random_number_generator.h" #include "sxt/base/test/unit_test.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/schedule/scheduler.h" +#include "sxt/proof/sumcheck2/chunked_gpu_driver.h" +#include "sxt/proof/sumcheck2/cpu_driver.h" +#include "sxt/proof/sumcheck2/gpu_driver.h" +#include "sxt/proof/sumcheck2/mle_utility.h" +#include "sxt/proof/sumcheck2/polynomial_utility.h" +#include "sxt/proof/sumcheck2/reference_transcript.h" +/* #include "sxt/proof/sumcheck2/sumcheck_random.h" */ +#include "sxt/proof/sumcheck2/verification.h" +#include "sxt/proof/transcript/transcript.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/realization/field.h" +#include "sxt/scalar25/type/literal.h" + +using namespace sxt; +using namespace sxt::prfsk2; +using s25t::operator""_s25; + +using T = s25t::element; + +static void test_proof(const driver& drv) noexcept { + prft::transcript base_transcript{"abc"}; + reference_transcript transcript{base_transcript}; + std::vector polynomials(2); + std::vector evaluation_point(1); + std::vector mles = { + 0x8_s25, + 0x3_s25, + }; + std::vector> product_table = { + {0x1_s25, 1}, + }; + std::vector product_terms = {0}; + + SECTION("we can prove a sum with n=1") { + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 1); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(polynomials[0] == mles[0]); + REQUIRE(polynomials[1] == -mles[0]); + } + + SECTION("we can prove a sum with a single variable") { + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 2); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(polynomials[0] == mles[0]); + REQUIRE(polynomials[1] == mles[1] - mles[0]); + } + + SECTION("we can prove a sum degree greater than 1") { + product_table = { + {0x1_s25, 2}, + }; + product_terms = {0, 0}; + polynomials.resize(3); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 2); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(polynomials[0] == mles[0] * mles[0]); + REQUIRE(polynomials[1] == 0x2_s25 * (mles[1] - mles[0]) * mles[0]); + REQUIRE(polynomials[2] == (mles[1] - mles[0]) * (mles[1] - mles[0])); + } + + SECTION("we can prove a sum with multiple MLEs") { + product_table = { + {0x1_s25, 2}, + }; + product_terms = {0, 1}; + polynomials.resize(3); + mles.push_back(0x7_s25); + mles.push_back(0x10_s25); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 2); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(polynomials[0] == mles[0] * mles[2]); + REQUIRE(polynomials[1] == (mles[1] - mles[0]) * mles[2] + (mles[3] - mles[2]) * mles[0]); + REQUIRE(polynomials[2] == (mles[1] - mles[0]) * (mles[3] - mles[2])); + } + + SECTION("we can prove a sum where the term multiplier is different from one") { + product_table[0].first = 0x2_s25; + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 2); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(polynomials[0] == 0x2_s25 * mles[0]); + REQUIRE(polynomials[1] == 0x2_s25 * (mles[1] - mles[0])); + } + + SECTION("we can prove a sum with two variables") { + mles.push_back(0x4_s25); + mles.push_back(0x7_s25); + polynomials.resize(4); + evaluation_point.resize(2); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 4); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(polynomials[0] == mles[0] + mles[1]); + REQUIRE(polynomials[1] == (mles[2] - mles[0]) + (mles[3] - mles[1])); + + auto r = evaluation_point[0]; + mles[0] = mles[0] * (0x1_s25 - r) + mles[2] * r; + mles[1] = mles[1] * (0x1_s25 - r) + mles[3] * r; + + REQUIRE(polynomials[2] == mles[0]); + REQUIRE(polynomials[3] == mles[1] - mles[0]); + } + + SECTION("we can prove a sum with n=3") { + mles.push_back(0x4_s25); + polynomials.resize(4); + evaluation_point.resize(2); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 3); + xens::get_scheduler().run(); + REQUIRE(fut.ready()); + REQUIRE(polynomials[0] == mles[0] + mles[1]); + REQUIRE(polynomials[1] == (mles[2] - mles[0]) - mles[1]); + + auto r = evaluation_point[0]; + mles[0] = mles[0] * (0x1_s25 - r) + mles[2] * r; + mles[1] = mles[1] * (0x1_s25 - r); + + REQUIRE(polynomials[2] == mles[0]); + REQUIRE(polynomials[3] == mles[1] - mles[0]); + } + +#if 0 + SECTION("we can verify random sumcheck problems") { + basn::fast_random_number_generator rng{1, 2}; + + for (unsigned i = 0; i < 10; ++i) { + random_sumcheck_descriptor descriptor; + unsigned n; + generate_random_sumcheck_problem(mles, product_table, product_terms, n, rng, descriptor); + + unsigned polynomial_length = 0; + for (auto [_, len] : product_table) { + polynomial_length = std::max(polynomial_length, len + 1u); + } + + auto num_variables = n == 1 ? 1 : basn::ceil_log2(n); + evaluation_point.resize(num_variables); + polynomials.resize(polynomial_length * num_variables); + + // prove + { + prft::transcript base_transcript{"abc"}; + reference_transcript transcript{base_transcript}; + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, n); + xens::get_scheduler().run(); + } + + // we can verify + { + prft::transcript base_transcript{"abc"}; + reference_transcript transcript{base_transcript}; + s25t::element expected_sum; + sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); + auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + polynomials, polynomial_length - 1u); + REQUIRE(valid); + } + + // verification fails if we break the proof + { + prft::transcript base_transcript{"abc"}; + reference_transcript transcript{base_transcript}; + s25t::element expected_sum; + sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); + polynomials[polynomials.size() - 1] = polynomials[0] + polynomials[1]; + auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + polynomials, polynomial_length - 1u); + REQUIRE(!valid); + } + } + } +#endif +} + +TEST_CASE("we can create a sumcheck proof") { +#if 0 + SECTION("we can prove with the cpu driver") { + cpu_driver drv; + test_proof(drv); + } + + SECTION("we can prove with the gpu driver") { + gpu_driver drv; + test_proof(drv); + } + + SECTION("we can prove with the chunked gpu driver") { + chunked_gpu_driver drv{0.0}; + test_proof(drv); + } -TEST_CASE("todo") {} + SECTION("we can prove with a chunked driver that switches over to the single gpu driver") { + std::vector mles(4); + auto fraction = get_gpu_memory_fraction(mles); + chunked_gpu_driver drv{fraction}; + test_proof(drv); + } +#endif +} From 96e775c72c9a27051d977150b0516b39f654ca5e Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 18:42:45 -0800 Subject: [PATCH 68/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 14 ++++++ sxt/proof/sumcheck2/sumcheck_random.cc | 61 ++++++++++++++++++++++++++ sxt/proof/sumcheck2/sumcheck_random.h | 41 +++++++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 sxt/proof/sumcheck2/sumcheck_random.cc create mode 100644 sxt/proof/sumcheck2/sumcheck_random.h diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index e2c91f1bc..a073f144a 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -262,6 +262,20 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "sumcheck_random", + impl_deps = [ + "//sxt/base/error:assert", + "//sxt/base/num:fast_random_number_generator", + "//sxt/scalar25/random:element", + "//sxt/scalar25/type:element", + ], + with_test = False, + deps = [ + ":constant", + ], +) + sxt_cc_component( name = "reference_transcript", deps = [ diff --git a/sxt/proof/sumcheck2/sumcheck_random.cc b/sxt/proof/sumcheck2/sumcheck_random.cc new file mode 100644 index 000000000..19dd9bba3 --- /dev/null +++ b/sxt/proof/sumcheck2/sumcheck_random.cc @@ -0,0 +1,61 @@ +#include "sxt/proof/sumcheck2/sumcheck_random.h" + +#include + +#include "sxt/base/error/assert.h" +#include "sxt/base/num/fast_random_number_generator.h" +#include "sxt/scalar25/random/element.h" +#include "sxt/scalar25/type/element.h" + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// generate_random_sumcheck_problem +//-------------------------------------------------------------------------------------------------- +void generate_random_sumcheck_problem( + std::vector& mles, + std::vector>& product_table, + std::vector& product_terms, unsigned& n, basn::fast_random_number_generator& rng, + const random_sumcheck_descriptor& descriptor) noexcept { + std::mt19937 rng_p{rng()}; + + // n + SXT_RELEASE_ASSERT(descriptor.min_length <= descriptor.max_length); + std::uniform_int_distribution n_dist{descriptor.min_length, descriptor.max_length}; + n = n_dist(rng_p); + + // num_mles + SXT_RELEASE_ASSERT(descriptor.min_num_mles <= descriptor.max_num_mles); + std::uniform_int_distribution num_mles_dist{descriptor.min_num_mles, + descriptor.max_num_mles}; + auto num_mles = num_mles_dist(rng_p); + + // num_products + SXT_RELEASE_ASSERT(descriptor.min_num_products <= descriptor.max_num_products); + std::uniform_int_distribution num_products_dist{descriptor.min_num_products, + descriptor.max_num_products}; + auto num_products = num_products_dist(rng_p); + + // mles + mles.resize(n * num_mles); + s25rn::generate_random_elements(mles, rng); + + // product_table + unsigned num_terms = 0; + product_table.resize(num_products); + SXT_RELEASE_ASSERT(descriptor.min_product_length <= descriptor.max_product_length); + std::uniform_int_distribution product_length_dist{descriptor.min_product_length, + descriptor.max_product_length}; + for (auto& [s, len] : product_table) { + s25rn::generate_random_element(s, rng); + len = product_length_dist(rng_p); + num_terms += len; + } + + // product_terms + product_terms.resize(num_terms); + std::uniform_int_distribution mle_dist{0, num_mles - 1}; + for (auto& term : product_terms) { + term = mle_dist(rng_p); + } +} +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/sumcheck_random.h b/sxt/proof/sumcheck2/sumcheck_random.h new file mode 100644 index 000000000..780c7ec4d --- /dev/null +++ b/sxt/proof/sumcheck2/sumcheck_random.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +#include "sxt/proof/sumcheck2/constant.h" + +namespace sxt::s25t { +class element; +} +namespace sxt::basn { +class fast_random_number_generator; +} + +namespace sxt::prfsk2 { +//-------------------------------------------------------------------------------------------------- +// random_sumcheck_descriptor +//-------------------------------------------------------------------------------------------------- +struct random_sumcheck_descriptor { + unsigned min_length = 1; + unsigned max_length = 10; + + unsigned min_num_products = 1; + unsigned max_num_products = 5; + + unsigned min_product_length = 2; + unsigned max_product_length = max_degree_v; + + unsigned min_num_mles = 1; + unsigned max_num_mles = 5; +}; + +//-------------------------------------------------------------------------------------------------- +// generate_random_sumcheck_problem +//-------------------------------------------------------------------------------------------------- +void generate_random_sumcheck_problem( + std::vector& mles, + std::vector>& product_table, + std::vector& product_terms, unsigned& n, basn::fast_random_number_generator& rng, + const random_sumcheck_descriptor& descriptor) noexcept; +} // namespace sxt::prfsk2 From 9223f8cd1809d96360e833713b18e8a323e46bbe Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 18:50:24 -0800 Subject: [PATCH 69/83] rework sumcheck --- sxt/proof/sumcheck2/BUILD | 2 +- sxt/proof/sumcheck2/proof_computation.t.cc | 32 ++++++++++------------ 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index a073f144a..2ea75152c 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -242,7 +242,7 @@ sxt_cc_component( ":gpu_driver", ":mle_utility", ":reference_transcript", - # ":sumcheck_random", + ":sumcheck_random", ":verification", "//sxt/base/test:unit_test", "//sxt/execution/async:future", diff --git a/sxt/proof/sumcheck2/proof_computation.t.cc b/sxt/proof/sumcheck2/proof_computation.t.cc index e5622a335..caa5f76eb 100644 --- a/sxt/proof/sumcheck2/proof_computation.t.cc +++ b/sxt/proof/sumcheck2/proof_computation.t.cc @@ -15,7 +15,7 @@ #include "sxt/proof/sumcheck2/mle_utility.h" #include "sxt/proof/sumcheck2/polynomial_utility.h" #include "sxt/proof/sumcheck2/reference_transcript.h" -/* #include "sxt/proof/sumcheck2/sumcheck_random.h" */ +#include "sxt/proof/sumcheck2/sumcheck_random.h" #include "sxt/proof/sumcheck2/verification.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/scalar25/operation/overload.h" @@ -141,7 +141,6 @@ static void test_proof(const driver& drv) noexcept { REQUIRE(polynomials[3] == mles[1] - mles[0]); } -#if 0 SECTION("we can verify random sumcheck problems") { basn::fast_random_number_generator rng{1, 2}; @@ -162,8 +161,8 @@ static void test_proof(const driver& drv) noexcept { // prove { prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + reference_transcript transcript{base_transcript}; + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, product_terms, n); xens::get_scheduler().run(); } @@ -171,10 +170,10 @@ static void test_proof(const driver& drv) noexcept { // we can verify { prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; + reference_transcript transcript{base_transcript}; s25t::element expected_sum; - sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); - auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); + auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, polynomials, polynomial_length - 1u); REQUIRE(valid); } @@ -182,41 +181,38 @@ static void test_proof(const driver& drv) noexcept { // verification fails if we break the proof { prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; + reference_transcript transcript{base_transcript}; s25t::element expected_sum; - sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); + sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); polynomials[polynomials.size() - 1] = polynomials[0] + polynomials[1]; - auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, polynomials, polynomial_length - 1u); REQUIRE(!valid); } } } -#endif } TEST_CASE("we can create a sumcheck proof") { -#if 0 SECTION("we can prove with the cpu driver") { - cpu_driver drv; + cpu_driver drv; test_proof(drv); } SECTION("we can prove with the gpu driver") { - gpu_driver drv; + gpu_driver drv; test_proof(drv); } SECTION("we can prove with the chunked gpu driver") { - chunked_gpu_driver drv{0.0}; + chunked_gpu_driver drv{0.0}; test_proof(drv); } SECTION("we can prove with a chunked driver that switches over to the single gpu driver") { std::vector mles(4); - auto fraction = get_gpu_memory_fraction(mles); - chunked_gpu_driver drv{fraction}; + auto fraction = get_gpu_memory_fraction(mles); + chunked_gpu_driver drv{fraction}; test_proof(drv); } -#endif } From dc2f966f7e7721c255c613ddbdfd95c5d7c2a56b Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 18:50:46 -0800 Subject: [PATCH 70/83] reformat --- sxt/base/concept/field.cc | 16 ++ sxt/base/concept/field.h | 16 ++ sxt/base/field/BUILD | 6 +- sxt/base/field/accumulator.cc | 16 ++ sxt/base/field/accumulator.h | 23 ++- sxt/base/field/element.cc | 16 ++ sxt/base/field/element.h | 16 ++ sxt/proof/sumcheck2/BUILD | 170 +++++++++--------- sxt/proof/sumcheck2/chunked_gpu_driver.cc | 16 ++ sxt/proof/sumcheck2/chunked_gpu_driver.h | 19 +- sxt/proof/sumcheck2/chunked_gpu_driver.t.cc | 16 ++ sxt/proof/sumcheck2/constant.cc | 16 ++ sxt/proof/sumcheck2/constant.h | 16 ++ sxt/proof/sumcheck2/cpu_driver.cc | 16 ++ sxt/proof/sumcheck2/cpu_driver.h | 22 ++- sxt/proof/sumcheck2/cpu_driver.t.cc | 18 +- sxt/proof/sumcheck2/device_cache.cc | 16 ++ sxt/proof/sumcheck2/device_cache.h | 22 ++- sxt/proof/sumcheck2/device_cache.t.cc | 17 +- sxt/proof/sumcheck2/driver.cc | 16 ++ sxt/proof/sumcheck2/driver.h | 29 ++- sxt/proof/sumcheck2/driver_test.cc | 16 ++ sxt/proof/sumcheck2/driver_test.h | 16 ++ sxt/proof/sumcheck2/fold_gpu.cc | 16 ++ sxt/proof/sumcheck2/fold_gpu.h | 25 ++- sxt/proof/sumcheck2/fold_gpu.t.cc | 16 ++ sxt/proof/sumcheck2/gpu_driver.cc | 16 ++ sxt/proof/sumcheck2/gpu_driver.h | 22 ++- sxt/proof/sumcheck2/gpu_driver.t.cc | 18 +- sxt/proof/sumcheck2/mle_utility.cc | 16 ++ sxt/proof/sumcheck2/mle_utility.h | 27 ++- sxt/proof/sumcheck2/mle_utility.t.cc | 16 ++ sxt/proof/sumcheck2/polynomial_mapper.cc | 16 ++ sxt/proof/sumcheck2/polynomial_mapper.h | 16 ++ sxt/proof/sumcheck2/polynomial_mapper.t.cc | 16 ++ sxt/proof/sumcheck2/polynomial_reducer.cc | 16 ++ sxt/proof/sumcheck2/polynomial_reducer.h | 18 +- sxt/proof/sumcheck2/polynomial_utility.cc | 16 ++ sxt/proof/sumcheck2/polynomial_utility.h | 29 ++- sxt/proof/sumcheck2/polynomial_utility.t.cc | 16 ++ sxt/proof/sumcheck2/proof_computation.cc | 16 ++ sxt/proof/sumcheck2/proof_computation.h | 28 ++- sxt/proof/sumcheck2/proof_computation.t.cc | 22 ++- sxt/proof/sumcheck2/reduction_gpu.cc | 16 ++ sxt/proof/sumcheck2/reduction_gpu.h | 19 +- sxt/proof/sumcheck2/reduction_gpu.t.cc | 16 ++ sxt/proof/sumcheck2/reference_transcript.cc | 16 ++ sxt/proof/sumcheck2/reference_transcript.h | 22 ++- sxt/proof/sumcheck2/reference_transcript.t.cc | 16 ++ sxt/proof/sumcheck2/sum_gpu.cc | 16 ++ sxt/proof/sumcheck2/sum_gpu.h | 40 +++-- sxt/proof/sumcheck2/sum_gpu.t.cc | 16 ++ sxt/proof/sumcheck2/sumcheck_random.cc | 16 ++ sxt/proof/sumcheck2/sumcheck_random.h | 16 ++ sxt/proof/sumcheck2/sumcheck_transcript.cc | 16 ++ sxt/proof/sumcheck2/sumcheck_transcript.h | 21 ++- sxt/proof/sumcheck2/verification.cc | 16 ++ sxt/proof/sumcheck2/verification.h | 21 ++- sxt/proof/sumcheck2/verification.t.cc | 28 ++- sxt/proof/sumcheck2/workspace.cc | 16 ++ sxt/proof/sumcheck2/workspace.h | 16 ++ sxt/scalar25/realization/BUILD | 14 +- sxt/scalar25/realization/field.cc | 16 ++ sxt/scalar25/realization/field.h | 19 +- sxt/scalar25/type/BUILD | 2 +- sxt/scalar25/type/operation_adl_stub.cc | 16 ++ sxt/scalar25/type/operation_adl_stub.h | 17 +- 67 files changed, 1167 insertions(+), 191 deletions(-) diff --git a/sxt/base/concept/field.cc b/sxt/base/concept/field.cc index 9bb167899..16a6731c6 100644 --- a/sxt/base/concept/field.cc +++ b/sxt/base/concept/field.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/base/concept/field.h" diff --git a/sxt/base/concept/field.h b/sxt/base/concept/field.h index 3098cf394..aee6c1224 100644 --- a/sxt/base/concept/field.h +++ b/sxt/base/concept/field.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once namespace sxt::bascpt { diff --git a/sxt/base/field/BUILD b/sxt/base/field/BUILD index 7a8f0c8ad..7a2a89d23 100644 --- a/sxt/base/field/BUILD +++ b/sxt/base/field/BUILD @@ -24,7 +24,7 @@ sxt_cc_component( name = "accumulator", with_test = False, deps = [ - ":element", - "//sxt/base/macro:cuda_callable", - ] + ":element", + "//sxt/base/macro:cuda_callable", + ], ) diff --git a/sxt/base/field/accumulator.cc b/sxt/base/field/accumulator.cc index 5ab79a772..fe0f80fa5 100644 --- a/sxt/base/field/accumulator.cc +++ b/sxt/base/field/accumulator.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/base/field/accumulator.h" diff --git a/sxt/base/field/accumulator.h b/sxt/base/field/accumulator.h index 53a14cb1a..a693cd08a 100644 --- a/sxt/base/field/accumulator.h +++ b/sxt/base/field/accumulator.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "sxt/base/field/element.h" @@ -7,12 +23,9 @@ namespace sxt::basfld { //-------------------------------------------------------------------------------------------------- // accumulator //-------------------------------------------------------------------------------------------------- -template -struct accumulator { +template struct accumulator { using value_type = T; - CUDA_CALLABLE static void accumulate_inplace(T& res, T& e) noexcept { - add(res, res, e); - } + CUDA_CALLABLE static void accumulate_inplace(T& res, T& e) noexcept { add(res, res, e); } }; } // namespace sxt::basfld diff --git a/sxt/base/field/element.cc b/sxt/base/field/element.cc index f95ddeab8..0c91592ea 100644 --- a/sxt/base/field/element.cc +++ b/sxt/base/field/element.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/base/field/element.h" diff --git a/sxt/base/field/element.h b/sxt/base/field/element.h index f5409323b..f6feb08d3 100644 --- a/sxt/base/field/element.h +++ b/sxt/base/field/element.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 2ea75152c..41937f96d 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -29,13 +29,13 @@ sxt_cc_component( ], deps = [ "//sxt/base/container:span", + "//sxt/base/device:device_map", "//sxt/base/device:memory_utility", "//sxt/base/device:state", "//sxt/base/device:stream", - "//sxt/base/device:device_map", "//sxt/base/field:element", - "//sxt/memory/resource:device_resource", "//sxt/memory/management:managed_array", + "//sxt/memory/resource:device_resource", ], ) @@ -43,19 +43,15 @@ sxt_cc_component( name = "driver", with_test = False, deps = [ - ":workspace", + ":workspace", + "//sxt/base/container:span", + "//sxt/base/field:element", "//sxt/execution/async:future_fwd", - "//sxt/base/container:span", - "//sxt/base/field:element", ], ) sxt_cc_component( name = "driver_test", - deps = [ - ":driver", - "//sxt/scalar25/realization:field", - ], impl_deps = [ ":workspace", "//sxt/base/test:unit_test", @@ -66,23 +62,27 @@ sxt_cc_component( "//sxt/scalar25/type:literal", ], with_test = False, + deps = [ + ":driver", + "//sxt/scalar25/realization:field", + ], ) sxt_cc_component( - name = "cpu_driver", - deps = [ - ":driver", - ":polynomial_utility", - "//sxt/base/container:stack_array", - "//sxt/base/error:assert", - "//sxt/base/num:ceil_log2", - "//sxt/execution/async:coroutine", - "//sxt/memory/management:managed_array", - ], - test_deps = [ - ":driver_test", - "//sxt/base/test:unit_test", - ], + name = "cpu_driver", + test_deps = [ + ":driver_test", + "//sxt/base/test:unit_test", + ], + deps = [ + ":driver", + ":polynomial_utility", + "//sxt/base/container:stack_array", + "//sxt/base/error:assert", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:coroutine", + "//sxt/memory/management:managed_array", + ], ) sxt_cc_component( @@ -99,18 +99,13 @@ sxt_cc_component( ":mle_utility", "//sxt/algorithm/iteration:kernel_fit", "//sxt/base/container:span", - "//sxt/base/error:assert", - "//sxt/base/device:property", "//sxt/base/device:memory_utility", + "//sxt/base/device:property", "//sxt/base/device:stream", + "//sxt/base/error:assert", "//sxt/base/field:element", "//sxt/base/iterator:split", "//sxt/base/num:ceil_log2", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:muladd", - "//sxt/scalar25/operation:sub", "//sxt/execution/async:coroutine", "//sxt/execution/async:future", "//sxt/execution/device:for_each", @@ -119,28 +114,33 @@ sxt_cc_component( "//sxt/memory/management:managed_array", "//sxt/memory/resource:async_device_resource", "//sxt/memory/resource:device_resource", + "//sxt/scalar25/operation:mul", + "//sxt/scalar25/operation:muladd", + "//sxt/scalar25/operation:sub", + "//sxt/scalar25/type:element", + "//sxt/scalar25/type:literal", ], ) sxt_cc_component( - name = "gpu_driver", - deps = [ - ":driver", - ":sum_gpu", - "//sxt/base/error:assert", - "//sxt/algorithm/iteration:for_each", - "//sxt/base/device:memory_utility", - "//sxt/base/device:stream", - "//sxt/base/num:ceil_log2", - "//sxt/execution/async:coroutine", - "//sxt/execution/device:synchronization", - "//sxt/memory/management:managed_array", - "//sxt/memory/resource:device_resource", - ], - test_deps = [ - ":driver_test", - "//sxt/base/test:unit_test", - ], + name = "gpu_driver", + test_deps = [ + ":driver_test", + "//sxt/base/test:unit_test", + ], + deps = [ + ":driver", + ":sum_gpu", + "//sxt/algorithm/iteration:for_each", + "//sxt/base/device:memory_utility", + "//sxt/base/device:stream", + "//sxt/base/error:assert", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:coroutine", + "//sxt/execution/device:synchronization", + "//sxt/memory/management:managed_array", + "//sxt/memory/resource:device_resource", + ], ) sxt_cc_component( @@ -150,8 +150,8 @@ sxt_cc_component( "//sxt/base/test:unit_test", ], deps = [ - ":driver", ":device_cache", + ":driver", ":fold_gpu", ":gpu_driver", ":mle_utility", @@ -190,52 +190,42 @@ sxt_cc_component( sxt_cc_component( name = "polynomial_mapper", - deps = [ - ":polynomial_utility", - "//sxt/base/field:element", - "//sxt/base/macro:cuda_callable", - ], test_deps = [ - "//sxt/base/test:unit_test", + "//sxt/base/test:unit_test", + ], + deps = [ + ":polynomial_utility", + "//sxt/base/field:element", + "//sxt/base/macro:cuda_callable", ], ) sxt_cc_component( name = "polynomial_reducer", + with_test = False, deps = [ - "//sxt/base/field:element", - "//sxt/base/macro:cuda_callable", + "//sxt/base/field:element", + "//sxt/base/macro:cuda_callable", ], - with_test = False, ) sxt_cc_component( name = "polynomial_utility", - deps = [ - "//sxt/base/container:span", - "//sxt/base/field:element", - "//sxt/base/macro:cuda_callable", - ], test_deps = [ - "//sxt/base/test:unit_test", + "//sxt/base/test:unit_test", "//sxt/scalar25/operation:overload", "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], + deps = [ + "//sxt/base/container:span", + "//sxt/base/field:element", + "//sxt/base/macro:cuda_callable", + ], ) sxt_cc_component( name = "proof_computation", - deps = [ - ":sumcheck_transcript", - ":driver", - "//sxt/base/container:span", - "//sxt/base/error:assert", - "//sxt/base/field:element", - "//sxt/base/num:ceil_log2", - "//sxt/execution/async:future", - "//sxt/execution/async:coroutine", - ], test_deps = [ ":chunked_gpu_driver", ":cpu_driver", @@ -251,14 +241,24 @@ sxt_cc_component( "//sxt/scalar25/operation:overload", "//sxt/scalar25/realization:field", ], + deps = [ + ":driver", + ":sumcheck_transcript", + "//sxt/base/container:span", + "//sxt/base/error:assert", + "//sxt/base/field:element", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:coroutine", + "//sxt/execution/async:future", + ], ) sxt_cc_component( name = "sumcheck_transcript", with_test = False, deps = [ - "//sxt/base/container:span", - "//sxt/base/field:element", + "//sxt/base/container:span", + "//sxt/base/field:element", ], ) @@ -278,18 +278,18 @@ sxt_cc_component( sxt_cc_component( name = "reference_transcript", - deps = [ - ":sumcheck_transcript", - "//sxt/base/container:span", - "//sxt/base/field:element", - "//sxt/proof/transcript:transcript_utility", - ], test_deps = [ "//sxt/base/test:unit_test", "//sxt/scalar25/operation:overload", "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], + deps = [ + ":sumcheck_transcript", + "//sxt/base/container:span", + "//sxt/base/field:element", + "//sxt/proof/transcript:transcript_utility", + ], ) sxt_cc_component( @@ -315,9 +315,9 @@ sxt_cc_component( "//sxt/base/error:assert", "//sxt/base/field:accumulator", "//sxt/base/field:element", - "//sxt/execution/async:future_fwd", "//sxt/execution/async:coroutine", "//sxt/execution/async:future", + "//sxt/execution/async:future_fwd", "//sxt/execution/device:synchronization", "//sxt/execution/kernel:kernel_dims", "//sxt/execution/kernel:launch", @@ -384,8 +384,8 @@ sxt_cc_component( ], deps = [ ":polynomial_utility", - ":sumcheck_transcript", + ":sumcheck_transcript", "//sxt/base/container:span", - "//sxt/base/log:log", + "//sxt/base/log", ], ) diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.cc b/sxt/proof/sumcheck2/chunked_gpu_driver.cc index 7cda45a3c..8f2afc388 100644 --- a/sxt/proof/sumcheck2/chunked_gpu_driver.cc +++ b/sxt/proof/sumcheck2/chunked_gpu_driver.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/chunked_gpu_driver.h" diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.h b/sxt/proof/sumcheck2/chunked_gpu_driver.h index 6f0d28d87..1871fe920 100644 --- a/sxt/proof/sumcheck2/chunked_gpu_driver.h +++ b/sxt/proof/sumcheck2/chunked_gpu_driver.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -19,8 +35,7 @@ namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // chunked_gpu_driver //-------------------------------------------------------------------------------------------------- -template -class chunked_gpu_driver final : public driver { +template class chunked_gpu_driver final : public driver { struct chunked_gpu_workspace final : public workspace { std::unique_ptr single_gpu_workspace; diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc b/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc index 751b722e9..da60a2e21 100644 --- a/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc +++ b/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/chunked_gpu_driver.h" #include "sxt/base/test/unit_test.h" diff --git a/sxt/proof/sumcheck2/constant.cc b/sxt/proof/sumcheck2/constant.cc index 5b59043b7..babcea3f3 100644 --- a/sxt/proof/sumcheck2/constant.cc +++ b/sxt/proof/sumcheck2/constant.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/constant.h" diff --git a/sxt/proof/sumcheck2/constant.h b/sxt/proof/sumcheck2/constant.h index 53a35ee97..0b8fc5348 100644 --- a/sxt/proof/sumcheck2/constant.h +++ b/sxt/proof/sumcheck2/constant.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once namespace sxt::prfsk2 { diff --git a/sxt/proof/sumcheck2/cpu_driver.cc b/sxt/proof/sumcheck2/cpu_driver.cc index 9abb18f9f..d504430dc 100644 --- a/sxt/proof/sumcheck2/cpu_driver.cc +++ b/sxt/proof/sumcheck2/cpu_driver.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/cpu_driver.h" diff --git a/sxt/proof/sumcheck2/cpu_driver.h b/sxt/proof/sumcheck2/cpu_driver.h index 45e1e497d..508124647 100644 --- a/sxt/proof/sumcheck2/cpu_driver.h +++ b/sxt/proof/sumcheck2/cpu_driver.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "sxt/base/container/stack_array.h" @@ -12,8 +28,7 @@ namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // cpu_driver //-------------------------------------------------------------------------------------------------- -template -class cpu_driver final : public driver { +template class cpu_driver final : public driver { struct cpu_workspace final : public workspace { memmg::managed_array mles; basct::cspan> product_table; @@ -25,8 +40,7 @@ class cpu_driver final : public driver { public: // driver xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, + make_workspace(basct::cspan mles, basct::cspan> product_table, basct::cspan product_terms, unsigned n) const noexcept override { auto res = std::make_unique(); res->mles = memmg::managed_array{mles.begin(), mles.end()}; diff --git a/sxt/proof/sumcheck2/cpu_driver.t.cc b/sxt/proof/sumcheck2/cpu_driver.t.cc index 4da8f688e..4775cb0fb 100644 --- a/sxt/proof/sumcheck2/cpu_driver.t.cc +++ b/sxt/proof/sumcheck2/cpu_driver.t.cc @@ -1,7 +1,23 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/cpu_driver.h" -#include "sxt/proof/sumcheck2/driver_test.h" #include "sxt/base/test/unit_test.h" +#include "sxt/proof/sumcheck2/driver_test.h" using namespace sxt; using namespace sxt::prfsk2; diff --git a/sxt/proof/sumcheck2/device_cache.cc b/sxt/proof/sumcheck2/device_cache.cc index d5e2569e1..5cafd3238 100644 --- a/sxt/proof/sumcheck2/device_cache.cc +++ b/sxt/proof/sumcheck2/device_cache.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/device_cache.h" diff --git a/sxt/proof/sumcheck2/device_cache.h b/sxt/proof/sumcheck2/device_cache.h index b5fc54700..951524f32 100644 --- a/sxt/proof/sumcheck2/device_cache.h +++ b/sxt/proof/sumcheck2/device_cache.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -17,8 +33,7 @@ namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // device_cache_data //-------------------------------------------------------------------------------------------------- -template -struct device_cache_data { +template struct device_cache_data { memmg::managed_array> product_table; memmg::managed_array product_terms; }; @@ -42,8 +57,7 @@ make_device_copy(basct::cspan> product_table, //-------------------------------------------------------------------------------------------------- // device_cache //-------------------------------------------------------------------------------------------------- -template -class device_cache { +template class device_cache { public: device_cache(basct::cspan> product_table, basct::cspan product_terms) noexcept diff --git a/sxt/proof/sumcheck2/device_cache.t.cc b/sxt/proof/sumcheck2/device_cache.t.cc index e2c3df658..63a15b89a 100644 --- a/sxt/proof/sumcheck2/device_cache.t.cc +++ b/sxt/proof/sumcheck2/device_cache.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/device_cache.h" #include @@ -13,7 +29,6 @@ using namespace sxt; using namespace sxt::prfsk2; using s25t::operator""_s25; - TEST_CASE("we can cache device values that don't change as a proof is computed") { using T = s25t::element; std::vector> product_table; diff --git a/sxt/proof/sumcheck2/driver.cc b/sxt/proof/sumcheck2/driver.cc index 74f52ac47..d28bdf310 100644 --- a/sxt/proof/sumcheck2/driver.cc +++ b/sxt/proof/sumcheck2/driver.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/driver.h" diff --git a/sxt/proof/sumcheck2/driver.h b/sxt/proof/sumcheck2/driver.h index bc857ae62..23aeec4d5 100644 --- a/sxt/proof/sumcheck2/driver.h +++ b/sxt/proof/sumcheck2/driver.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -11,19 +27,16 @@ namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // driver //-------------------------------------------------------------------------------------------------- -template -class driver { +template class driver { public: virtual ~driver() noexcept = default; virtual xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, + make_workspace(basct::cspan mles, basct::cspan> product_table, basct::cspan product_terms, unsigned n) const noexcept = 0; - - virtual xena::future<> sum(basct::span polynomial, - workspace& ws) const noexcept = 0; - + + virtual xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept = 0; + virtual xena::future<> fold(workspace& ws, const T& r) const noexcept = 0; }; } // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/driver_test.cc b/sxt/proof/sumcheck2/driver_test.cc index d3da244e5..b74f47ada 100644 --- a/sxt/proof/sumcheck2/driver_test.cc +++ b/sxt/proof/sumcheck2/driver_test.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/driver_test.h" #include diff --git a/sxt/proof/sumcheck2/driver_test.h b/sxt/proof/sumcheck2/driver_test.h index 3fc355147..167554551 100644 --- a/sxt/proof/sumcheck2/driver_test.h +++ b/sxt/proof/sumcheck2/driver_test.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "sxt/proof/sumcheck2/driver.h" diff --git a/sxt/proof/sumcheck2/fold_gpu.cc b/sxt/proof/sumcheck2/fold_gpu.cc index 5b542af66..332097aae 100644 --- a/sxt/proof/sumcheck2/fold_gpu.cc +++ b/sxt/proof/sumcheck2/fold_gpu.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/fold_gpu.h" diff --git a/sxt/proof/sumcheck2/fold_gpu.h b/sxt/proof/sumcheck2/fold_gpu.h index 757b7dba2..8dc89d6a7 100644 --- a/sxt/proof/sumcheck2/fold_gpu.h +++ b/sxt/proof/sumcheck2/fold_gpu.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -88,9 +104,8 @@ xena::future<> fold_impl(basct::span mles_p, basct::cspan mles, unsigned n // fold_gpu //-------------------------------------------------------------------------------------------------- template -xena::future<> fold_gpu(basct::span mles_p, - const basit::split_options& split_options, basct::cspan mles, - unsigned n, const T& r) noexcept { +xena::future<> fold_gpu(basct::span mles_p, const basit::split_options& split_options, + basct::cspan mles, unsigned n, const T& r) noexcept { auto num_mles = mles.size() / n; auto num_variables = std::max(basn::ceil_log2(n), 1); auto mid = 1u << (num_variables - 1u); @@ -113,8 +128,8 @@ xena::future<> fold_gpu(basct::span mles_p, } template -xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, - unsigned n, const T& r) noexcept { +xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, unsigned n, + const T& r) noexcept { basit::split_options split_options{ .min_chunk_size = 1024u * 128u, .max_chunk_size = 1024u * 256u, diff --git a/sxt/proof/sumcheck2/fold_gpu.t.cc b/sxt/proof/sumcheck2/fold_gpu.t.cc index ecda894be..d8f82b1cd 100644 --- a/sxt/proof/sumcheck2/fold_gpu.t.cc +++ b/sxt/proof/sumcheck2/fold_gpu.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/fold_gpu.h" #include diff --git a/sxt/proof/sumcheck2/gpu_driver.cc b/sxt/proof/sumcheck2/gpu_driver.cc index ab4d321ef..2951ac714 100644 --- a/sxt/proof/sumcheck2/gpu_driver.cc +++ b/sxt/proof/sumcheck2/gpu_driver.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/gpu_driver.h" diff --git a/sxt/proof/sumcheck2/gpu_driver.h b/sxt/proof/sumcheck2/gpu_driver.h index e20ce32a9..537df4d3a 100644 --- a/sxt/proof/sumcheck2/gpu_driver.h +++ b/sxt/proof/sumcheck2/gpu_driver.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -18,8 +34,7 @@ namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // gpu_driver //-------------------------------------------------------------------------------------------------- -template -class gpu_driver final : public driver { +template class gpu_driver final : public driver { public: struct gpu_workspace final : public workspace { memmg::managed_array mles; @@ -35,8 +50,7 @@ class gpu_driver final : public driver { // driver xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, + make_workspace(basct::cspan mles, basct::cspan> product_table, basct::cspan product_terms, unsigned n) const noexcept override { auto ws = std::make_unique(); diff --git a/sxt/proof/sumcheck2/gpu_driver.t.cc b/sxt/proof/sumcheck2/gpu_driver.t.cc index b6168a7a7..1a58e5823 100644 --- a/sxt/proof/sumcheck2/gpu_driver.t.cc +++ b/sxt/proof/sumcheck2/gpu_driver.t.cc @@ -1,7 +1,23 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/gpu_driver.h" -#include "sxt/proof/sumcheck2/driver_test.h" #include "sxt/base/test/unit_test.h" +#include "sxt/proof/sumcheck2/driver_test.h" using namespace sxt; using namespace sxt::prfsk2; diff --git a/sxt/proof/sumcheck2/mle_utility.cc b/sxt/proof/sumcheck2/mle_utility.cc index df389d796..c4bf89222 100644 --- a/sxt/proof/sumcheck2/mle_utility.cc +++ b/sxt/proof/sumcheck2/mle_utility.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/mle_utility.h" diff --git a/sxt/proof/sumcheck2/mle_utility.h b/sxt/proof/sumcheck2/mle_utility.h index ccaa550d9..d5b2e3413 100644 --- a/sxt/proof/sumcheck2/mle_utility.h +++ b/sxt/proof/sumcheck2/mle_utility.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -20,8 +36,7 @@ namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- template void copy_partial_mles(memmg::managed_array& partial_mles, basdv::stream& stream, - basct::cspan mles, unsigned n, unsigned a, - unsigned b) noexcept { + basct::cspan mles, unsigned n, unsigned a, unsigned b) noexcept { auto num_variables = std::max(basn::ceil_log2(n), 1); auto mid = 1u << (num_variables - 1u); auto num_mles = mles.size() / n; @@ -55,9 +70,8 @@ void copy_partial_mles(memmg::managed_array& partial_mles, basdv::stream& str // copy_folded_mles //-------------------------------------------------------------------------------------------------- template -void copy_folded_mles(basct::span host_mles, basdv::stream& stream, - basct::cspan device_mles, unsigned np, unsigned a, - unsigned b) noexcept { +void copy_folded_mles(basct::span host_mles, basdv::stream& stream, basct::cspan device_mles, + unsigned np, unsigned a, unsigned b) noexcept { auto num_mles = host_mles.size() / np; auto slice_n = device_mles.size() / num_mles; auto slice_np = b - a; @@ -78,8 +92,7 @@ void copy_folded_mles(basct::span host_mles, basdv::stream& stream, //-------------------------------------------------------------------------------------------------- // get_gpu_memory_fraction //-------------------------------------------------------------------------------------------------- -template -double get_gpu_memory_fraction(basct::cspan mles) noexcept { +template double get_gpu_memory_fraction(basct::cspan mles) noexcept { auto total_memory = static_cast(basdv::get_total_device_memory()); return static_cast(mles.size() * sizeof(T)) / total_memory; } diff --git a/sxt/proof/sumcheck2/mle_utility.t.cc b/sxt/proof/sumcheck2/mle_utility.t.cc index 222d79a1c..b0b53d1b6 100644 --- a/sxt/proof/sumcheck2/mle_utility.t.cc +++ b/sxt/proof/sumcheck2/mle_utility.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/mle_utility.h" #include diff --git a/sxt/proof/sumcheck2/polynomial_mapper.cc b/sxt/proof/sumcheck2/polynomial_mapper.cc index 6884b6798..cb2f5c790 100644 --- a/sxt/proof/sumcheck2/polynomial_mapper.cc +++ b/sxt/proof/sumcheck2/polynomial_mapper.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/polynomial_mapper.h" diff --git a/sxt/proof/sumcheck2/polynomial_mapper.h b/sxt/proof/sumcheck2/polynomial_mapper.h index 932506fdc..b8eab6ed6 100644 --- a/sxt/proof/sumcheck2/polynomial_mapper.h +++ b/sxt/proof/sumcheck2/polynomial_mapper.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "sxt/base/field/element.h" diff --git a/sxt/proof/sumcheck2/polynomial_mapper.t.cc b/sxt/proof/sumcheck2/polynomial_mapper.t.cc index 38b38a767..57b837240 100644 --- a/sxt/proof/sumcheck2/polynomial_mapper.t.cc +++ b/sxt/proof/sumcheck2/polynomial_mapper.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/polynomial_mapper.h" #include "sxt/base/test/unit_test.h" diff --git a/sxt/proof/sumcheck2/polynomial_reducer.cc b/sxt/proof/sumcheck2/polynomial_reducer.cc index 42f9483b2..47e62134d 100644 --- a/sxt/proof/sumcheck2/polynomial_reducer.cc +++ b/sxt/proof/sumcheck2/polynomial_reducer.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/polynomial_reducer.h" diff --git a/sxt/proof/sumcheck2/polynomial_reducer.h b/sxt/proof/sumcheck2/polynomial_reducer.h index 426be8716..565fb1f49 100644 --- a/sxt/proof/sumcheck2/polynomial_reducer.h +++ b/sxt/proof/sumcheck2/polynomial_reducer.h @@ -1,7 +1,23 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once -#include "sxt/base/macro/cuda_callable.h" #include "sxt/base/field/element.h" +#include "sxt/base/macro/cuda_callable.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/polynomial_utility.cc b/sxt/proof/sumcheck2/polynomial_utility.cc index f69c84beb..701fe438d 100644 --- a/sxt/proof/sumcheck2/polynomial_utility.cc +++ b/sxt/proof/sumcheck2/polynomial_utility.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/polynomial_utility.h" diff --git a/sxt/proof/sumcheck2/polynomial_utility.h b/sxt/proof/sumcheck2/polynomial_utility.h index 1f9dc59d2..3ae6aa949 100644 --- a/sxt/proof/sumcheck2/polynomial_utility.h +++ b/sxt/proof/sumcheck2/polynomial_utility.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -14,8 +30,7 @@ namespace sxt::prfsk2 { // f_a(X) = a[0] + a[1] * X + a[2] * X^2 + ... // compute the sum // f_a(0) + f_a(1) -template -void sum_polynomial_01(T& e, basct::cspan polynomial) noexcept { +template void sum_polynomial_01(T& e, basct::cspan polynomial) noexcept { if (polynomial.empty()) { e = T{}; return; @@ -48,9 +63,8 @@ void evaluate_polynomial(T& e, basct::cspan polynomial, const T& x) noexcept // expand_products //-------------------------------------------------------------------------------------------------- template -CUDA_CALLABLE -void expand_products(basct::span p, const T* mles, unsigned n, - unsigned step, basct::cspan terms) noexcept { +CUDA_CALLABLE void expand_products(basct::span p, const T* mles, unsigned n, unsigned step, + basct::cspan terms) noexcept { auto num_terms = terms.size(); assert( // clang-format off @@ -89,9 +103,8 @@ void expand_products(basct::span p, const T* mles, unsigned n, // partial_expand_products //-------------------------------------------------------------------------------------------------- template -CUDA_CALLABLE -void partial_expand_products(basct::span p, const T* mles, unsigned n, - basct::cspan terms) noexcept { +CUDA_CALLABLE void partial_expand_products(basct::span p, const T* mles, unsigned n, + basct::cspan terms) noexcept { auto num_terms = terms.size(); assert( // clang-format off diff --git a/sxt/proof/sumcheck2/polynomial_utility.t.cc b/sxt/proof/sumcheck2/polynomial_utility.t.cc index f08dc9829..4081ed049 100644 --- a/sxt/proof/sumcheck2/polynomial_utility.t.cc +++ b/sxt/proof/sumcheck2/polynomial_utility.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/polynomial_utility.h" #include diff --git a/sxt/proof/sumcheck2/proof_computation.cc b/sxt/proof/sumcheck2/proof_computation.cc index 8d7efecf6..a87c23c27 100644 --- a/sxt/proof/sumcheck2/proof_computation.cc +++ b/sxt/proof/sumcheck2/proof_computation.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/proof_computation.h" diff --git a/sxt/proof/sumcheck2/proof_computation.h b/sxt/proof/sumcheck2/proof_computation.h index a0dd847e5..727542e4c 100644 --- a/sxt/proof/sumcheck2/proof_computation.h +++ b/sxt/proof/sumcheck2/proof_computation.h @@ -1,23 +1,37 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "sxt/base/error/assert.h" +#include "sxt/base/field/element.h" #include "sxt/base/num/ceil_log2.h" #include "sxt/execution/async/coroutine.h" -#include "sxt/base/field/element.h" +#include "sxt/execution/async/future.h" #include "sxt/proof/sumcheck2/driver.h" #include "sxt/proof/sumcheck2/sumcheck_transcript.h" -#include "sxt/execution/async/future.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // prove_sum //-------------------------------------------------------------------------------------------------- template -xena::future<> prove_sum(basct::span polynomials, - basct::span evaluation_point, +xena::future<> prove_sum(basct::span polynomials, basct::span evaluation_point, sumcheck_transcript& transcript, const driver& drv, - basct::cspan mles, - basct::cspan> product_table, + basct::cspan mles, basct::cspan> product_table, basct::cspan product_terms, unsigned n) noexcept { SXT_RELEASE_ASSERT(0 < n); auto num_variables = std::max(basn::ceil_log2(n), 1); @@ -53,4 +67,4 @@ xena::future<> prove_sum(basct::span polynomials, } } } -} // namespace sxt:prfsk2 +} // namespace sxt::prfsk2 diff --git a/sxt/proof/sumcheck2/proof_computation.t.cc b/sxt/proof/sumcheck2/proof_computation.t.cc index caa5f76eb..022b501a2 100644 --- a/sxt/proof/sumcheck2/proof_computation.t.cc +++ b/sxt/proof/sumcheck2/proof_computation.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/proof_computation.h" #include @@ -163,7 +179,7 @@ static void test_proof(const driver& drv) noexcept { prft::transcript base_transcript{"abc"}; reference_transcript transcript{base_transcript}; auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, n); + product_terms, n); xens::get_scheduler().run(); } @@ -174,7 +190,7 @@ static void test_proof(const driver& drv) noexcept { s25t::element expected_sum; sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - polynomials, polynomial_length - 1u); + polynomials, polynomial_length - 1u); REQUIRE(valid); } @@ -186,7 +202,7 @@ static void test_proof(const driver& drv) noexcept { sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); polynomials[polynomials.size() - 1] = polynomials[0] + polynomials[1]; auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - polynomials, polynomial_length - 1u); + polynomials, polynomial_length - 1u); REQUIRE(!valid); } } diff --git a/sxt/proof/sumcheck2/reduction_gpu.cc b/sxt/proof/sumcheck2/reduction_gpu.cc index ff9abf6af..1b4232869 100644 --- a/sxt/proof/sumcheck2/reduction_gpu.cc +++ b/sxt/proof/sumcheck2/reduction_gpu.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/reduction_gpu.h" diff --git a/sxt/proof/sumcheck2/reduction_gpu.h b/sxt/proof/sumcheck2/reduction_gpu.h index 3624e5bee..638c17802 100644 --- a/sxt/proof/sumcheck2/reduction_gpu.h +++ b/sxt/proof/sumcheck2/reduction_gpu.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "sxt/algorithm/base/identity_mapper.h" @@ -21,8 +37,7 @@ namespace sxt::prfsk2 { // reduction_kernel //-------------------------------------------------------------------------------------------------- template -__global__ static void reduction_kernel(T* __restrict__ out, - const T* __restrict__ partials, +__global__ static void reduction_kernel(T* __restrict__ out, const T* __restrict__ partials, unsigned n) noexcept { auto thread_index = threadIdx.x; auto block_index = blockIdx.x; diff --git a/sxt/proof/sumcheck2/reduction_gpu.t.cc b/sxt/proof/sumcheck2/reduction_gpu.t.cc index 6fd017a1f..4d7659d53 100644 --- a/sxt/proof/sumcheck2/reduction_gpu.t.cc +++ b/sxt/proof/sumcheck2/reduction_gpu.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/reduction_gpu.h" #include diff --git a/sxt/proof/sumcheck2/reference_transcript.cc b/sxt/proof/sumcheck2/reference_transcript.cc index 44733b41e..f49c8a511 100644 --- a/sxt/proof/sumcheck2/reference_transcript.cc +++ b/sxt/proof/sumcheck2/reference_transcript.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/reference_transcript.h" diff --git a/sxt/proof/sumcheck2/reference_transcript.h b/sxt/proof/sumcheck2/reference_transcript.h index 3c75da067..be65c6dc4 100644 --- a/sxt/proof/sumcheck2/reference_transcript.h +++ b/sxt/proof/sumcheck2/reference_transcript.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "sxt/proof/sumcheck2/sumcheck_transcript.h" @@ -8,11 +24,9 @@ namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // reference_transcript //-------------------------------------------------------------------------------------------------- -template -class reference_transcript final : public sumcheck_transcript { +template class reference_transcript final : public sumcheck_transcript { public: - explicit reference_transcript(prft::transcript& transcript) noexcept - : transcript_{transcript} {} + explicit reference_transcript(prft::transcript& transcript) noexcept : transcript_{transcript} {} void init(size_t num_variables, size_t round_degree) noexcept { prft::set_domain(transcript_, "sumcheck proof v1"); diff --git a/sxt/proof/sumcheck2/reference_transcript.t.cc b/sxt/proof/sumcheck2/reference_transcript.t.cc index 307647110..adac2b364 100644 --- a/sxt/proof/sumcheck2/reference_transcript.t.cc +++ b/sxt/proof/sumcheck2/reference_transcript.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/reference_transcript.h" #include "sxt/base/test/unit_test.h" diff --git a/sxt/proof/sumcheck2/sum_gpu.cc b/sxt/proof/sumcheck2/sum_gpu.cc index 89c02f6fd..26d71a8af 100644 --- a/sxt/proof/sumcheck2/sum_gpu.cc +++ b/sxt/proof/sumcheck2/sum_gpu.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/sum_gpu.h" diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck2/sum_gpu.h index 80e4bcc59..a5e8fc909 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck2/sum_gpu.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include @@ -63,11 +79,11 @@ __device__ static void partial_sum_kernel_impl(T* __restrict__ shared_data, // partial_sum_kernel //-------------------------------------------------------------------------------------------------- template -__global__ static void -partial_sum_kernel(T* __restrict__ out, const T* __restrict__ mles, - const std::pair* __restrict__ product_table, - const unsigned* __restrict__ product_terms, unsigned num_coefficients, - unsigned split, unsigned n) noexcept { +__global__ static void partial_sum_kernel(T* __restrict__ out, const T* __restrict__ mles, + const std::pair* __restrict__ product_table, + const unsigned* __restrict__ product_terms, + unsigned num_coefficients, unsigned split, + unsigned n) noexcept { auto product_index = blockIdx.y; auto num_terms = product_table[product_index].second; auto thread_index = threadIdx.x; @@ -105,8 +121,7 @@ partial_sum_kernel(T* __restrict__ out, const T* __restrict__ mles, // partial_sum //-------------------------------------------------------------------------------------------------- template -static xena::future<> partial_sum(basct::span p, basdv::stream& stream, - basct::cspan mles, +static xena::future<> partial_sum(basct::span p, basdv::stream& stream, basct::cspan mles, basct::cspan> product_table, basct::cspan product_terms, unsigned split, unsigned n) noexcept { @@ -116,8 +131,7 @@ static xena::future<> partial_sum(basct::span p, basdv::stream& stream, memr::async_device_resource resource{stream}; // partials - memmg::managed_array partials{num_coefficients * dims.num_blocks * num_products, - &resource}; + memmg::managed_array partials{num_coefficients * dims.num_blocks * num_products, &resource}; xenk::launch_kernel(dims.block_size, [&]( std::integral_constant) noexcept { partial_sum_kernel<<>>( @@ -164,8 +178,8 @@ xena::future<> sum_gpu(basct::span p, device_cache& cache, // compute memmg::managed_array partial_p(num_coefficients); - co_await partial_sum(partial_p, stream, partial_mles, product_table, product_terms, split, - np); + co_await partial_sum(partial_p, stream, partial_mles, product_table, product_terms, + split, np); // fill in the result if (counter == 0) { @@ -182,8 +196,8 @@ xena::future<> sum_gpu(basct::span p, device_cache& cache, } template -xena::future<> sum_gpu(basct::span p, device_cache& cache, - basct::cspan mles, unsigned n) noexcept { +xena::future<> sum_gpu(basct::span p, device_cache& cache, basct::cspan mles, + unsigned n) noexcept { basit::split_options options{ .min_chunk_size = 100'000u, .max_chunk_size = 200'000u, diff --git a/sxt/proof/sumcheck2/sum_gpu.t.cc b/sxt/proof/sumcheck2/sum_gpu.t.cc index 67740bc26..c6759ab61 100644 --- a/sxt/proof/sumcheck2/sum_gpu.t.cc +++ b/sxt/proof/sumcheck2/sum_gpu.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/sum_gpu.h" #include diff --git a/sxt/proof/sumcheck2/sumcheck_random.cc b/sxt/proof/sumcheck2/sumcheck_random.cc index 19dd9bba3..17c32d430 100644 --- a/sxt/proof/sumcheck2/sumcheck_random.cc +++ b/sxt/proof/sumcheck2/sumcheck_random.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/sumcheck_random.h" #include diff --git a/sxt/proof/sumcheck2/sumcheck_random.h b/sxt/proof/sumcheck2/sumcheck_random.h index 780c7ec4d..7bfec674f 100644 --- a/sxt/proof/sumcheck2/sumcheck_random.h +++ b/sxt/proof/sumcheck2/sumcheck_random.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include diff --git a/sxt/proof/sumcheck2/sumcheck_transcript.cc b/sxt/proof/sumcheck2/sumcheck_transcript.cc index ae78804db..d04730afa 100644 --- a/sxt/proof/sumcheck2/sumcheck_transcript.cc +++ b/sxt/proof/sumcheck2/sumcheck_transcript.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/sumcheck_transcript.h" diff --git a/sxt/proof/sumcheck2/sumcheck_transcript.h b/sxt/proof/sumcheck2/sumcheck_transcript.h index a23e98bb8..8447212e0 100644 --- a/sxt/proof/sumcheck2/sumcheck_transcript.h +++ b/sxt/proof/sumcheck2/sumcheck_transcript.h @@ -1,14 +1,29 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once -#include "sxt/base/field/element.h" #include "sxt/base/container/span.h" +#include "sxt/base/field/element.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // sumcheck_transcript //-------------------------------------------------------------------------------------------------- -template -class sumcheck_transcript { +template class sumcheck_transcript { public: virtual ~sumcheck_transcript() noexcept = default; diff --git a/sxt/proof/sumcheck2/verification.cc b/sxt/proof/sumcheck2/verification.cc index 0746fc7e9..e1d4e844f 100644 --- a/sxt/proof/sumcheck2/verification.cc +++ b/sxt/proof/sumcheck2/verification.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/verification.h" diff --git a/sxt/proof/sumcheck2/verification.h b/sxt/proof/sumcheck2/verification.h index 6acac8863..40166a92b 100644 --- a/sxt/proof/sumcheck2/verification.h +++ b/sxt/proof/sumcheck2/verification.h @@ -1,18 +1,33 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "sxt/base/container/span.h" #include "sxt/base/error/assert.h" #include "sxt/base/log/log.h" -#include "sxt/proof/sumcheck2/sumcheck_transcript.h" #include "sxt/proof/sumcheck2/polynomial_utility.h" +#include "sxt/proof/sumcheck2/sumcheck_transcript.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- // verify_sumcheck_no_evaluation //-------------------------------------------------------------------------------------------------- template -bool verify_sumcheck_no_evaluation(T& expected_sum, - basct::span evaluation_point, +bool verify_sumcheck_no_evaluation(T& expected_sum, basct::span evaluation_point, sumcheck_transcript& transcript, basct::cspan round_polynomials, unsigned round_degree) noexcept { diff --git a/sxt/proof/sumcheck2/verification.t.cc b/sxt/proof/sumcheck2/verification.t.cc index 4d6502753..f48605338 100644 --- a/sxt/proof/sumcheck2/verification.t.cc +++ b/sxt/proof/sumcheck2/verification.t.cc @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/verification.h" #include @@ -23,13 +39,13 @@ TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { SECTION("verification fails if dimensions don't match") { auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 2); + round_polynomials, 2); REQUIRE(!res); } SECTION("we can verify a single round") { auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); + round_polynomials, 1); REQUIRE(res); REQUIRE(evaluation_point[0] != 0x0_s25); } @@ -37,7 +53,7 @@ TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { SECTION("verification fails if the round polynomial doesn't match the sum") { round_polynomials[1] = 0x1_s25; auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); + round_polynomials, 1); REQUIRE(!res); } @@ -67,7 +83,7 @@ TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { // prove evaluation_point.resize(2); auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); + round_polynomials, 1); REQUIRE(evaluation_point[0] == r); REQUIRE(res); } @@ -92,7 +108,7 @@ TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { // prove evaluation_point.resize(2); auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); + round_polynomials, 1); REQUIRE(!res); } @@ -107,7 +123,7 @@ TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { }; expected_sum = 0x3_s25 * -0x2_s25 - 0x7_s25 * 0x4_s25; auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 2); + round_polynomials, 2); REQUIRE(res); REQUIRE(evaluation_point[0] != 0x0_s25); } diff --git a/sxt/proof/sumcheck2/workspace.cc b/sxt/proof/sumcheck2/workspace.cc index f5547993c..32a934d6b 100644 --- a/sxt/proof/sumcheck2/workspace.cc +++ b/sxt/proof/sumcheck2/workspace.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/proof/sumcheck2/workspace.h" diff --git a/sxt/proof/sumcheck2/workspace.h b/sxt/proof/sumcheck2/workspace.h index a422cb703..8fae4355d 100644 --- a/sxt/proof/sumcheck2/workspace.h +++ b/sxt/proof/sumcheck2/workspace.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once namespace sxt::prfsk2 { diff --git a/sxt/scalar25/realization/BUILD b/sxt/scalar25/realization/BUILD index aa4e2ed87..bca94a28a 100644 --- a/sxt/scalar25/realization/BUILD +++ b/sxt/scalar25/realization/BUILD @@ -7,12 +7,12 @@ sxt_cc_component( name = "field", with_test = False, deps = [ - "//sxt/base/field:element", - "//sxt/scalar25/type:element", - "//sxt/scalar25/operation:add", - "//sxt/scalar25/operation:sub", - "//sxt/scalar25/operation:neg", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:muladd", + "//sxt/base/field:element", + "//sxt/scalar25/operation:add", + "//sxt/scalar25/operation:mul", + "//sxt/scalar25/operation:muladd", + "//sxt/scalar25/operation:neg", + "//sxt/scalar25/operation:sub", + "//sxt/scalar25/type:element", ], ) diff --git a/sxt/scalar25/realization/field.cc b/sxt/scalar25/realization/field.cc index 6a4616b8e..a2452d2e1 100644 --- a/sxt/scalar25/realization/field.cc +++ b/sxt/scalar25/realization/field.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/scalar25/realization/field.h" diff --git a/sxt/scalar25/realization/field.h b/sxt/scalar25/realization/field.h index c4f5501cc..8ce413ac8 100644 --- a/sxt/scalar25/realization/field.h +++ b/sxt/scalar25/realization/field.h @@ -1,3 +1,19 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "sxt/base/field/element.h" @@ -8,5 +24,4 @@ #include "sxt/scalar25/operation/sub.h" #include "sxt/scalar25/type/element.h" -static_assert( - sxt::basfld::element); +static_assert(sxt::basfld::element); diff --git a/sxt/scalar25/type/BUILD b/sxt/scalar25/type/BUILD index 22a845243..cf154a0b9 100644 --- a/sxt/scalar25/type/BUILD +++ b/sxt/scalar25/type/BUILD @@ -17,7 +17,7 @@ sxt_cc_component( "//sxt/base/test:unit_test", ], deps = [ - ":operation_adl_stub", + ":operation_adl_stub", "//sxt/base/macro:cuda_callable", ], ) diff --git a/sxt/scalar25/type/operation_adl_stub.cc b/sxt/scalar25/type/operation_adl_stub.cc index 82f1feee6..95c90e307 100644 --- a/sxt/scalar25/type/operation_adl_stub.cc +++ b/sxt/scalar25/type/operation_adl_stub.cc @@ -1 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "sxt/scalar25/type/operation_adl_stub.h" diff --git a/sxt/scalar25/type/operation_adl_stub.h b/sxt/scalar25/type/operation_adl_stub.h index 134fafbf6..34a7e577f 100644 --- a/sxt/scalar25/type/operation_adl_stub.h +++ b/sxt/scalar25/type/operation_adl_stub.h @@ -1,6 +1,21 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once - namespace sxt::s25o { //-------------------------------------------------------------------------------------------------- // operation_adl_stub From 9dc06dd7e0d411f83f7fe6c16c6f65b808a65100 Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 18:59:59 -0800 Subject: [PATCH 71/83] update benchmark --- benchmark/sumcheck/BUILD | 8 ++++---- benchmark/sumcheck/benchmark.m.cc | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/benchmark/sumcheck/BUILD b/benchmark/sumcheck/BUILD index 14d521e5f..d7e31ca8f 100644 --- a/benchmark/sumcheck/BUILD +++ b/benchmark/sumcheck/BUILD @@ -14,11 +14,11 @@ sxt_cc_benchmark( "//sxt/execution/async:future", "//sxt/execution/schedule:scheduler", "//sxt/memory/management:managed_array", - "//sxt/proof/sumcheck:gpu_driver", - "//sxt/proof/sumcheck:proof_computation", - "//sxt/proof/sumcheck:reference_transcript", + "//sxt/proof/sumcheck2:gpu_driver", + "//sxt/proof/sumcheck2:proof_computation", + "//sxt/proof/sumcheck2:reference_transcript", "//sxt/proof/transcript", "//sxt/scalar25/random:element", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", ], ) diff --git a/benchmark/sumcheck/benchmark.m.cc b/benchmark/sumcheck/benchmark.m.cc index 348e43e4f..60697ebb6 100644 --- a/benchmark/sumcheck/benchmark.m.cc +++ b/benchmark/sumcheck/benchmark.m.cc @@ -27,12 +27,12 @@ #include "sxt/execution/async/future.h" #include "sxt/execution/schedule/scheduler.h" #include "sxt/memory/management/managed_array.h" -#include "sxt/proof/sumcheck/gpu_driver.h" -#include "sxt/proof/sumcheck/proof_computation.h" -#include "sxt/proof/sumcheck/reference_transcript.h" +#include "sxt/proof/sumcheck2/gpu_driver.h" +#include "sxt/proof/sumcheck2/proof_computation.h" +#include "sxt/proof/sumcheck2/reference_transcript.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/scalar25/random/element.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" using namespace sxt; @@ -111,13 +111,13 @@ int main(int argc, char* argv[]) { memmg::managed_array polynomials((p.degree + 1u) * num_rounds); memmg::managed_array evaluation_point(num_rounds); prft::transcript base_transcript{"abc123"}; - prfsk::reference_transcript transcript{base_transcript}; - prfsk::gpu_driver drv; + prfsk2::reference_transcript transcript{base_transcript}; + prfsk2::gpu_driver drv; // initial run { - auto fut = prfsk::prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, p.n); + auto fut = prfsk2::prove_sum(polynomials, evaluation_point, transcript, drv, + mles, product_table, product_terms, p.n); xens::get_scheduler().run(); } @@ -125,8 +125,8 @@ int main(int argc, char* argv[]) { double elapse = 0; for (unsigned i = 0; i < (p.num_samples + 1u); ++i) { auto t1 = std::chrono::steady_clock::now(); - auto fut = prfsk::prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, p.n); + auto fut = prfsk2::prove_sum(polynomials, evaluation_point, transcript, drv, + mles, product_table, product_terms, p.n); xens::get_scheduler().run(); auto t2 = std::chrono::steady_clock::now(); auto duration = std::chrono::duration_cast(t2 - t1); From e50c455c98cec822925362fc9a661ea975a80d8b Mon Sep 17 00:00:00 2001 From: rnburn Date: Mon, 24 Feb 2025 19:22:37 -0800 Subject: [PATCH 72/83] rework sumcheck --- sxt/cbindings/backend/BUILD | 1 + .../backend/callback_sumcheck_transcript.h | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/sxt/cbindings/backend/BUILD b/sxt/cbindings/backend/BUILD index a4a2e1b42..68e7860d5 100644 --- a/sxt/cbindings/backend/BUILD +++ b/sxt/cbindings/backend/BUILD @@ -10,6 +10,7 @@ sxt_cc_component( with_test = False, deps = [ "//sxt/proof/sumcheck:sumcheck_transcript", + "//sxt/proof/sumcheck2:sumcheck_transcript", ], ) diff --git a/sxt/cbindings/backend/callback_sumcheck_transcript.h b/sxt/cbindings/backend/callback_sumcheck_transcript.h index 8f4c9c908..ba72f7bd7 100644 --- a/sxt/cbindings/backend/callback_sumcheck_transcript.h +++ b/sxt/cbindings/backend/callback_sumcheck_transcript.h @@ -17,6 +17,7 @@ #pragma once #include "sxt/proof/sumcheck/sumcheck_transcript.h" +#include "sxt/proof/sumcheck2/sumcheck_transcript.h" namespace sxt::cbnbck { //-------------------------------------------------------------------------------------------------- @@ -35,6 +36,24 @@ class callback_sumcheck_transcript final : public prfsk::sumcheck_transcript { f_(&r, context_, polynomial.data(), static_cast(polynomial.size())); } +private: + callback_t f_; + void* context_; +}; + +template +class callback_sumcheck_transcript2 final : public prfsk2::sumcheck_transcript { +public: + using callback_t = void (*)(T* r, void* context, T* polynomial, unsigned polynomial_len); + + callback_sumcheck_transcript2(callback_t f, void* context) noexcept : f_{f}, context_{context} {} + + void init(size_t /*num_variables*/, size_t /*round_degree*/) noexcept override {} + + void round_challenge(T& r, basct::cspan polynomial) noexcept override { + f_(&r, context_, polynomial.data(), static_cast(polynomial.size())); + } + private: callback_t f_; void* context_; From 913176aaa61b6fbf005e2c7508437ed37ec24072 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 13:07:23 -0800 Subject: [PATCH 73/83] rework sumcheck --- sxt/cbindings/backend/BUILD | 1 + .../backend/callback_sumcheck_transcript.h | 2 +- sxt/cbindings/backend/gpu_backend.cc | 41 ++++++++++++++++++- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/sxt/cbindings/backend/BUILD b/sxt/cbindings/backend/BUILD index 68e7860d5..e3aa6ee29 100644 --- a/sxt/cbindings/backend/BUILD +++ b/sxt/cbindings/backend/BUILD @@ -100,6 +100,7 @@ sxt_cc_component( "//sxt/proof/inner_product:gpu_driver", "//sxt/proof/sumcheck:chunked_gpu_driver", "//sxt/proof/sumcheck:proof_computation", + "//sxt/scalar25/realization:field", ], with_test = False, deps = [ diff --git a/sxt/cbindings/backend/callback_sumcheck_transcript.h b/sxt/cbindings/backend/callback_sumcheck_transcript.h index ba72f7bd7..6924c5687 100644 --- a/sxt/cbindings/backend/callback_sumcheck_transcript.h +++ b/sxt/cbindings/backend/callback_sumcheck_transcript.h @@ -44,7 +44,7 @@ class callback_sumcheck_transcript final : public prfsk::sumcheck_transcript { template class callback_sumcheck_transcript2 final : public prfsk2::sumcheck_transcript { public: - using callback_t = void (*)(T* r, void* context, T* polynomial, unsigned polynomial_len); + using callback_t = void (*)(T* r, void* context, const T* polynomial, unsigned polynomial_len); callback_sumcheck_transcript2(callback_t f, void* context) noexcept : f_{f}, context_{context} {} diff --git a/sxt/cbindings/backend/gpu_backend.cc b/sxt/cbindings/backend/gpu_backend.cc index 6dc4208c4..899b1b320 100644 --- a/sxt/cbindings/backend/gpu_backend.cc +++ b/sxt/cbindings/backend/gpu_backend.cc @@ -65,11 +65,13 @@ #include "sxt/proof/inner_product/proof_descriptor.h" #include "sxt/proof/sumcheck/chunked_gpu_driver.h" #include "sxt/proof/sumcheck/proof_computation.h" +#include "sxt/proof/sumcheck2/chunked_gpu_driver.h" +#include "sxt/proof/sumcheck2/proof_computation.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/ristretto/operation/compression.h" #include "sxt/ristretto/type/compressed_element.h" #include "sxt/ristretto/type/literal.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/seqcommit/generator/precomputed_generators.h" using sxt::rstt::operator""_rs; @@ -144,6 +146,43 @@ void gpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi product_table_span, product_terms_span, descriptor.n); xens::get_scheduler().run(); }); + return; + cbnb::switch_field_type( + static_cast(field_id), [&](std::type_identity) noexcept { + static_assert(std::same_as, "only support curve-255 right now"); + // transcript + callback_sumcheck_transcript2 transcript{ + reinterpret_cast::callback_t>( + const_cast(transcript_callback)), + transcript_context}; + + // prove + basct::span polynomials_span{ + static_cast(polynomials), + (descriptor.round_degree + 1u) * num_variables, + }; + basct::span evaluation_point_span{ + static_cast(evaluation_point), + num_variables, + }; + basct::cspan mles_span{ + static_cast(descriptor.mles), + descriptor.n * descriptor.num_mles, + }; + basct::cspan> product_table_span{ + static_cast*>(descriptor.product_table), + descriptor.num_products, + }; + basct::cspan product_terms_span{ + descriptor.product_terms, + descriptor.num_product_terms, + }; + prfsk2::chunked_gpu_driver drv; + auto fut = + prfsk2::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, + mles_span, product_table_span, product_terms_span, descriptor.n); + xens::get_scheduler().run(); + }); } //-------------------------------------------------------------------------------------------------- From a0dde1cd33a4e143a53fb3200f53f1e796dc48b2 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 13:10:02 -0800 Subject: [PATCH 74/83] rework sumcheck --- sxt/cbindings/backend/BUILD | 4 +-- sxt/cbindings/backend/gpu_backend.cc | 39 ---------------------------- 2 files changed, 2 insertions(+), 41 deletions(-) diff --git a/sxt/cbindings/backend/BUILD b/sxt/cbindings/backend/BUILD index e3aa6ee29..19eb3a09f 100644 --- a/sxt/cbindings/backend/BUILD +++ b/sxt/cbindings/backend/BUILD @@ -98,8 +98,8 @@ sxt_cc_component( "//sxt/proof/inner_product:proof_descriptor", "//sxt/proof/inner_product:proof_computation", "//sxt/proof/inner_product:gpu_driver", - "//sxt/proof/sumcheck:chunked_gpu_driver", - "//sxt/proof/sumcheck:proof_computation", + "//sxt/proof/sumcheck2:chunked_gpu_driver", + "//sxt/proof/sumcheck2:proof_computation", "//sxt/scalar25/realization:field", ], with_test = False, diff --git a/sxt/cbindings/backend/gpu_backend.cc b/sxt/cbindings/backend/gpu_backend.cc index 899b1b320..e28bc62a8 100644 --- a/sxt/cbindings/backend/gpu_backend.cc +++ b/sxt/cbindings/backend/gpu_backend.cc @@ -63,8 +63,6 @@ #include "sxt/proof/inner_product/gpu_driver.h" #include "sxt/proof/inner_product/proof_computation.h" #include "sxt/proof/inner_product/proof_descriptor.h" -#include "sxt/proof/sumcheck/chunked_gpu_driver.h" -#include "sxt/proof/sumcheck/proof_computation.h" #include "sxt/proof/sumcheck2/chunked_gpu_driver.h" #include "sxt/proof/sumcheck2/proof_computation.h" #include "sxt/proof/transcript/transcript.h" @@ -110,43 +108,6 @@ void gpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi const cbnb::sumcheck_descriptor& descriptor, void* transcript_callback, void* transcript_context) noexcept { auto num_variables = static_cast(std::max(basn::ceil_log2(descriptor.n), 1)); - cbnb::switch_field_type( - static_cast(field_id), [&](std::type_identity) noexcept { - static_assert(std::same_as, "only support curve-255 right now"); - // transcript - callback_sumcheck_transcript transcript{ - reinterpret_cast( - const_cast(transcript_callback)), - transcript_context}; - - // prove - basct::span polynomials_span{ - static_cast(polynomials), - (descriptor.round_degree + 1u) * num_variables, - }; - basct::span evaluation_point_span{ - static_cast(evaluation_point), - num_variables, - }; - basct::cspan mles_span{ - static_cast(descriptor.mles), - descriptor.n * descriptor.num_mles, - }; - basct::cspan> product_table_span{ - static_cast*>(descriptor.product_table), - descriptor.num_products, - }; - basct::cspan product_terms_span{ - descriptor.product_terms, - descriptor.num_product_terms, - }; - prfsk::chunked_gpu_driver drv; - auto fut = - prfsk::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, mles_span, - product_table_span, product_terms_span, descriptor.n); - xens::get_scheduler().run(); - }); - return; cbnb::switch_field_type( static_cast(field_id), [&](std::type_identity) noexcept { static_assert(std::same_as, "only support curve-255 right now"); From 25693c4747d16032e5e964d0677ebfaf4ce3bbaf Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 13:48:38 -0800 Subject: [PATCH 75/83] rework sumcheck --- sxt/cbindings/backend/BUILD | 5 +++-- sxt/cbindings/backend/cpu_backend.cc | 21 ++++++++------------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/sxt/cbindings/backend/BUILD b/sxt/cbindings/backend/BUILD index 19eb3a09f..8906c5984 100644 --- a/sxt/cbindings/backend/BUILD +++ b/sxt/cbindings/backend/BUILD @@ -157,8 +157,9 @@ sxt_cc_component( "//sxt/proof/inner_product:proof_descriptor", "//sxt/proof/inner_product:proof_computation", "//sxt/proof/inner_product:cpu_driver", - "//sxt/proof/sumcheck:cpu_driver", - "//sxt/proof/sumcheck:proof_computation", + "//sxt/proof/sumcheck2:cpu_driver", + "//sxt/proof/sumcheck2:proof_computation", + "//sxt/scalar25/realization:field", ], with_test = False, deps = [ diff --git a/sxt/cbindings/backend/cpu_backend.cc b/sxt/cbindings/backend/cpu_backend.cc index 29d33d48e..c8e980ab5 100644 --- a/sxt/cbindings/backend/cpu_backend.cc +++ b/sxt/cbindings/backend/cpu_backend.cc @@ -59,22 +59,18 @@ #include "sxt/proof/inner_product/cpu_driver.h" #include "sxt/proof/inner_product/proof_computation.h" #include "sxt/proof/inner_product/proof_descriptor.h" -#include "sxt/proof/sumcheck/cpu_driver.h" -#include "sxt/proof/sumcheck/proof_computation.h" +#include "sxt/proof/sumcheck2/cpu_driver.h" +#include "sxt/proof/sumcheck2/proof_computation.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/ristretto/operation/compression.h" #include "sxt/ristretto/type/compressed_element.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/seqcommit/generator/precomputed_generators.h" namespace sxt::cbnbck { //-------------------------------------------------------------------------------------------------- // prove_sumcheck //-------------------------------------------------------------------------------------------------- -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-parameter" void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsigned field_id, const cbnb::sumcheck_descriptor& descriptor, void* transcript_callback, void* transcript_context) noexcept { @@ -83,8 +79,8 @@ void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi static_cast(field_id), [&](std::type_identity) noexcept { static_assert(std::same_as, "only support curve-255 right now"); // transcript - callback_sumcheck_transcript transcript{ - reinterpret_cast( + callback_sumcheck_transcript2 transcript{ + reinterpret_cast::callback_t>( const_cast(transcript_callback)), transcript_context}; @@ -109,14 +105,13 @@ void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi descriptor.product_terms, descriptor.num_product_terms, }; - prfsk::cpu_driver drv; + prfsk2::cpu_driver drv; auto fut = - prfsk::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, mles_span, - product_table_span, product_terms_span, descriptor.n); + prfsk2::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, + mles_span, product_table_span, product_terms_span, descriptor.n); SXT_RELEASE_ASSERT(fut.ready()); }); } -#pragma clang diagnostic pop //-------------------------------------------------------------------------------------------------- // compute_commitments From 3b9dcba8d4258ba2aebe3b7904586c244f61dd0c Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 16:53:22 -0800 Subject: [PATCH 76/83] rework sumcheck --- sxt/cbindings/backend/BUILD | 1 - .../backend/callback_sumcheck_transcript.h | 23 ++----------------- sxt/cbindings/backend/cpu_backend.cc | 4 ++-- sxt/cbindings/backend/gpu_backend.cc | 4 ++-- 4 files changed, 6 insertions(+), 26 deletions(-) diff --git a/sxt/cbindings/backend/BUILD b/sxt/cbindings/backend/BUILD index 8906c5984..26f6cddfe 100644 --- a/sxt/cbindings/backend/BUILD +++ b/sxt/cbindings/backend/BUILD @@ -9,7 +9,6 @@ sxt_cc_component( ], with_test = False, deps = [ - "//sxt/proof/sumcheck:sumcheck_transcript", "//sxt/proof/sumcheck2:sumcheck_transcript", ], ) diff --git a/sxt/cbindings/backend/callback_sumcheck_transcript.h b/sxt/cbindings/backend/callback_sumcheck_transcript.h index 6924c5687..8c35e446a 100644 --- a/sxt/cbindings/backend/callback_sumcheck_transcript.h +++ b/sxt/cbindings/backend/callback_sumcheck_transcript.h @@ -16,37 +16,18 @@ */ #pragma once -#include "sxt/proof/sumcheck/sumcheck_transcript.h" #include "sxt/proof/sumcheck2/sumcheck_transcript.h" namespace sxt::cbnbck { //-------------------------------------------------------------------------------------------------- // callback_sumcheck_transcript //-------------------------------------------------------------------------------------------------- -class callback_sumcheck_transcript final : public prfsk::sumcheck_transcript { -public: - using callback_t = void (*)(s25t::element* r, void* context, const s25t::element* polynomial, - unsigned polynomial_len); - - callback_sumcheck_transcript(callback_t f, void* context) noexcept : f_{f}, context_{context} {} - - void init(size_t /*num_variables*/, size_t /*round_degree*/) noexcept override {} - - void round_challenge(s25t::element& r, basct::cspan polynomial) noexcept override { - f_(&r, context_, polynomial.data(), static_cast(polynomial.size())); - } - -private: - callback_t f_; - void* context_; -}; - template -class callback_sumcheck_transcript2 final : public prfsk2::sumcheck_transcript { +class callback_sumcheck_transcript final : public prfsk2::sumcheck_transcript { public: using callback_t = void (*)(T* r, void* context, const T* polynomial, unsigned polynomial_len); - callback_sumcheck_transcript2(callback_t f, void* context) noexcept : f_{f}, context_{context} {} + callback_sumcheck_transcript(callback_t f, void* context) noexcept : f_{f}, context_{context} {} void init(size_t /*num_variables*/, size_t /*round_degree*/) noexcept override {} diff --git a/sxt/cbindings/backend/cpu_backend.cc b/sxt/cbindings/backend/cpu_backend.cc index c8e980ab5..b245c0c61 100644 --- a/sxt/cbindings/backend/cpu_backend.cc +++ b/sxt/cbindings/backend/cpu_backend.cc @@ -79,8 +79,8 @@ void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi static_cast(field_id), [&](std::type_identity) noexcept { static_assert(std::same_as, "only support curve-255 right now"); // transcript - callback_sumcheck_transcript2 transcript{ - reinterpret_cast::callback_t>( + callback_sumcheck_transcript transcript{ + reinterpret_cast::callback_t>( const_cast(transcript_callback)), transcript_context}; diff --git a/sxt/cbindings/backend/gpu_backend.cc b/sxt/cbindings/backend/gpu_backend.cc index e28bc62a8..6ee151a74 100644 --- a/sxt/cbindings/backend/gpu_backend.cc +++ b/sxt/cbindings/backend/gpu_backend.cc @@ -112,8 +112,8 @@ void gpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi static_cast(field_id), [&](std::type_identity) noexcept { static_assert(std::same_as, "only support curve-255 right now"); // transcript - callback_sumcheck_transcript2 transcript{ - reinterpret_cast::callback_t>( + callback_sumcheck_transcript transcript{ + reinterpret_cast::callback_t>( const_cast(transcript_callback)), transcript_context}; From ecbde3848d0ada2e9cd3714dde6f908a570b11ee Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 16:56:53 -0800 Subject: [PATCH 77/83] rework sumcheck --- cbindings/BUILD | 4 ++-- cbindings/sumcheck.t.cc | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cbindings/BUILD b/cbindings/BUILD index b975cfa9d..56f548a23 100644 --- a/cbindings/BUILD +++ b/cbindings/BUILD @@ -224,9 +224,9 @@ sxt_cc_component( test_deps = [ ":backend", "//sxt/base/test:unit_test", - "//sxt/proof/sumcheck:reference_transcript", + "//sxt/proof/sumcheck2:reference_transcript", "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], deps = [ diff --git a/cbindings/sumcheck.t.cc b/cbindings/sumcheck.t.cc index 282204848..c397ff62e 100644 --- a/cbindings/sumcheck.t.cc +++ b/cbindings/sumcheck.t.cc @@ -20,9 +20,9 @@ #include "cbindings/backend.h" #include "sxt/base/test/unit_test.h" -#include "sxt/proof/sumcheck/reference_transcript.h" +#include "sxt/proof/sumcheck2/reference_transcript.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; @@ -30,7 +30,7 @@ using s25t::operator""_s25; TEST_CASE("we can create sumcheck proofs") { prft::transcript base_transcript{"abc"}; - prfsk::reference_transcript transcript{base_transcript}; + prfsk2::reference_transcript transcript{base_transcript}; std::vector polynomials(2); std::vector evaluation_point(1); @@ -55,7 +55,7 @@ TEST_CASE("we can create sumcheck proofs") { auto f = [](s25t::element* r, void* context, const s25t::element* polynomial, unsigned polynomial_len) noexcept { - static_cast(context)->round_challenge( + static_cast*>(context)->round_challenge( *r, {polynomial, polynomial_len}); }; @@ -70,7 +70,7 @@ TEST_CASE("we can create sumcheck proofs") { REQUIRE(polynomials[1] == mles[1] - mles[0]); { prft::transcript base_transcript_p{"abc"}; - prfsk::reference_transcript transcript_p{base_transcript_p}; + prfsk2::reference_transcript transcript_p{base_transcript_p}; s25t::element r; transcript_p.round_challenge(r, polynomials); REQUIRE(evaluation_point[0] == r); @@ -88,7 +88,7 @@ TEST_CASE("we can create sumcheck proofs") { REQUIRE(polynomials[1] == mles[1] - mles[0]); { prft::transcript base_transcript_p{"abc"}; - prfsk::reference_transcript transcript_p{base_transcript_p}; + prfsk2::reference_transcript transcript_p{base_transcript_p}; s25t::element r; transcript_p.round_challenge(r, polynomials); REQUIRE(evaluation_point[0] == r); From 85659b3d72d6ee7d3949bd6a227367b21612a40e Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 16:59:06 -0800 Subject: [PATCH 78/83] rework sumcheck --- sxt/proof/sumcheck/BUILD | 778 ++++++++++++--------------- sxt/proof/sumcheck/constant.cc | 17 - sxt/proof/sumcheck/constant.h | 26 - sxt/proof/sumcheck/device_cache.cc | 70 --- sxt/proof/sumcheck/device_cache.h | 58 -- sxt/proof/sumcheck/device_cache.t.cc | 76 --- sxt/proof/sumcheck/fold_gpu.cc | 156 ------ sxt/proof/sumcheck/fold_gpu.h | 40 -- sxt/proof/sumcheck/fold_gpu.t.cc | 81 --- 9 files changed, 353 insertions(+), 949 deletions(-) delete mode 100644 sxt/proof/sumcheck/constant.cc delete mode 100644 sxt/proof/sumcheck/constant.h delete mode 100644 sxt/proof/sumcheck/device_cache.cc delete mode 100644 sxt/proof/sumcheck/device_cache.h delete mode 100644 sxt/proof/sumcheck/device_cache.t.cc delete mode 100644 sxt/proof/sumcheck/fold_gpu.cc delete mode 100644 sxt/proof/sumcheck/fold_gpu.h delete mode 100644 sxt/proof/sumcheck/fold_gpu.t.cc diff --git a/sxt/proof/sumcheck/BUILD b/sxt/proof/sumcheck/BUILD index 6fdc8dbd3..99f9cf019 100644 --- a/sxt/proof/sumcheck/BUILD +++ b/sxt/proof/sumcheck/BUILD @@ -1,425 +1,353 @@ -load( - "//bazel:sxt_build_system.bzl", - "sxt_cc_component", -) - -sxt_cc_component( - name = "constant", - with_test = False, -) - -sxt_cc_component( - name = "device_cache", - impl_deps = [ - "//sxt/base/device:memory_utility", - "//sxt/base/device:state", - "//sxt/base/device:stream", - "//sxt/memory/resource:device_resource", - ], - test_deps = [ - "//sxt/base/device:memory_utility", - "//sxt/base/device:stream", - "//sxt/base/device:synchronization", - "//sxt/base/test:unit_test", - "//sxt/scalar25/type:literal", - ], - deps = [ - "//sxt/base/container:span", - "//sxt/base/device:device_map", - "//sxt/memory/management:managed_array", - "//sxt/scalar25/type:element", - ], -) - -sxt_cc_component( - name = "fold_gpu", - impl_deps = [ - ":mle_utility", - "//sxt/algorithm/iteration:kernel_fit", - "//sxt/base/error:assert", - "//sxt/base/device:property", - "//sxt/base/device:memory_utility", - "//sxt/base/device:stream", - "//sxt/base/iterator:split", - "//sxt/base/num:ceil_log2", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:muladd", - "//sxt/scalar25/operation:sub", - "//sxt/execution/async:coroutine", - "//sxt/execution/async:future", - "//sxt/execution/device:for_each", - "//sxt/execution/device:synchronization", - "//sxt/execution/kernel:kernel_dims", - "//sxt/memory/management:managed_array", - "//sxt/memory/resource:async_device_resource", - "//sxt/memory/resource:device_resource", - ], - test_deps = [ - "//sxt/base/iterator:split", - "//sxt/base/test:unit_test", - "//sxt/execution/async:future", - "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - deps = [ - "//sxt/base/container:span", - "//sxt/execution/async:future_fwd", - ], -) - -sxt_cc_component( - name = "mle_utility", - impl_deps = [ - "//sxt/base/container:span_utility", - "//sxt/base/device:memory_utility", - "//sxt/base/device:property", - "//sxt/base/num:divide_up", - "//sxt/base/num:ceil_log2", - "//sxt/memory/management:managed_array", - "//sxt/scalar25/type:element", - ], - test_deps = [ - "//sxt/base/device:stream", - "//sxt/base/device:synchronization", - "//sxt/base/test:unit_test", - "//sxt/memory/management:managed_array", - "//sxt/memory/resource:managed_device_resource", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - deps = [ - "//sxt/base/container:span", - "//sxt/memory/management:managed_array_fwd", - ], -) - -sxt_cc_component( - name = "sum_gpu", - impl_deps = [ - ":constant", - ":device_cache", - ":mle_utility", - ":polynomial_mapper", - ":reduction_gpu", - "//sxt/algorithm/reduction:kernel_fit", - "//sxt/algorithm/reduction:thread_reduction", - "//sxt/base/device:memory_utility", - "//sxt/base/device:stream", - "//sxt/base/device:state", - "//sxt/base/iterator:split", - "//sxt/base/num:ceil_log2", - "//sxt/base/num:constexpr_switch", - "//sxt/execution/async:coroutine", - "//sxt/execution/async:future", - "//sxt/execution/device:for_each", - "//sxt/execution/kernel:kernel_dims", - "//sxt/memory/resource:device_resource", - "//sxt/memory/resource:async_device_resource", - "//sxt/scalar25/operation:add", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/type:element", - ], - test_deps = [ - ":device_cache", - "//sxt/base/iterator:split", - "//sxt/base/test:unit_test", - "//sxt/execution/async:future", - "//sxt/execution/schedule:scheduler", - "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - deps = [ - "//sxt/base/container:span", - "//sxt/execution/async:future_fwd", - ], -) - -sxt_cc_component( - name = "sumcheck_transcript", - with_test = False, - deps = [ - "//sxt/base/container:span", - ], -) - -sxt_cc_component( - name = "reference_transcript", - impl_deps = [ - "//sxt/scalar25/type:element", - "//sxt/proof/transcript:transcript_utility", - ], - test_deps = [ - "//sxt/base/test:unit_test", - "//sxt/scalar25/type:literal", - ], - deps = [ - ":sumcheck_transcript", - "//sxt/proof/transcript", - ], -) - -sxt_cc_component( - name = "sumcheck_random", - impl_deps = [ - "//sxt/base/error:assert", - "//sxt/base/num:fast_random_number_generator", - "//sxt/scalar25/random:element", - "//sxt/scalar25/type:element", - ], - with_test = False, - deps = [ - ":constant", - ], -) - -sxt_cc_component( - name = "workspace", - with_test = False, -) - -sxt_cc_component( - name = "driver", - with_test = False, - deps = [ - ":workspace", - "//sxt/base/container:span", - "//sxt/execution/async:future_fwd", - ], -) - -sxt_cc_component( - name = "driver_test", - impl_deps = [ - ":driver", - ":workspace", - "//sxt/base/test:unit_test", - "//sxt/execution/async:future", - "//sxt/execution/schedule:scheduler", - "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - with_test = False, -) - -sxt_cc_component( - name = "cpu_driver", - impl_deps = [ - ":polynomial_utility", - "//sxt/base/container:stack_array", - "//sxt/base/error:panic", - "//sxt/base/num:ceil_log2", - "//sxt/execution/async:future", - "//sxt/memory/management:managed_array", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:sub", - "//sxt/scalar25/operation:muladd", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - test_deps = [ - ":driver_test", - "//sxt/base/test:unit_test", - ], - deps = [ - ":driver", - ":workspace", - ], -) - -sxt_cc_component( - name = "gpu_driver", - impl_deps = [ - ":constant", - ":polynomial_mapper", - ":sum_gpu", - "//sxt/algorithm/iteration:for_each", - "//sxt/algorithm/reduction:reduction", - "//sxt/base/container:stack_array", - "//sxt/base/device:stream", - "//sxt/base/device:memory_utility", - "//sxt/base/error:panic", - "//sxt/base/num:ceil_log2", - "//sxt/base/num:constexpr_switch", - "//sxt/execution/async:coroutine", - "//sxt/execution/async:future", - "//sxt/execution/device:synchronization", - "//sxt/memory/management:managed_array", - "//sxt/memory/resource:async_device_resource", - "//sxt/memory/resource:device_resource", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:sub", - "//sxt/scalar25/operation:muladd", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - test_deps = [ - ":driver_test", - "//sxt/base/test:unit_test", - ], - deps = [ - ":driver", - ], -) - -sxt_cc_component( - name = "polynomial_utility", - impl_deps = [ - "//sxt/scalar25/operation:add", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:sub", - "//sxt/scalar25/operation:muladd", - "//sxt/scalar25/operation:neg", - "//sxt/scalar25/type:element", - ], - test_deps = [ - "//sxt/base/test:unit_test", - "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - deps = [ - "//sxt/base/container:span", - ], -) - -sxt_cc_component( - name = "chunked_gpu_driver", - impl_deps = [ - ":device_cache", - ":fold_gpu", - ":gpu_driver", - ":mle_utility", - ":sum_gpu", - "//sxt/algorithm/iteration:transform", - "//sxt/base/error:assert", - "//sxt/base/num:ceil_log2", - "//sxt/execution/async:coroutine", - "//sxt/execution/async:future", - "//sxt/memory/management:managed_array", - "//sxt/scalar25/operation:sub", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - test_deps = [ - ":driver_test", - "//sxt/base/test:unit_test", - ], - deps = [ - ":driver", - ], -) - -sxt_cc_component( - name = "proof_computation", - impl_deps = [ - ":driver", - ":sumcheck_transcript", - "//sxt/execution/async:future", - "//sxt/base/error:assert", - "//sxt/base/num:ceil_log2", - "//sxt/execution/async:coroutine", - "//sxt/scalar25/type:element", - ], - test_deps = [ - ":chunked_gpu_driver", - ":cpu_driver", - ":gpu_driver", - ":mle_utility", - ":reference_transcript", - ":sumcheck_random", - ":verification", - "//sxt/base/test:unit_test", - "//sxt/execution/async:future", - "//sxt/execution/schedule:scheduler", - "//sxt/proof/transcript", - "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", - ], - deps = [ - "//sxt/base/container:span", - "//sxt/execution/async:future_fwd", - ], -) - -sxt_cc_component( - name = "verification", - impl_deps = [ - ":polynomial_utility", - ":sumcheck_transcript", - "//sxt/base/error:assert", - "//sxt/base/log:log", - "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", - ], - test_deps = [ - ":reference_transcript", - "//sxt/base/test:unit_test", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - deps = [ - "//sxt/base/container:span", - ], -) - -sxt_cc_component( - name = "reduction_gpu", - impl_deps = [ - "//sxt/algorithm/base:identity_mapper", - "//sxt/algorithm/reduction:kernel_fit", - "//sxt/algorithm/reduction:thread_reduction", - "//sxt/base/device:memory_utility", - "//sxt/base/device:stream", - "//sxt/base/error:assert", - "//sxt/execution/async:coroutine", - "//sxt/execution/async:future", - "//sxt/execution/device:synchronization", - "//sxt/execution/kernel:kernel_dims", - "//sxt/execution/kernel:launch", - "//sxt/memory/management:managed_array", - "//sxt/memory/resource:async_device_resource", - "//sxt/scalar25/operation:add", - "//sxt/scalar25/operation:accumulator", - "//sxt/scalar25/type:element", - ], - test_deps = [ - "//sxt/base/device:property", - "//sxt/base/device:stream", - "//sxt/base/test:unit_test", - "//sxt/execution/async:future", - "//sxt/execution/schedule:scheduler", - "//sxt/memory/resource:managed_device_resource", - "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - deps = [ - "//sxt/base/container:span", - "//sxt/execution/async:future_fwd", - ], -) - -sxt_cc_component( - name = "polynomial_mapper", - test_deps = [ - "//sxt/algorithm/base:mapper", - "//sxt/base/test:unit_test", - "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - deps = [ - ":polynomial_utility", - "//sxt/base/macro:cuda_callable", - "//sxt/scalar25/operation:add", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:muladd", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], -) +# sxt_cc_component( +# name = "mle_utility", +# impl_deps = [ +# "//sxt/base/container:span_utility", +# "//sxt/base/device:memory_utility", +# "//sxt/base/device:property", +# "//sxt/base/num:divide_up", +# "//sxt/base/num:ceil_log2", +# "//sxt/memory/management:managed_array", +# "//sxt/scalar25/type:element", +# ], +# test_deps = [ +# "//sxt/base/device:stream", +# "//sxt/base/device:synchronization", +# "//sxt/base/test:unit_test", +# "//sxt/memory/management:managed_array", +# "//sxt/memory/resource:managed_device_resource", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# deps = [ +# "//sxt/base/container:span", +# "//sxt/memory/management:managed_array_fwd", +# ], +# ) +# +# sxt_cc_component( +# name = "sum_gpu", +# impl_deps = [ +# ":constant", +# ":device_cache", +# ":mle_utility", +# ":polynomial_mapper", +# ":reduction_gpu", +# "//sxt/algorithm/reduction:kernel_fit", +# "//sxt/algorithm/reduction:thread_reduction", +# "//sxt/base/device:memory_utility", +# "//sxt/base/device:stream", +# "//sxt/base/device:state", +# "//sxt/base/iterator:split", +# "//sxt/base/num:ceil_log2", +# "//sxt/base/num:constexpr_switch", +# "//sxt/execution/async:coroutine", +# "//sxt/execution/async:future", +# "//sxt/execution/device:for_each", +# "//sxt/execution/kernel:kernel_dims", +# "//sxt/memory/resource:device_resource", +# "//sxt/memory/resource:async_device_resource", +# "//sxt/scalar25/operation:add", +# "//sxt/scalar25/operation:mul", +# "//sxt/scalar25/type:element", +# ], +# test_deps = [ +# ":device_cache", +# "//sxt/base/iterator:split", +# "//sxt/base/test:unit_test", +# "//sxt/execution/async:future", +# "//sxt/execution/schedule:scheduler", +# "//sxt/scalar25/operation:overload", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# deps = [ +# "//sxt/base/container:span", +# "//sxt/execution/async:future_fwd", +# ], +# ) +# +# sxt_cc_component( +# name = "sumcheck_transcript", +# with_test = False, +# deps = [ +# "//sxt/base/container:span", +# ], +# ) +# +# sxt_cc_component( +# name = "reference_transcript", +# impl_deps = [ +# "//sxt/scalar25/type:element", +# "//sxt/proof/transcript:transcript_utility", +# ], +# test_deps = [ +# "//sxt/base/test:unit_test", +# "//sxt/scalar25/type:literal", +# ], +# deps = [ +# ":sumcheck_transcript", +# "//sxt/proof/transcript", +# ], +# ) +# +# sxt_cc_component( +# name = "sumcheck_random", +# impl_deps = [ +# "//sxt/base/error:assert", +# "//sxt/base/num:fast_random_number_generator", +# "//sxt/scalar25/random:element", +# "//sxt/scalar25/type:element", +# ], +# with_test = False, +# deps = [ +# ":constant", +# ], +# ) +# +# sxt_cc_component( +# name = "workspace", +# with_test = False, +# ) +# +# sxt_cc_component( +# name = "driver", +# with_test = False, +# deps = [ +# ":workspace", +# "//sxt/base/container:span", +# "//sxt/execution/async:future_fwd", +# ], +# ) +# +# sxt_cc_component( +# name = "driver_test", +# impl_deps = [ +# ":driver", +# ":workspace", +# "//sxt/base/test:unit_test", +# "//sxt/execution/async:future", +# "//sxt/execution/schedule:scheduler", +# "//sxt/scalar25/operation:overload", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# with_test = False, +# ) +# +# sxt_cc_component( +# name = "cpu_driver", +# impl_deps = [ +# ":polynomial_utility", +# "//sxt/base/container:stack_array", +# "//sxt/base/error:panic", +# "//sxt/base/num:ceil_log2", +# "//sxt/execution/async:future", +# "//sxt/memory/management:managed_array", +# "//sxt/scalar25/operation:mul", +# "//sxt/scalar25/operation:sub", +# "//sxt/scalar25/operation:muladd", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# test_deps = [ +# ":driver_test", +# "//sxt/base/test:unit_test", +# ], +# deps = [ +# ":driver", +# ":workspace", +# ], +# ) +# +# sxt_cc_component( +# name = "gpu_driver", +# impl_deps = [ +# ":constant", +# ":polynomial_mapper", +# ":sum_gpu", +# "//sxt/algorithm/iteration:for_each", +# "//sxt/algorithm/reduction:reduction", +# "//sxt/base/container:stack_array", +# "//sxt/base/device:stream", +# "//sxt/base/device:memory_utility", +# "//sxt/base/error:panic", +# "//sxt/base/num:ceil_log2", +# "//sxt/base/num:constexpr_switch", +# "//sxt/execution/async:coroutine", +# "//sxt/execution/async:future", +# "//sxt/execution/device:synchronization", +# "//sxt/memory/management:managed_array", +# "//sxt/memory/resource:async_device_resource", +# "//sxt/memory/resource:device_resource", +# "//sxt/scalar25/operation:mul", +# "//sxt/scalar25/operation:sub", +# "//sxt/scalar25/operation:muladd", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# test_deps = [ +# ":driver_test", +# "//sxt/base/test:unit_test", +# ], +# deps = [ +# ":driver", +# ], +# ) +# +# sxt_cc_component( +# name = "polynomial_utility", +# impl_deps = [ +# "//sxt/scalar25/operation:add", +# "//sxt/scalar25/operation:mul", +# "//sxt/scalar25/operation:sub", +# "//sxt/scalar25/operation:muladd", +# "//sxt/scalar25/operation:neg", +# "//sxt/scalar25/type:element", +# ], +# test_deps = [ +# "//sxt/base/test:unit_test", +# "//sxt/scalar25/operation:overload", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# deps = [ +# "//sxt/base/container:span", +# ], +# ) +# +# sxt_cc_component( +# name = "chunked_gpu_driver", +# impl_deps = [ +# ":device_cache", +# ":fold_gpu", +# ":gpu_driver", +# ":mle_utility", +# ":sum_gpu", +# "//sxt/algorithm/iteration:transform", +# "//sxt/base/error:assert", +# "//sxt/base/num:ceil_log2", +# "//sxt/execution/async:coroutine", +# "//sxt/execution/async:future", +# "//sxt/memory/management:managed_array", +# "//sxt/scalar25/operation:sub", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# test_deps = [ +# ":driver_test", +# "//sxt/base/test:unit_test", +# ], +# deps = [ +# ":driver", +# ], +# ) +# +# sxt_cc_component( +# name = "proof_computation", +# impl_deps = [ +# ":driver", +# ":sumcheck_transcript", +# "//sxt/execution/async:future", +# "//sxt/base/error:assert", +# "//sxt/base/num:ceil_log2", +# "//sxt/execution/async:coroutine", +# "//sxt/scalar25/type:element", +# ], +# test_deps = [ +# ":chunked_gpu_driver", +# ":cpu_driver", +# ":gpu_driver", +# ":mle_utility", +# ":reference_transcript", +# ":sumcheck_random", +# ":verification", +# "//sxt/base/test:unit_test", +# "//sxt/execution/async:future", +# "//sxt/execution/schedule:scheduler", +# "//sxt/proof/transcript", +# "//sxt/scalar25/operation:overload", +# "//sxt/scalar25/type:element", +# ], +# deps = [ +# "//sxt/base/container:span", +# "//sxt/execution/async:future_fwd", +# ], +# ) +# +# sxt_cc_component( +# name = "verification", +# impl_deps = [ +# ":polynomial_utility", +# ":sumcheck_transcript", +# "//sxt/base/error:assert", +# "//sxt/base/log:log", +# "//sxt/scalar25/operation:overload", +# "//sxt/scalar25/type:element", +# ], +# test_deps = [ +# ":reference_transcript", +# "//sxt/base/test:unit_test", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# deps = [ +# "//sxt/base/container:span", +# ], +# ) +# +# sxt_cc_component( +# name = "reduction_gpu", +# impl_deps = [ +# "//sxt/algorithm/base:identity_mapper", +# "//sxt/algorithm/reduction:kernel_fit", +# "//sxt/algorithm/reduction:thread_reduction", +# "//sxt/base/device:memory_utility", +# "//sxt/base/device:stream", +# "//sxt/base/error:assert", +# "//sxt/execution/async:coroutine", +# "//sxt/execution/async:future", +# "//sxt/execution/device:synchronization", +# "//sxt/execution/kernel:kernel_dims", +# "//sxt/execution/kernel:launch", +# "//sxt/memory/management:managed_array", +# "//sxt/memory/resource:async_device_resource", +# "//sxt/scalar25/operation:add", +# "//sxt/scalar25/operation:accumulator", +# "//sxt/scalar25/type:element", +# ], +# test_deps = [ +# "//sxt/base/device:property", +# "//sxt/base/device:stream", +# "//sxt/base/test:unit_test", +# "//sxt/execution/async:future", +# "//sxt/execution/schedule:scheduler", +# "//sxt/memory/resource:managed_device_resource", +# "//sxt/scalar25/operation:overload", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# deps = [ +# "//sxt/base/container:span", +# "//sxt/execution/async:future_fwd", +# ], +# ) +# +# sxt_cc_component( +# name = "polynomial_mapper", +# test_deps = [ +# "//sxt/algorithm/base:mapper", +# "//sxt/base/test:unit_test", +# "//sxt/scalar25/operation:overload", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# deps = [ +# ":polynomial_utility", +# "//sxt/base/macro:cuda_callable", +# "//sxt/scalar25/operation:add", +# "//sxt/scalar25/operation:mul", +# "//sxt/scalar25/operation:muladd", +# "//sxt/scalar25/type:element", +# "//sxt/scalar25/type:literal", +# ], +# ) diff --git a/sxt/proof/sumcheck/constant.cc b/sxt/proof/sumcheck/constant.cc deleted file mode 100644 index 396197800..000000000 --- a/sxt/proof/sumcheck/constant.cc +++ /dev/null @@ -1,17 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/constant.h" diff --git a/sxt/proof/sumcheck/constant.h b/sxt/proof/sumcheck/constant.h deleted file mode 100644 index aa70ad9a0..000000000 --- a/sxt/proof/sumcheck/constant.h +++ /dev/null @@ -1,26 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// max_degree_v -//-------------------------------------------------------------------------------------------------- -// the maximum degree of the round polynomial -// used in sumcheck -constexpr unsigned max_degree_v = 5u; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/device_cache.cc b/sxt/proof/sumcheck/device_cache.cc deleted file mode 100644 index 2daaadc4b..000000000 --- a/sxt/proof/sumcheck/device_cache.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/device_cache.h" - -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/state.h" -#include "sxt/base/device/stream.h" -#include "sxt/memory/resource/device_resource.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// make_device_copy -//-------------------------------------------------------------------------------------------------- -static std::unique_ptr -make_device_copy(basct::cspan> product_table, - basct::cspan product_terms, basdv::stream& stream) noexcept { - device_cache_data res{ - .product_table{product_table.size(), memr::get_device_resource()}, - .product_terms{product_terms.size(), memr::get_device_resource()}, - }; - basdv::async_copy_host_to_device(res.product_table, product_table, stream); - basdv::async_copy_host_to_device(res.product_terms, product_terms, stream); - return std::make_unique(std::move(res)); -} - -//-------------------------------------------------------------------------------------------------- -// constructor -//-------------------------------------------------------------------------------------------------- -device_cache::device_cache(basct::cspan> product_table, - basct::cspan product_terms) noexcept - : product_table_{product_table}, product_terms_{product_terms} {} - -//-------------------------------------------------------------------------------------------------- -// lookup -//-------------------------------------------------------------------------------------------------- -void device_cache::lookup(basct::cspan>& product_table, - basct::cspan& product_terms, basdv::stream& stream) noexcept { - auto& ptr = data_[basdv::get_device()]; - if (ptr == nullptr) { - ptr = make_device_copy(product_table_, product_terms_, stream); - } - product_table = ptr->product_table; - product_terms = ptr->product_terms; -} - -//-------------------------------------------------------------------------------------------------- -// clear -//-------------------------------------------------------------------------------------------------- -std::unique_ptr device_cache::clear() noexcept { - auto res{std::move(data_[basdv::get_device()])}; - for (auto& ptr : data_) { - ptr.reset(); - } - return res; -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/device_cache.h b/sxt/proof/sumcheck/device_cache.h deleted file mode 100644 index c61038392..000000000 --- a/sxt/proof/sumcheck/device_cache.h +++ /dev/null @@ -1,58 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -#include "sxt/base/container/span.h" -#include "sxt/base/device/device_map.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::basdv { -class stream; -} - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// device_cache_data -//-------------------------------------------------------------------------------------------------- -struct device_cache_data { - memmg::managed_array> product_table; - memmg::managed_array product_terms; -}; - -//-------------------------------------------------------------------------------------------------- -// device_cache -//-------------------------------------------------------------------------------------------------- -class device_cache { -public: - device_cache(basct::cspan> product_table, - basct::cspan product_terms) noexcept; - - void lookup(basct::cspan>& product_table, - basct::cspan& product_terms, basdv::stream& stream) noexcept; - - std::unique_ptr clear() noexcept; - -private: - basct::cspan> product_table_; - basct::cspan product_terms_; - basdv::device_map> data_; -}; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/device_cache.t.cc b/sxt/proof/sumcheck/device_cache.t.cc deleted file mode 100644 index 008440faa..000000000 --- a/sxt/proof/sumcheck/device_cache.t.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/device_cache.h" - -#include - -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/device/synchronization.h" -#include "sxt/base/test/unit_test.h" -#include "sxt/scalar25/type/literal.h" - -using namespace sxt; -using namespace sxt::prfsk; -using s25t::operator""_s25; - -TEST_CASE("we can cache device values that don't change as a proof is computed") { - std::vector> product_table; - std::vector product_terms; - - basdv::stream stream; - - basct::cspan> product_table_dev; - basct::cspan product_terms_dev; - - SECTION("we can access values from device memory") { - product_table = {{0x123_s25, 0}}; - product_terms = {0}; - device_cache cache{product_table, product_terms}; - cache.lookup(product_table_dev, product_terms_dev, stream); - - std::vector> product_table_p(product_table.size()); - basdv::async_copy_device_to_host(product_table_p, product_table_dev, stream); - - std::vector product_terms_p(product_terms.size()); - basdv::async_copy_device_to_host(product_terms_p, product_terms_dev, stream); - - basdv::synchronize_stream(stream); - REQUIRE(product_table_p == product_table); - REQUIRE(product_terms_p == product_terms); - } - - SECTION("we can clear the device cache") { - product_table = {{0x123_s25, 0}}; - product_terms = {0}; - device_cache cache{product_table, product_terms}; - cache.lookup(product_table_dev, product_terms_dev, stream); - - std::vector> product_table_p(product_table.size()); - basdv::async_copy_device_to_host(product_table_p, product_table_dev, stream); - - std::vector product_terms_p(product_terms.size()); - basdv::async_copy_device_to_host(product_terms_p, product_terms_dev, stream); - - auto data = cache.clear(); - basdv::async_copy_device_to_host(product_table_p, data->product_table, stream); - basdv::async_copy_device_to_host(product_terms_p, data->product_terms, stream); - basdv::synchronize_stream(stream); - REQUIRE(product_table_p == product_table); - REQUIRE(product_terms_p == product_terms); - } -} diff --git a/sxt/proof/sumcheck/fold_gpu.cc b/sxt/proof/sumcheck/fold_gpu.cc deleted file mode 100644 index 5e55e226e..000000000 --- a/sxt/proof/sumcheck/fold_gpu.cc +++ /dev/null @@ -1,156 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/fold_gpu.h" - -#include - -#include "sxt/algorithm/iteration/kernel_fit.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/property.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/error/assert.h" -#include "sxt/base/iterator/split.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/device/for_each.h" -#include "sxt/execution/device/synchronization.h" -#include "sxt/execution/kernel/kernel_dims.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/memory/resource/async_device_resource.h" -#include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck/mle_utility.h" -#include "sxt/scalar25/operation/mul.h" -#include "sxt/scalar25/operation/muladd.h" -#include "sxt/scalar25/operation/sub.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// fold_kernel -//-------------------------------------------------------------------------------------------------- -static __global__ void fold_kernel(s25t::element* __restrict__ mles, unsigned np, unsigned split, - s25t::element r, s25t::element one_m_r) noexcept { - auto thread_index = threadIdx.x; - auto block_index = blockIdx.x; - auto block_size = blockDim.x; - auto k = basn::divide_up(split, gridDim.x * block_size) * block_size; - auto block_first = block_index * k; - assert(block_first < split && "every block should be active"); - auto m = umin(block_first + k, split); - - // adjust mles - mles += np * blockIdx.y; - - // fold - auto index = block_first + thread_index; - for (; index < m; index += block_size) { - auto x = mles[index]; - s25o::mul(x, x, one_m_r); - auto index_p = split + index; - if (index_p < np) { - s25o::muladd(x, mles[index_p], r, x); - } - mles[index] = x; - } -} - -//-------------------------------------------------------------------------------------------------- -// fold_impl -//-------------------------------------------------------------------------------------------------- -static xena::future<> fold_impl(basct::span mles_p, basct::cspan mles, - unsigned n, unsigned mid, unsigned a, unsigned b, - const s25t::element& r, const s25t::element one_m_r) noexcept { - auto num_mles = mles.size() / n; - auto split = b - a; - - // copy MLEs to device - basdv::stream stream; - memr::async_device_resource resource{stream}; - memmg::managed_array mles_dev{&resource}; - copy_partial_mles(mles_dev, stream, mles, n, a, b); - - // fold - auto np = mles_dev.size() / num_mles; - auto dims = algi::fit_iteration_kernel(split); - fold_kernel<<(dims.block_size), 0, - stream>>>(mles_dev.data(), np, split, r, one_m_r); - - // copy results back - copy_folded_mles(mles_p, stream, mles_dev, mid, a, b); - - co_await xendv::await_stream(stream); -} - -//-------------------------------------------------------------------------------------------------- -// fold_gpu -//-------------------------------------------------------------------------------------------------- -xena::future<> fold_gpu(basct::span mles_p, - const basit::split_options& split_options, basct::cspan mles, - unsigned n, const s25t::element& r) noexcept { - using s25t::operator""_s25; - auto num_mles = mles.size() / n; - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto mid = 1u << (num_variables - 1u); - SXT_DEBUG_ASSERT( - // clang-format off - n > 1 && mles.size() == num_mles * n - // clang-format on - ); - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); - - // split - auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, mid}, split_options); - - // fold - co_await xendv::concurrent_for_each( - chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> { - co_await fold_impl(mles_p, mles, n, mid, rng.a(), rng.b(), r, one_m_r); - }); -} - -xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, - unsigned n, const s25t::element& r) noexcept { - using s25t::operator""_s25; - auto num_mles = mles.size() / n; - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto mid = 1u << (num_variables - 1u); - SXT_DEBUG_ASSERT( - // clang-format off - n > 1 && mles.size() == num_mles * n - // clang-format on - ); - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); - - // split - basit::split_options split_options{ - .min_chunk_size = 1024u * 128u, - .max_chunk_size = 1024u * 256u, - .split_factor = basdv::get_num_devices(), - }; - auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, mid}, split_options); - - // fold - co_await xendv::concurrent_for_each( - chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> { - co_await fold_impl(mles_p, mles, n, mid, rng.a(), rng.b(), r, one_m_r); - }); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/fold_gpu.h b/sxt/proof/sumcheck/fold_gpu.h deleted file mode 100644 index 611d37522..000000000 --- a/sxt/proof/sumcheck/fold_gpu.h +++ /dev/null @@ -1,40 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/base/container/span.h" -#include "sxt/execution/async/future_fwd.h" - -namespace sxt::s25t { -class element; -} - -namespace sxt::basit { -struct split_options; -} - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// fold_gpu -//-------------------------------------------------------------------------------------------------- -xena::future<> fold_gpu(basct::span mles_p, - const basit::split_options& split_options, basct::cspan mles, - unsigned n, const s25t::element& r) noexcept; - -xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, - unsigned n, const s25t::element& r) noexcept; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/fold_gpu.t.cc b/sxt/proof/sumcheck/fold_gpu.t.cc deleted file mode 100644 index c6d6ae693..000000000 --- a/sxt/proof/sumcheck/fold_gpu.t.cc +++ /dev/null @@ -1,81 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/fold_gpu.h" - -#include - -#include "sxt/base/iterator/split.h" -#include "sxt/base/test/unit_test.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/schedule/scheduler.h" -#include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -using namespace sxt; -using namespace sxt::prfsk; -using s25t::operator""_s25; - -TEST_CASE("we can fold scalars using the gpu") { - std::vector mles, mles_p, expected; - - auto r = 0xabc123_s25; - auto one_m_r = 0x1_s25 - r; - - SECTION("we can fold a single mle with n=2") { - mles = {0x1_s25, 0x2_s25}; - mles_p.resize(1); - auto fut = fold_gpu(mles_p, mles, 2, r); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - expected = { - one_m_r * mles[0] + r * mles[1], - }; - REQUIRE(mles_p == expected); - } - - SECTION("we can fold a single mle with n=3") { - mles = {0x123_s25, 0x456_s25, 0x789_s25}; - mles_p.resize(2); - auto fut = fold_gpu(mles_p, mles, 3, r); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - expected = { - one_m_r * mles[0] + r * mles[2], - one_m_r * mles[1], - }; - REQUIRE(mles_p == expected); - } - - SECTION("we can split folds") { - basit::split_options split_options{ - .min_chunk_size = 1, - .max_chunk_size = 1, - .split_factor = 2, - }; - mles = {0x123_s25, 0x456_s25, 0x789_s25, 0x101112_s25}; - mles_p.resize(2); - auto fut = fold_gpu(mles_p, split_options, mles, 4, r); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - expected = { - one_m_r * mles[0] + r * mles[2], - one_m_r * mles[1] + r * mles[3], - }; - REQUIRE(mles_p == expected); - } -} From 606e38a53b1fb470859dc55e19ef31fc15d0f381 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 19:57:19 -0800 Subject: [PATCH 79/83] rework sumcheck --- sxt/proof/sumcheck/BUILD | 196 ---------------- sxt/proof/sumcheck/chunked_gpu_driver.cc | 154 ------------- sxt/proof/sumcheck/chunked_gpu_driver.h | 42 ---- sxt/proof/sumcheck/chunked_gpu_driver.t.cc | 35 --- sxt/proof/sumcheck/cpu_driver.cc | 161 ------------- sxt/proof/sumcheck/cpu_driver.h | 37 --- sxt/proof/sumcheck/cpu_driver.t.cc | 30 --- sxt/proof/sumcheck/driver.cc | 17 -- sxt/proof/sumcheck/driver.h | 47 ---- sxt/proof/sumcheck/driver_test.cc | 133 ----------- sxt/proof/sumcheck/driver_test.h | 26 --- sxt/proof/sumcheck/gpu_driver.cc | 200 ---------------- sxt/proof/sumcheck/gpu_driver.h | 43 ---- sxt/proof/sumcheck/gpu_driver.t.cc | 28 --- sxt/proof/sumcheck/mle_utility.cc | 98 -------- sxt/proof/sumcheck/mle_utility.h | 48 ---- sxt/proof/sumcheck/mle_utility.t.cc | 94 -------- sxt/proof/sumcheck/reference_transcript.cc | 46 ---- sxt/proof/sumcheck/reference_transcript.h | 37 --- sxt/proof/sumcheck/reference_transcript.t.cc | 58 ----- sxt/proof/sumcheck/sum_gpu.cc | 231 ------------------- sxt/proof/sumcheck/sum_gpu.h | 58 ----- sxt/proof/sumcheck/sum_gpu.t.cc | 138 ----------- sxt/proof/sumcheck/sumcheck_random.cc | 77 ------- sxt/proof/sumcheck/sumcheck_random.h | 57 ----- sxt/proof/sumcheck/sumcheck_transcript.cc | 17 -- sxt/proof/sumcheck/sumcheck_transcript.h | 40 ---- sxt/proof/sumcheck/workspace.cc | 17 -- sxt/proof/sumcheck/workspace.h | 27 --- sxt/proof/sumcheck2/BUILD | 3 + sxt/proof/sumcheck2/polynomial_mapper.t.cc | 44 +++- 31 files changed, 46 insertions(+), 2193 deletions(-) delete mode 100644 sxt/proof/sumcheck/chunked_gpu_driver.cc delete mode 100644 sxt/proof/sumcheck/chunked_gpu_driver.h delete mode 100644 sxt/proof/sumcheck/chunked_gpu_driver.t.cc delete mode 100644 sxt/proof/sumcheck/cpu_driver.cc delete mode 100644 sxt/proof/sumcheck/cpu_driver.h delete mode 100644 sxt/proof/sumcheck/cpu_driver.t.cc delete mode 100644 sxt/proof/sumcheck/driver.cc delete mode 100644 sxt/proof/sumcheck/driver.h delete mode 100644 sxt/proof/sumcheck/driver_test.cc delete mode 100644 sxt/proof/sumcheck/driver_test.h delete mode 100644 sxt/proof/sumcheck/gpu_driver.cc delete mode 100644 sxt/proof/sumcheck/gpu_driver.h delete mode 100644 sxt/proof/sumcheck/gpu_driver.t.cc delete mode 100644 sxt/proof/sumcheck/mle_utility.cc delete mode 100644 sxt/proof/sumcheck/mle_utility.h delete mode 100644 sxt/proof/sumcheck/mle_utility.t.cc delete mode 100644 sxt/proof/sumcheck/reference_transcript.cc delete mode 100644 sxt/proof/sumcheck/reference_transcript.h delete mode 100644 sxt/proof/sumcheck/reference_transcript.t.cc delete mode 100644 sxt/proof/sumcheck/sum_gpu.cc delete mode 100644 sxt/proof/sumcheck/sum_gpu.h delete mode 100644 sxt/proof/sumcheck/sum_gpu.t.cc delete mode 100644 sxt/proof/sumcheck/sumcheck_random.cc delete mode 100644 sxt/proof/sumcheck/sumcheck_random.h delete mode 100644 sxt/proof/sumcheck/sumcheck_transcript.cc delete mode 100644 sxt/proof/sumcheck/sumcheck_transcript.h delete mode 100644 sxt/proof/sumcheck/workspace.cc delete mode 100644 sxt/proof/sumcheck/workspace.h diff --git a/sxt/proof/sumcheck/BUILD b/sxt/proof/sumcheck/BUILD index 99f9cf019..a80a29ed6 100644 --- a/sxt/proof/sumcheck/BUILD +++ b/sxt/proof/sumcheck/BUILD @@ -1,200 +1,4 @@ # sxt_cc_component( -# name = "mle_utility", -# impl_deps = [ -# "//sxt/base/container:span_utility", -# "//sxt/base/device:memory_utility", -# "//sxt/base/device:property", -# "//sxt/base/num:divide_up", -# "//sxt/base/num:ceil_log2", -# "//sxt/memory/management:managed_array", -# "//sxt/scalar25/type:element", -# ], -# test_deps = [ -# "//sxt/base/device:stream", -# "//sxt/base/device:synchronization", -# "//sxt/base/test:unit_test", -# "//sxt/memory/management:managed_array", -# "//sxt/memory/resource:managed_device_resource", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# deps = [ -# "//sxt/base/container:span", -# "//sxt/memory/management:managed_array_fwd", -# ], -# ) -# -# sxt_cc_component( -# name = "sum_gpu", -# impl_deps = [ -# ":constant", -# ":device_cache", -# ":mle_utility", -# ":polynomial_mapper", -# ":reduction_gpu", -# "//sxt/algorithm/reduction:kernel_fit", -# "//sxt/algorithm/reduction:thread_reduction", -# "//sxt/base/device:memory_utility", -# "//sxt/base/device:stream", -# "//sxt/base/device:state", -# "//sxt/base/iterator:split", -# "//sxt/base/num:ceil_log2", -# "//sxt/base/num:constexpr_switch", -# "//sxt/execution/async:coroutine", -# "//sxt/execution/async:future", -# "//sxt/execution/device:for_each", -# "//sxt/execution/kernel:kernel_dims", -# "//sxt/memory/resource:device_resource", -# "//sxt/memory/resource:async_device_resource", -# "//sxt/scalar25/operation:add", -# "//sxt/scalar25/operation:mul", -# "//sxt/scalar25/type:element", -# ], -# test_deps = [ -# ":device_cache", -# "//sxt/base/iterator:split", -# "//sxt/base/test:unit_test", -# "//sxt/execution/async:future", -# "//sxt/execution/schedule:scheduler", -# "//sxt/scalar25/operation:overload", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# deps = [ -# "//sxt/base/container:span", -# "//sxt/execution/async:future_fwd", -# ], -# ) -# -# sxt_cc_component( -# name = "sumcheck_transcript", -# with_test = False, -# deps = [ -# "//sxt/base/container:span", -# ], -# ) -# -# sxt_cc_component( -# name = "reference_transcript", -# impl_deps = [ -# "//sxt/scalar25/type:element", -# "//sxt/proof/transcript:transcript_utility", -# ], -# test_deps = [ -# "//sxt/base/test:unit_test", -# "//sxt/scalar25/type:literal", -# ], -# deps = [ -# ":sumcheck_transcript", -# "//sxt/proof/transcript", -# ], -# ) -# -# sxt_cc_component( -# name = "sumcheck_random", -# impl_deps = [ -# "//sxt/base/error:assert", -# "//sxt/base/num:fast_random_number_generator", -# "//sxt/scalar25/random:element", -# "//sxt/scalar25/type:element", -# ], -# with_test = False, -# deps = [ -# ":constant", -# ], -# ) -# -# sxt_cc_component( -# name = "workspace", -# with_test = False, -# ) -# -# sxt_cc_component( -# name = "driver", -# with_test = False, -# deps = [ -# ":workspace", -# "//sxt/base/container:span", -# "//sxt/execution/async:future_fwd", -# ], -# ) -# -# sxt_cc_component( -# name = "driver_test", -# impl_deps = [ -# ":driver", -# ":workspace", -# "//sxt/base/test:unit_test", -# "//sxt/execution/async:future", -# "//sxt/execution/schedule:scheduler", -# "//sxt/scalar25/operation:overload", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# with_test = False, -# ) -# -# sxt_cc_component( -# name = "cpu_driver", -# impl_deps = [ -# ":polynomial_utility", -# "//sxt/base/container:stack_array", -# "//sxt/base/error:panic", -# "//sxt/base/num:ceil_log2", -# "//sxt/execution/async:future", -# "//sxt/memory/management:managed_array", -# "//sxt/scalar25/operation:mul", -# "//sxt/scalar25/operation:sub", -# "//sxt/scalar25/operation:muladd", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# test_deps = [ -# ":driver_test", -# "//sxt/base/test:unit_test", -# ], -# deps = [ -# ":driver", -# ":workspace", -# ], -# ) -# -# sxt_cc_component( -# name = "gpu_driver", -# impl_deps = [ -# ":constant", -# ":polynomial_mapper", -# ":sum_gpu", -# "//sxt/algorithm/iteration:for_each", -# "//sxt/algorithm/reduction:reduction", -# "//sxt/base/container:stack_array", -# "//sxt/base/device:stream", -# "//sxt/base/device:memory_utility", -# "//sxt/base/error:panic", -# "//sxt/base/num:ceil_log2", -# "//sxt/base/num:constexpr_switch", -# "//sxt/execution/async:coroutine", -# "//sxt/execution/async:future", -# "//sxt/execution/device:synchronization", -# "//sxt/memory/management:managed_array", -# "//sxt/memory/resource:async_device_resource", -# "//sxt/memory/resource:device_resource", -# "//sxt/scalar25/operation:mul", -# "//sxt/scalar25/operation:sub", -# "//sxt/scalar25/operation:muladd", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# test_deps = [ -# ":driver_test", -# "//sxt/base/test:unit_test", -# ], -# deps = [ -# ":driver", -# ], -# ) -# -# sxt_cc_component( # name = "polynomial_utility", # impl_deps = [ # "//sxt/scalar25/operation:add", diff --git a/sxt/proof/sumcheck/chunked_gpu_driver.cc b/sxt/proof/sumcheck/chunked_gpu_driver.cc deleted file mode 100644 index 1358c4cc7..000000000 --- a/sxt/proof/sumcheck/chunked_gpu_driver.cc +++ /dev/null @@ -1,154 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/chunked_gpu_driver.h" - -#include - -#include "sxt/algorithm/iteration/transform.h" -#include "sxt/base/error/assert.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/proof/sumcheck/device_cache.h" -#include "sxt/proof/sumcheck/fold_gpu.h" -#include "sxt/proof/sumcheck/gpu_driver.h" -#include "sxt/proof/sumcheck/mle_utility.h" -#include "sxt/proof/sumcheck/sum_gpu.h" -#include "sxt/scalar25/operation/sub.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// chunked_gpu_workspace -//-------------------------------------------------------------------------------------------------- -namespace { -struct chunked_gpu_workspace final : public workspace { - std::unique_ptr single_gpu_workspace; - - device_cache cache; - memmg::managed_array mles_data; - basct::cspan mles; - unsigned n; - unsigned num_variables; - - chunked_gpu_workspace(basct::cspan> product_table, - basct::cspan product_terms) noexcept - : cache{product_table, product_terms} {} -}; -} // namespace - -//-------------------------------------------------------------------------------------------------- -// try_make_single_gpu_workspace -//-------------------------------------------------------------------------------------------------- -static xena::future<> try_make_single_gpu_workspace(chunked_gpu_workspace& work, - double no_chunk_cutoff) noexcept { - auto gpu_memory_fraction = get_gpu_memory_fraction(work.mles); - if (gpu_memory_fraction > no_chunk_cutoff) { - co_return; - } - - // construct single gpu workspace - auto cache_data = work.cache.clear(); - gpu_driver drv; - work.single_gpu_workspace = - co_await drv.make_workspace(work.mles, std::move(cache_data->product_table), - std::move(cache_data->product_terms), work.n); - - // free data we no longer need - work.mles_data.reset(); - work.mles = {}; -} - -//-------------------------------------------------------------------------------------------------- -// constructor -//-------------------------------------------------------------------------------------------------- -chunked_gpu_driver::chunked_gpu_driver(double no_chunk_cutoff) noexcept - : no_chunk_cutoff_{no_chunk_cutoff} { - SXT_RELEASE_ASSERT(0 <= no_chunk_cutoff_ && no_chunk_cutoff_ <= 1.0); -} - -//-------------------------------------------------------------------------------------------------- -// make_workspace -//-------------------------------------------------------------------------------------------------- -xena::future> -chunked_gpu_driver::make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, - unsigned n) const noexcept { - auto res = std::make_unique(product_table, product_terms); - res->mles = mles; - res->n = n; - res->num_variables = std::max(basn::ceil_log2(n), 1); - auto gpu_memory_fraction = get_gpu_memory_fraction(mles); - if (gpu_memory_fraction <= no_chunk_cutoff_) { - gpu_driver drv; - res->single_gpu_workspace = co_await drv.make_workspace(mles, product_table, product_terms, n); - } - co_return std::unique_ptr(std::move(res)); -} - -//-------------------------------------------------------------------------------------------------- -// sum -//-------------------------------------------------------------------------------------------------- -xena::future<> chunked_gpu_driver::sum(basct::span polynomial, - workspace& ws) const noexcept { - auto& work = static_cast(ws); - if (work.single_gpu_workspace) { - gpu_driver drv; - co_return co_await drv.sum(polynomial, *work.single_gpu_workspace); - } - co_await sum_gpu(polynomial, work.cache, work.mles, work.n); -} - -//-------------------------------------------------------------------------------------------------- -// fold -//-------------------------------------------------------------------------------------------------- -xena::future<> chunked_gpu_driver::fold(workspace& ws, const s25t::element& r) const noexcept { - using s25t::operator""_s25; - auto& work = static_cast(ws); - if (work.single_gpu_workspace) { - gpu_driver drv; - co_return co_await drv.fold(*work.single_gpu_workspace, r); - } - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - auto num_mles = work.mles.size() / n; - SXT_RELEASE_ASSERT( - // clang-format off - work.n >= mid && work.mles.size() % n == 0 - // clang-format on - ); - - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); - - // fold - memmg::managed_array mles_p(num_mles * mid); - co_await fold_gpu(mles_p, work.mles, n, r); - - // update - work.n = mid; - --work.num_variables; - work.mles_data = std::move(mles_p); - work.mles = work.mles_data; - - // check if we should fall back to single gpu workspace - co_await try_make_single_gpu_workspace(work, no_chunk_cutoff_); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/chunked_gpu_driver.h b/sxt/proof/sumcheck/chunked_gpu_driver.h deleted file mode 100644 index ed93decbe..000000000 --- a/sxt/proof/sumcheck/chunked_gpu_driver.h +++ /dev/null @@ -1,42 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/proof/sumcheck/driver.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// chunked_gpu_driver -//-------------------------------------------------------------------------------------------------- -class chunked_gpu_driver final : public driver { -public: - explicit chunked_gpu_driver(double no_chunk_cutoff = 0.5) noexcept; - - // driver - xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept override; - - xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override; - - xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept override; - -private: - double no_chunk_cutoff_; -}; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/chunked_gpu_driver.t.cc b/sxt/proof/sumcheck/chunked_gpu_driver.t.cc deleted file mode 100644 index 8fb4a3aac..000000000 --- a/sxt/proof/sumcheck/chunked_gpu_driver.t.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/chunked_gpu_driver.h" - -#include "sxt/base/test/unit_test.h" -#include "sxt/proof/sumcheck/driver_test.h" - -using namespace sxt; -using namespace sxt::prfsk; - -TEST_CASE("we can perform the primitive operations for sumcheck proofs") { - SECTION("we handle the case when only chunking is used") { - chunked_gpu_driver drv{0.0}; - exercise_driver(drv); - } - - SECTION("we handle the case when the chunked driver falls back to the single gpu driver") { - chunked_gpu_driver drv{1.0}; - exercise_driver(drv); - } -} diff --git a/sxt/proof/sumcheck/cpu_driver.cc b/sxt/proof/sumcheck/cpu_driver.cc deleted file mode 100644 index cc8d69192..000000000 --- a/sxt/proof/sumcheck/cpu_driver.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/cpu_driver.h" - -#include - -#include "sxt/base/container/stack_array.h" -#include "sxt/base/error/panic.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/execution/async/future.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/proof/sumcheck/polynomial_utility.h" -#include "sxt/scalar25/operation/mul.h" -#include "sxt/scalar25/operation/muladd.h" -#include "sxt/scalar25/operation/sub.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// cpu_workspace -//-------------------------------------------------------------------------------------------------- -namespace { -struct cpu_workspace final : public workspace { - memmg::managed_array mles; - basct::cspan> product_table; - basct::cspan product_terms; - unsigned n; - unsigned num_variables; -}; -} // namespace - -//-------------------------------------------------------------------------------------------------- -// make_workspace -//-------------------------------------------------------------------------------------------------- -xena::future> -cpu_driver::make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept { - auto res = std::make_unique(); - res->mles = memmg::managed_array{mles.begin(), mles.end()}; - res->product_table = product_table; - res->product_terms = product_terms; - res->n = n; - res->num_variables = std::max(basn::ceil_log2(n), 1); - return xena::make_ready_future>(std::move(res)); -} - -//-------------------------------------------------------------------------------------------------- -// sum -//-------------------------------------------------------------------------------------------------- -xena::future<> cpu_driver::sum(basct::span polynomial, - workspace& ws) const noexcept { - auto& work = static_cast(ws); - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - SXT_RELEASE_ASSERT(work.n >= mid); - - auto mles = work.mles.data(); - auto product_table = work.product_table; - auto product_terms = work.product_terms; - - for (auto& val : polynomial) { - val = {}; - } - - // expand paired terms - auto n1 = work.n - mid; - for (unsigned i = 0; i < n1; ++i) { - unsigned term_first = 0; - for (auto [mult, num_terms] : product_table) { - SXT_RELEASE_ASSERT(num_terms < polynomial.size()); - auto terms = product_terms.subspan(term_first, num_terms); - SXT_STACK_ARRAY(p, num_terms + 1u, s25t::element); - expand_products(p, mles + i, n, mid, terms); - for (unsigned term_index = 0; term_index < p.size(); ++term_index) { - s25o::muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); - } - term_first += num_terms; - } - } - - // expand terms where the corresponding pair is zero (i.e. n is not a power of 2) - for (unsigned i = n1; i < mid; ++i) { - unsigned term_first = 0; - for (auto [mult, num_terms] : product_table) { - auto terms = product_terms.subspan(term_first, num_terms); - SXT_STACK_ARRAY(p, num_terms + 1u, s25t::element); - partial_expand_products(p, mles + i, n, terms); - for (unsigned term_index = 0; term_index < p.size(); ++term_index) { - s25o::muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); - } - term_first += num_terms; - } - } - - return xena::make_ready_future(); -} - -//-------------------------------------------------------------------------------------------------- -// fold -//-------------------------------------------------------------------------------------------------- -xena::future<> cpu_driver::fold(workspace& ws, const s25t::element& r) const noexcept { - using s25t::operator""_s25; - - auto& work = static_cast(ws); - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - auto num_mles = work.mles.size() / n; - SXT_RELEASE_ASSERT( - // clang-format off - work.n >= mid && work.mles.size() % n == 0 - // clang-format on - ); - - auto mles = work.mles.data(); - memmg::managed_array mles_p(num_mles * mid); - - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); - auto n1 = work.n - mid; - for (auto mle_index = 0; mle_index < num_mles; ++mle_index) { - auto data = mles + n * mle_index; - auto data_p = mles_p.data() + mid * mle_index; - - // fold paired terms - for (unsigned i = 0; i < n1; ++i) { - auto val = data[i]; - s25o::mul(val, val, one_m_r); - s25o::muladd(val, r, data[mid + i], val); - data_p[i] = val; - } - - // fold terms paired with zero - for (unsigned i = n1; i < mid; ++i) { - auto val = data[i]; - s25o::mul(val, val, one_m_r); - data_p[i] = val; - } - } - - work.n = mid; - --work.num_variables; - work.mles = std::move(mles_p); - return xena::make_ready_future(); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/cpu_driver.h b/sxt/proof/sumcheck/cpu_driver.h deleted file mode 100644 index b017396d1..000000000 --- a/sxt/proof/sumcheck/cpu_driver.h +++ /dev/null @@ -1,37 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/proof/sumcheck/driver.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// cpu_driver -//-------------------------------------------------------------------------------------------------- -class cpu_driver final : public driver { -public: - // driver - xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept override; - - xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override; - - xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept override; -}; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/cpu_driver.t.cc b/sxt/proof/sumcheck/cpu_driver.t.cc deleted file mode 100644 index 42ae1e7c4..000000000 --- a/sxt/proof/sumcheck/cpu_driver.t.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/cpu_driver.h" - -#include - -#include "sxt/base/test/unit_test.h" -#include "sxt/proof/sumcheck/driver_test.h" - -using namespace sxt; -using namespace sxt::prfsk; - -TEST_CASE("we can perform the primitive operations for sumcheck proofs") { - cpu_driver drv; - exercise_driver(drv); -} diff --git a/sxt/proof/sumcheck/driver.cc b/sxt/proof/sumcheck/driver.cc deleted file mode 100644 index 6e46927f6..000000000 --- a/sxt/proof/sumcheck/driver.cc +++ /dev/null @@ -1,17 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/driver.h" diff --git a/sxt/proof/sumcheck/driver.h b/sxt/proof/sumcheck/driver.h deleted file mode 100644 index 422bb88ad..000000000 --- a/sxt/proof/sumcheck/driver.h +++ /dev/null @@ -1,47 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "sxt/base/container/span.h" -#include "sxt/execution/async/future_fwd.h" -#include "sxt/proof/sumcheck/workspace.h" - -namespace sxt::s25t { -class element; -} - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// driver -//-------------------------------------------------------------------------------------------------- -class driver { -public: - virtual ~driver() noexcept = default; - - virtual xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept = 0; - - virtual xena::future<> sum(basct::span polynomial, - workspace& ws) const noexcept = 0; - - virtual xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept = 0; -}; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/driver_test.cc b/sxt/proof/sumcheck/driver_test.cc deleted file mode 100644 index 97fa80df1..000000000 --- a/sxt/proof/sumcheck/driver_test.cc +++ /dev/null @@ -1,133 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/driver_test.h" - -#include - -#include "sxt/base/test/unit_test.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/schedule/scheduler.h" -#include "sxt/proof/sumcheck/driver.h" -#include "sxt/proof/sumcheck/workspace.h" -#include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -namespace sxt::prfsk { -using s25t::operator""_s25; - -//-------------------------------------------------------------------------------------------------- -// exercise_driver -//-------------------------------------------------------------------------------------------------- -void exercise_driver(const driver& drv) { - std::vector mles; - std::vector> product_table{ - {0x1_s25, 1}, - }; - std::vector product_terms = {0}; - - std::vector p(2); - - SECTION("we can sum a polynomial with n = 1") { - std::vector mles = {0x123_s25}; - auto ws = drv.make_workspace(mles, product_table, product_terms, 1); - xens::get_scheduler().run(); - auto fut = drv.sum(p, *ws.value()); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == mles[0]); - REQUIRE(p[1] == -mles[0]); - } - - SECTION("we can sum a polynomial with a non-unity multiplier") { - std::vector mles = {0x123_s25}; - product_table[0].first = 0x2_s25; - auto ws = drv.make_workspace(mles, product_table, product_terms, 1); - xens::get_scheduler().run(); - auto fut = drv.sum(p, *ws.value()); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == 0x2_s25 * mles[0]); - REQUIRE(p[1] == -0x2_s25 * mles[0]); - } - - SECTION("we can sum a polynomial with n = 2") { - std::vector mles = {0x123_s25, 0x456_s25}; - auto ws = drv.make_workspace(mles, product_table, product_terms, 2); - xens::get_scheduler().run(); - auto fut = drv.sum(p, *ws.value()); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == mles[0]); - REQUIRE(p[1] == mles[1] - mles[0]); - } - - SECTION("we can sum a polynomial with two MLEs added together") { - std::vector mles = {0x123_s25, 0x456_s25}; - std::vector> product_table{ - {0x1_s25, 1}, - {0x1_s25, 1}, - }; - std::vector product_terms = {0, 1}; - - auto ws = drv.make_workspace(mles, product_table, product_terms, 1); - xens::get_scheduler().run(); - auto fut = drv.sum(p, *ws.value()); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == mles[0] + mles[1]); - REQUIRE(p[1] == -mles[0] - mles[1]); - } - - SECTION("we can sum a polynomial with two MLEs multiplied together") { - std::vector mles = {0x123_s25, 0x456_s25}; - std::vector> product_table{ - {0x1_s25, 2}, - }; - std::vector product_terms = {0, 1}; - p.resize(3); - - auto ws = drv.make_workspace(mles, product_table, product_terms, 1); - xens::get_scheduler().run(); - auto fut = drv.sum(p, *ws.value()); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == mles[0] * mles[1]); - REQUIRE(p[1] == -mles[0] * mles[1] - mles[1] * mles[0]); - REQUIRE(p[2] == mles[0] * mles[1]); - } - - SECTION("we can fold mles") { - std::vector mles = {0x123_s25, 0x456_s25, 0x789_s25}; - auto ws = drv.make_workspace(mles, product_table, product_terms, 3); - xens::get_scheduler().run(); - auto r = 0xabc123_s25; - auto fut = drv.fold(*ws.value(), r); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - fut = drv.sum(p, *ws.value()); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - - mles[0] = (0x1_s25 - r) * mles[0] + r * mles[2]; - mles[1] = (0x1_s25 - r) * mles[1]; - - REQUIRE(p[0] == mles[0]); - REQUIRE(p[1] == mles[1] - mles[0]); - } -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/driver_test.h b/sxt/proof/sumcheck/driver_test.h deleted file mode 100644 index 69c0a53ea..000000000 --- a/sxt/proof/sumcheck/driver_test.h +++ /dev/null @@ -1,26 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -namespace sxt::prfsk { -class driver; - -//-------------------------------------------------------------------------------------------------- -// exercise_driver -//-------------------------------------------------------------------------------------------------- -void exercise_driver(const driver& drv); -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/gpu_driver.cc b/sxt/proof/sumcheck/gpu_driver.cc deleted file mode 100644 index de996b487..000000000 --- a/sxt/proof/sumcheck/gpu_driver.cc +++ /dev/null @@ -1,200 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/gpu_driver.h" - -#include -#include - -#include "sxt/algorithm/iteration/for_each.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/device/synchronization.h" -#include "sxt/base/error/panic.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/base/num/constexpr_switch.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/device/synchronization.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/memory/resource/async_device_resource.h" -#include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck/constant.h" -#include "sxt/proof/sumcheck/sum_gpu.h" -#include "sxt/scalar25/operation/mul.h" -#include "sxt/scalar25/operation/muladd.h" -#include "sxt/scalar25/operation/sub.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// gpu_workspace -//-------------------------------------------------------------------------------------------------- -namespace { -struct gpu_workspace final : public workspace { - memmg::managed_array mles; - memmg::managed_array> product_table; - memmg::managed_array product_terms; - unsigned n; - unsigned num_variables; - - gpu_workspace() noexcept - : mles{memr::get_device_resource()}, product_table{memr::get_device_resource()}, - product_terms{memr::get_device_resource()} {} -}; -} // namespace - -//-------------------------------------------------------------------------------------------------- -// make_workspace -//-------------------------------------------------------------------------------------------------- -xena::future> gpu_driver::make_workspace( - basct::cspan mles, - memmg::managed_array>&& product_table_dev, - memmg::managed_array&& product_terms_dev, unsigned n) const noexcept { - auto ws = std::make_unique(); - - // dimensions - ws->n = n; - ws->num_variables = std::max(basn::ceil_log2(n), 1); - - // mles - ws->mles = memmg::managed_array{ - mles.size(), - memr::get_device_resource(), - }; - basdv::stream mle_stream; - basdv::async_copy_host_to_device(ws->mles, mles, mle_stream); - - // product_table - ws->product_table = std::move(product_table_dev); - - // product_terms - ws->product_terms = std::move(product_terms_dev); - - // await - co_await xendv::await_stream(mle_stream); - co_return ws; -} - -xena::future> -gpu_driver::make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept { - auto ws = std::make_unique(); - - // dimensions - ws->n = n; - ws->num_variables = std::max(basn::ceil_log2(n), 1); - - // mles - ws->mles = memmg::managed_array{ - mles.size(), - memr::get_device_resource(), - }; - basdv::stream mle_stream; - basdv::async_copy_host_to_device(ws->mles, mles, mle_stream); - - // product_table - ws->product_table = memmg::managed_array>{ - product_table.size(), - memr::get_device_resource(), - }; - basdv::stream product_table_stream; - basdv::async_copy_host_to_device(ws->product_table, product_table, product_table_stream); - - // product_terms - ws->product_terms = memmg::managed_array{ - product_terms.size(), - memr::get_device_resource(), - }; - basdv::stream product_terms_stream; - basdv::async_copy_host_to_device(ws->product_terms, product_terms, product_terms_stream); - - // await - co_await xendv::await_stream(mle_stream); - co_await xendv::await_stream(product_table_stream); - co_await xendv::await_stream(product_terms_stream); - co_return ws; -} - -//-------------------------------------------------------------------------------------------------- -// sum -//-------------------------------------------------------------------------------------------------- -xena::future<> gpu_driver::sum(basct::span polynomial, - workspace& ws) const noexcept { - auto& work = static_cast(ws); - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - SXT_RELEASE_ASSERT( - // clang-format off - work.n >= mid && - polynomial.size() - 1u <= max_degree_v - // clang-format on - ); - co_await sum_gpu(polynomial, work.mles, work.product_table, work.product_terms, n); -} - -//-------------------------------------------------------------------------------------------------- -// fold -//-------------------------------------------------------------------------------------------------- -xena::future<> gpu_driver::fold(workspace& ws, const s25t::element& r) const noexcept { - using s25t::operator""_s25; - auto& work = static_cast(ws); - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - auto num_mles = work.mles.size() / n; - SXT_RELEASE_ASSERT( - // clang-format off - work.n >= mid && work.mles.size() % n == 0 - // clang-format on - ); - - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); - - memmg::managed_array mles_p{num_mles * mid, memr::get_device_resource()}; - - auto f = [ - // clang-format off - mles_p = mles_p.data(), - mles = work.mles.data(), - n = n, - num_mles = num_mles, - r = r, - one_m_r = one_m_r - // clang-format on - ] __device__ - __host__(unsigned mid, unsigned i) noexcept { - for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { - auto val = mles[i + mle_index * n]; - s25o::mul(val, val, one_m_r); - if (mid + i < n) { - s25o::muladd(val, r, mles[mid + i + mle_index * n], val); - } - mles_p[i + mle_index * mid] = val; - } - }; - auto fut = algi::for_each(f, mid); - - // complete - co_await std::move(fut); - - work.n = mid; - --work.num_variables; - work.mles = std::move(mles_p); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/gpu_driver.h b/sxt/proof/sumcheck/gpu_driver.h deleted file mode 100644 index 891acb583..000000000 --- a/sxt/proof/sumcheck/gpu_driver.h +++ /dev/null @@ -1,43 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/memory/management/managed_array_fwd.h" -#include "sxt/proof/sumcheck/driver.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// gpu_driver -//-------------------------------------------------------------------------------------------------- -class gpu_driver final : public driver { -public: - xena::future> - make_workspace(basct::cspan mles, - memmg::managed_array>&& product_table_dev, - memmg::managed_array&& product_terms_dev, unsigned n) const noexcept; - - // driver - xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept override; - - xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override; - - xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept override; -}; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/gpu_driver.t.cc b/sxt/proof/sumcheck/gpu_driver.t.cc deleted file mode 100644 index cf0acbba0..000000000 --- a/sxt/proof/sumcheck/gpu_driver.t.cc +++ /dev/null @@ -1,28 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/gpu_driver.h" - -#include "sxt/base/test/unit_test.h" -#include "sxt/proof/sumcheck/driver_test.h" - -using namespace sxt; -using namespace sxt::prfsk; - -TEST_CASE("we can perform the primitive operations for sumcheck proofs") { - gpu_driver drv; - exercise_driver(drv); -} diff --git a/sxt/proof/sumcheck/mle_utility.cc b/sxt/proof/sumcheck/mle_utility.cc deleted file mode 100644 index 8bcf33f36..000000000 --- a/sxt/proof/sumcheck/mle_utility.cc +++ /dev/null @@ -1,98 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/mle_utility.h" - -#include -#include - -#include "sxt/base/container/span_utility.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/property.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/error/assert.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/base/num/divide_up.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// copy_partial_mles -//-------------------------------------------------------------------------------------------------- -void copy_partial_mles(memmg::managed_array& partial_mles, basdv::stream& stream, - basct::cspan mles, unsigned n, unsigned a, - unsigned b) noexcept { - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto mid = 1u << (num_variables - 1u); - auto num_mles = mles.size() / n; - auto part1_size = b - a; - SXT_DEBUG_ASSERT(a < b && b <= n); - auto ap = std::min(mid + a, n); - auto bp = std::min(mid + b, n); - auto part2_size = bp - ap; - - // resize array - auto partial_length = part1_size + part2_size; - partial_mles.resize(partial_length * num_mles); - - // copy data - for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { - // first part - auto src = mles.subspan(n * mle_index + a, part1_size); - auto dst = basct::subspan(partial_mles, partial_length * mle_index, part1_size); - basdv::async_copy_host_to_device(dst, src, stream); - - // second part - src = mles.subspan(n * mle_index + ap, part2_size); - dst = basct::subspan(partial_mles, partial_length * mle_index + part1_size, part2_size); - if (!src.empty()) { - basdv::async_copy_host_to_device(dst, src, stream); - } - } -} - -//-------------------------------------------------------------------------------------------------- -// copy_folded_mles -//-------------------------------------------------------------------------------------------------- -void copy_folded_mles(basct::span host_mles, basdv::stream& stream, - basct::cspan device_mles, unsigned np, unsigned a, - unsigned b) noexcept { - auto num_mles = host_mles.size() / np; - auto slice_n = device_mles.size() / num_mles; - auto slice_np = b - a; - SXT_DEBUG_ASSERT( - // clang-format off - host_mles.size() == num_mles * np && - device_mles.size() == num_mles * slice_n && - b <= np - // clang-format on - ); - for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { - auto src = device_mles.subspan(mle_index * slice_n, slice_np); - auto dst = host_mles.subspan(mle_index * np + a, slice_np); - basdv::async_copy_device_to_host(dst, src, stream); - } -} - -//-------------------------------------------------------------------------------------------------- -// get_gpu_memory_fraction -//-------------------------------------------------------------------------------------------------- -double get_gpu_memory_fraction(basct::cspan mles) noexcept { - auto total_memory = static_cast(basdv::get_total_device_memory()); - return static_cast(mles.size() * sizeof(s25t::element)) / total_memory; -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/mle_utility.h b/sxt/proof/sumcheck/mle_utility.h deleted file mode 100644 index 9f3285a89..000000000 --- a/sxt/proof/sumcheck/mle_utility.h +++ /dev/null @@ -1,48 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/base/container/span.h" -#include "sxt/memory/management/managed_array_fwd.h" - -namespace sxt::basdv { -class stream; -} -namespace sxt::s25t { -class element; -} - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// copy_partial_mles -//-------------------------------------------------------------------------------------------------- -void copy_partial_mles(memmg::managed_array& partial_mles, basdv::stream& stream, - basct::cspan mles, unsigned n, unsigned a, - unsigned b) noexcept; - -//-------------------------------------------------------------------------------------------------- -// copy_folded_mles -//-------------------------------------------------------------------------------------------------- -void copy_folded_mles(basct::span host_mles, basdv::stream& stream, - basct::cspan device_mles, unsigned np, unsigned a, - unsigned b) noexcept; - -//-------------------------------------------------------------------------------------------------- -// get_gpu_memory_fraction -//-------------------------------------------------------------------------------------------------- -double get_gpu_memory_fraction(basct::cspan mles) noexcept; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/mle_utility.t.cc b/sxt/proof/sumcheck/mle_utility.t.cc deleted file mode 100644 index f930943dd..000000000 --- a/sxt/proof/sumcheck/mle_utility.t.cc +++ /dev/null @@ -1,94 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/mle_utility.h" - -#include - -#include "sxt/base/device/stream.h" -#include "sxt/base/device/synchronization.h" -#include "sxt/base/test/unit_test.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/memory/resource/managed_device_resource.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -using namespace sxt; -using namespace sxt::prfsk; -using s25t::operator""_s25; - -TEST_CASE("we can copy a slice of mles to device memory") { - std::pmr::vector mles{memr::get_managed_device_resource()}; - memmg::managed_array partial_mles{memr::get_managed_device_resource()}; - - basdv::stream stream; - - SECTION("we can copy an mle with a single element") { - mles = {0x123_s25}; - copy_partial_mles(partial_mles, stream, mles, 1, 0, 1); - basdv::synchronize_stream(stream); - memmg::managed_array expected = {0x123_s25}; - REQUIRE(partial_mles == expected); - } - - SECTION("we can copy a slice of MLEs") { - mles = {0x1_s25, 0x2_s25, 0x3_s25, 0x4_s25, 0x5_s25, 0x6_s25}; - copy_partial_mles(partial_mles, stream, mles, 3, 0, 1); - basdv::synchronize_stream(stream); - memmg::managed_array expected = {0x1_s25, 0x3_s25, 0x4_s25, 0x6_s25}; - REQUIRE(partial_mles == expected); - } -} - -TEST_CASE("we can copy partially folded MLEs to the host") { - std::pmr::vector device_mles{memr::get_managed_device_resource()}; - std::vector host_mles; - - basdv::stream stream; - - SECTION("we can copy a single element") { - device_mles = {0x123_s25}; - host_mles.resize(1); - copy_folded_mles(host_mles, stream, device_mles, 1, 0, 1); - basdv::synchronize_stream(stream); - std::vector expected = {0x123_s25}; - REQUIRE(host_mles == expected); - } - - SECTION("we can copy partially folded MLEs") { - device_mles = {0x123_s25, 0x456_s25}; - host_mles.resize(4); - copy_folded_mles(host_mles, stream, device_mles, 2, 0, 1); - basdv::synchronize_stream(stream); - std::vector expected = {0x123_s25, 0x0_s25, 0x456_s25, 0x0_s25}; - REQUIRE(host_mles == expected); - } -} - -TEST_CASE("we can query the fraction of device memory taken by MLEs") { - std::vector mles; - - SECTION("we handle the zero case") { REQUIRE(get_gpu_memory_fraction(mles) == 0.0); } - - SECTION("the fractions doubles if the length of mles doubles") { - mles.resize(1); - auto f1 = get_gpu_memory_fraction(mles); - REQUIRE(f1 > 0); - mles.resize(2); - auto f2 = get_gpu_memory_fraction(mles); - REQUIRE(f2 == Catch::Approx(2 * f1)); - } -} diff --git a/sxt/proof/sumcheck/reference_transcript.cc b/sxt/proof/sumcheck/reference_transcript.cc deleted file mode 100644 index fcbeb6b5b..000000000 --- a/sxt/proof/sumcheck/reference_transcript.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/reference_transcript.h" - -#include "sxt/proof/transcript/transcript_utility.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// constructor -//-------------------------------------------------------------------------------------------------- -reference_transcript::reference_transcript(prft::transcript& transcript) noexcept - : transcript_{transcript} {} - -//-------------------------------------------------------------------------------------------------- -// init -//-------------------------------------------------------------------------------------------------- -void reference_transcript::init(size_t num_variables, size_t round_degree) noexcept { - prft::set_domain(transcript_, "sumcheck proof v1"); - prft::append_value(transcript_, "n", num_variables); - prft::append_value(transcript_, "k", round_degree); -} - -//-------------------------------------------------------------------------------------------------- -// round_challenge -//-------------------------------------------------------------------------------------------------- -void reference_transcript::round_challenge(s25t::element& r, - basct::cspan polynomial) noexcept { - prft::append_values(transcript_, "P", polynomial); - prft::challenge_value(r, transcript_, "R"); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/reference_transcript.h b/sxt/proof/sumcheck/reference_transcript.h deleted file mode 100644 index 3ee41ed0b..000000000 --- a/sxt/proof/sumcheck/reference_transcript.h +++ /dev/null @@ -1,37 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/proof/sumcheck/sumcheck_transcript.h" -#include "sxt/proof/transcript/transcript.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// reference_transcript -//-------------------------------------------------------------------------------------------------- -class reference_transcript final : public sumcheck_transcript { -public: - explicit reference_transcript(prft::transcript& transcript) noexcept; - - void init(size_t num_variables, size_t round_degree) noexcept override; - - void round_challenge(s25t::element& r, basct::cspan polynomial) noexcept override; - -private: - prft::transcript& transcript_; -}; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/reference_transcript.t.cc b/sxt/proof/sumcheck/reference_transcript.t.cc deleted file mode 100644 index 890312c51..000000000 --- a/sxt/proof/sumcheck/reference_transcript.t.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/reference_transcript.h" - -#include - -#include "sxt/base/test/unit_test.h" -#include "sxt/proof/transcript/transcript.h" -#include "sxt/scalar25/type/literal.h" - -using namespace sxt; -using namespace sxt::prfsk; -using s25t::operator""_s25; - -TEST_CASE("we provide an implementation of sumcheck transcript") { - prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; - std::vector p = {0x123_s25}; - s25t::element r, rp; - - SECTION("we don't draw the same challenge from a transcript") { - transcript.round_challenge(r, p); - transcript.round_challenge(rp, p); - REQUIRE(r != rp); - - prft::transcript base_transcript_p{"abc"}; - reference_transcript transcript_p{base_transcript_p}; - p[0] = 0x456_s25; - transcript_p.round_challenge(rp, p); - REQUIRE(r != rp); - } - - SECTION("init_transcript produces different results based on parameters") { - transcript.init(1, 2); - transcript.round_challenge(r, p); - - prft::transcript base_transcript_p{"abc"}; - reference_transcript transcript_p{base_transcript_p}; - transcript.init(2, 1); - transcript.round_challenge(rp, p); - - REQUIRE(r != rp); - } -} diff --git a/sxt/proof/sumcheck/sum_gpu.cc b/sxt/proof/sumcheck/sum_gpu.cc deleted file mode 100644 index 8066fd249..000000000 --- a/sxt/proof/sumcheck/sum_gpu.cc +++ /dev/null @@ -1,231 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/sum_gpu.h" - -#include - -#include "sxt/algorithm/reduction/kernel_fit.h" -#include "sxt/algorithm/reduction/thread_reduction.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/state.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/iterator/split.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/base/num/constexpr_switch.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/device/for_each.h" -#include "sxt/execution/kernel/kernel_dims.h" -#include "sxt/execution/kernel/launch.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/memory/resource/async_device_resource.h" -#include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck/constant.h" -#include "sxt/proof/sumcheck/device_cache.h" -#include "sxt/proof/sumcheck/mle_utility.h" -#include "sxt/proof/sumcheck/polynomial_mapper.h" -#include "sxt/proof/sumcheck/reduction_gpu.h" -#include "sxt/scalar25/operation/add.h" -#include "sxt/scalar25/operation/mul.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// polynomial_reducer -//-------------------------------------------------------------------------------------------------- -namespace { -template struct polynomial_reducer { - using value_type = std::array; - - CUDA_CALLABLE static void accumulate_inplace(value_type& res, const value_type& e) noexcept { - for (unsigned i = 0; i < res.size(); ++i) { - s25o::add(res[i], res[i], e[i]); - } - } -}; -} // namespace - -//-------------------------------------------------------------------------------------------------- -// partial_sum_kernel_impl -//-------------------------------------------------------------------------------------------------- -template -__device__ static void partial_sum_kernel_impl(s25t::element* __restrict__ shared_data, - const s25t::element* __restrict__ mles, - const unsigned* __restrict__ product_terms, - unsigned split, unsigned n) noexcept { - using Mapper = polynomial_mapper; - using Reducer = polynomial_reducer; - using T = Mapper::value_type; - Mapper mapper{ - .mles = mles, - .product_terms = product_terms, - .split = split, - .n = n, - }; - auto index = blockIdx.x * (BlockSize * 2) + threadIdx.x; - auto step = BlockSize * 2 * gridDim.x; - algr::thread_reduce(reinterpret_cast(shared_data), mapper, split, step, - threadIdx.x, index); -} - -//-------------------------------------------------------------------------------------------------- -// partial_sum_kernel -//-------------------------------------------------------------------------------------------------- -template -__global__ static void -partial_sum_kernel(s25t::element* __restrict__ out, const s25t::element* __restrict__ mles, - const std::pair* __restrict__ product_table, - const unsigned* __restrict__ product_terms, unsigned num_coefficients, - unsigned split, unsigned n) noexcept { - auto product_index = blockIdx.y; - auto num_terms = product_table[product_index].second; - auto thread_index = threadIdx.x; - auto output_step = gridDim.x * gridDim.y; - - // shared data for reduction - __shared__ s25t::element shared_data[2 * BlockSize * (max_degree_v + 1u)]; - - // adjust pointers - out += blockIdx.x; - out += gridDim.x * product_index; - for (unsigned i = 0; i < product_index; ++i) { - product_terms += product_table[i].second; - } - - // sum - basn::constexpr_switch<1, max_degree_v + 1u>( - num_terms, [&](std::integral_constant) noexcept { - partial_sum_kernel_impl(shared_data, mles, product_terms, split, n); - }); - - // write out result - auto mult = product_table[product_index].first; - for (unsigned i = thread_index; i < num_coefficients; i += BlockSize) { - auto output_index = output_step * i; - if (i < num_terms + 1u) { - s25o::mul(out[output_index], mult, shared_data[i]); - } else { - out[output_index] = s25t::element{}; - } - } -} - -//-------------------------------------------------------------------------------------------------- -// partial_sum -//-------------------------------------------------------------------------------------------------- -static xena::future<> partial_sum(basct::span p, basdv::stream& stream, - basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned split, - unsigned n) noexcept { - auto num_coefficients = p.size(); - auto num_products = product_table.size(); - auto dims = algr::fit_reduction_kernel(split); - memr::async_device_resource resource{stream}; - - // partials - memmg::managed_array partials{num_coefficients * dims.num_blocks * num_products, - &resource}; - xenk::launch_kernel(dims.block_size, [&]( - std::integral_constant) noexcept { - partial_sum_kernel<<>>( - partials.data(), mles.data(), product_table.data(), product_terms.data(), num_coefficients, - split, n); - }); - - // reduce partials - co_await reduce_sums(p, stream, partials); -} - -//-------------------------------------------------------------------------------------------------- -// sum_gpu -//-------------------------------------------------------------------------------------------------- -xena::future<> sum_gpu(basct::span p, device_cache& cache, - const basit::split_options& options, basct::cspan mles, - unsigned n) noexcept { - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto mid = 1u << (num_variables - 1u); - auto num_mles = mles.size() / n; - auto num_coefficients = p.size(); - - // split - auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, mid}, options); - - // sum - size_t counter = 0; - co_await xendv::concurrent_for_each( - chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> { - basdv::stream stream; - memr::async_device_resource resource{stream}; - - // copy partial mles to device - memmg::managed_array partial_mles{&resource}; - copy_partial_mles(partial_mles, stream, mles, n, rng.a(), rng.b()); - auto split = rng.b() - rng.a(); - auto np = partial_mles.size() / num_mles; - - // lookup problem descriptor - basct::cspan> product_table; - basct::cspan product_terms; - cache.lookup(product_table, product_terms, stream); - - // compute - memmg::managed_array partial_p(num_coefficients); - co_await partial_sum(partial_p, stream, partial_mles, product_table, product_terms, split, - np); - - // fill in the result - if (counter == 0) { - for (unsigned i = 0; i < num_coefficients; ++i) { - p[i] = partial_p[i]; - } - } else { - for (unsigned i = 0; i < num_coefficients; ++i) { - s25o::add(p[i], p[i], partial_p[i]); - } - } - ++counter; - }); -} - -xena::future<> sum_gpu(basct::span p, device_cache& cache, - basct::cspan mles, unsigned n) noexcept { - basit::split_options options{ - .min_chunk_size = 100'000u, - .max_chunk_size = 200'000u, - .split_factor = basdv::get_num_devices(), - }; - co_await sum_gpu(p, cache, options, mles, n); -} - -xena::future<> sum_gpu(basct::span p, basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) noexcept { - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto mid = 1u << (num_variables - 1u); - SXT_DEBUG_ASSERT( - // clang-format off - basdv::is_host_pointer(p.data()) && - basdv::is_active_device_pointer(mles.data()) && - basdv::is_active_device_pointer(product_table.data()) && - basdv::is_active_device_pointer(product_terms.data()) - // clang-format on - ); - basdv::stream stream; - co_await partial_sum(p, stream, mles, product_table, product_terms, mid, n); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/sum_gpu.h b/sxt/proof/sumcheck/sum_gpu.h deleted file mode 100644 index a693c7da3..000000000 --- a/sxt/proof/sumcheck/sum_gpu.h +++ /dev/null @@ -1,58 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "sxt/base/container/span.h" -#include "sxt/base/device/property.h" -#include "sxt/execution/async/future_fwd.h" - -namespace sxt::s25t { -class element; -} - -namespace sxt::basit { -struct split_options; -} - -namespace sxt::prfsk { -class device_cache; - -//-------------------------------------------------------------------------------------------------- -// sum_options -//-------------------------------------------------------------------------------------------------- -struct sum_options { - unsigned min_chunk_size = 100'000u; - unsigned max_chunk_size = 250'000u; - unsigned split_factor = unsigned(basdv::get_num_devices()); -}; - -//-------------------------------------------------------------------------------------------------- -// sum_gpu -//-------------------------------------------------------------------------------------------------- -xena::future<> sum_gpu(basct::span p, device_cache& cache, - const basit::split_options& options, basct::cspan mles, - unsigned n) noexcept; - -xena::future<> sum_gpu(basct::span p, device_cache& cache, - basct::cspan mles, unsigned n) noexcept; - -xena::future<> sum_gpu(basct::span p, basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) noexcept; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/sum_gpu.t.cc b/sxt/proof/sumcheck/sum_gpu.t.cc deleted file mode 100644 index 35807fd55..000000000 --- a/sxt/proof/sumcheck/sum_gpu.t.cc +++ /dev/null @@ -1,138 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/sum_gpu.h" - -#include - -#include "sxt/base/iterator/split.h" -#include "sxt/base/test/unit_test.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/schedule/scheduler.h" -#include "sxt/proof/sumcheck/device_cache.h" -#include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -using namespace sxt; -using namespace sxt::prfsk; -using s25t::operator""_s25; - -TEST_CASE("we can sum MLEs") { - std::vector> product_table; - std::vector product_terms; - std::vector mles; - std::vector p(2); - - SECTION("we can sum an MLE with a single term and n=1") { - product_table = {{0x1_s25, 1}}; - product_terms = {0}; - device_cache cache{product_table, product_terms}; - mles = {0x123_s25}; - auto fut = sum_gpu(p, cache, mles, 1); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == mles[0]); - REQUIRE(p[1] == -mles[0]); - } - - SECTION("we can sum an MLE with a single term, n=1, and a non-unity multiplier") { - product_table = {{0x2_s25, 1}}; - product_terms = {0}; - device_cache cache{product_table, product_terms}; - mles = {0x123_s25}; - auto fut = sum_gpu(p, cache, mles, 1); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == product_table[0].first * mles[0]); - REQUIRE(p[1] == -product_table[0].first * mles[0]); - } - - SECTION("we can sum an MLE with a single term and n=2") { - product_table = {{0x1_s25, 1}}; - product_terms = {0}; - device_cache cache{product_table, product_terms}; - mles = {0x123_s25, 0x456_s25}; - auto fut = sum_gpu(p, cache, mles, 2); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == mles[0]); - REQUIRE(p[1] == mles[1] - mles[0]); - } - - SECTION("we can sum an MLE with multiple terms and n=1") { - p.resize(3); - product_table = {{0x1_s25, 2}}; - product_terms = {0, 1}; - device_cache cache{product_table, product_terms}; - mles = {0x123_s25, 0x456_s25}; - auto fut = sum_gpu(p, cache, mles, 1); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == mles[0] * mles[1]); - REQUIRE(p[1] == -mles[0] * mles[1] - mles[1] * mles[0]); - REQUIRE(p[2] == mles[0] * mles[1]); - } - - SECTION("we can sum multiple mles") { - product_table = { - {0x1_s25, 1}, - {0x1_s25, 1}, - }; - product_terms = {0, 1}; - device_cache cache{product_table, product_terms}; - mles = {0x123_s25, 0x456_s25}; - auto fut = sum_gpu(p, cache, mles, 1); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == mles[0] + mles[1]); - REQUIRE(p[1] == -mles[0] - mles[1]); - } - - SECTION("we can chunk sums with n=4") { - product_table = {{0x1_s25, 1}}; - product_terms = {0}; - device_cache cache{product_table, product_terms}; - mles = {0x123_s25, 0x456_s25, 0x789_s25, 0x91011_s25}; - basit::split_options options{ - .min_chunk_size = 1, - .max_chunk_size = 1, - .split_factor = 2, - }; - auto fut = sum_gpu(p, cache, options, mles, 4); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == mles[0] + mles[1]); - REQUIRE(p[1] == (mles[2] - mles[0]) + (mles[3] - mles[1])); - } - - SECTION("we can chunk sums with n=4") { - product_table = {{0x1_s25, 1}}; - product_terms = {0}; - device_cache cache{product_table, product_terms}; - mles = {0x2_s25, 0x4_s25, 0x7_s25, 0x9_s25}; - basit::split_options options{ - .min_chunk_size = 16, - .max_chunk_size = 16, - .split_factor = 2, - }; - auto fut = sum_gpu(p, cache, options, mles, 4); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == mles[0] + mles[1]); - REQUIRE(p[1] == (mles[2] - mles[0]) + (mles[3] - mles[1])); - } -} diff --git a/sxt/proof/sumcheck/sumcheck_random.cc b/sxt/proof/sumcheck/sumcheck_random.cc deleted file mode 100644 index d3f164e3a..000000000 --- a/sxt/proof/sumcheck/sumcheck_random.cc +++ /dev/null @@ -1,77 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/sumcheck_random.h" - -#include - -#include "sxt/base/error/assert.h" -#include "sxt/base/num/fast_random_number_generator.h" -#include "sxt/scalar25/random/element.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// generate_random_sumcheck_problem -//-------------------------------------------------------------------------------------------------- -void generate_random_sumcheck_problem( - std::vector& mles, - std::vector>& product_table, - std::vector& product_terms, unsigned& n, basn::fast_random_number_generator& rng, - const random_sumcheck_descriptor& descriptor) noexcept { - std::mt19937 rng_p{rng()}; - - // n - SXT_RELEASE_ASSERT(descriptor.min_length <= descriptor.max_length); - std::uniform_int_distribution n_dist{descriptor.min_length, descriptor.max_length}; - n = n_dist(rng_p); - - // num_mles - SXT_RELEASE_ASSERT(descriptor.min_num_mles <= descriptor.max_num_mles); - std::uniform_int_distribution num_mles_dist{descriptor.min_num_mles, - descriptor.max_num_mles}; - auto num_mles = num_mles_dist(rng_p); - - // num_products - SXT_RELEASE_ASSERT(descriptor.min_num_products <= descriptor.max_num_products); - std::uniform_int_distribution num_products_dist{descriptor.min_num_products, - descriptor.max_num_products}; - auto num_products = num_products_dist(rng_p); - - // mles - mles.resize(n * num_mles); - s25rn::generate_random_elements(mles, rng); - - // product_table - unsigned num_terms = 0; - product_table.resize(num_products); - SXT_RELEASE_ASSERT(descriptor.min_product_length <= descriptor.max_product_length); - std::uniform_int_distribution product_length_dist{descriptor.min_product_length, - descriptor.max_product_length}; - for (auto& [s, len] : product_table) { - s25rn::generate_random_element(s, rng); - len = product_length_dist(rng_p); - num_terms += len; - } - - // product_terms - product_terms.resize(num_terms); - std::uniform_int_distribution mle_dist{0, num_mles - 1}; - for (auto& term : product_terms) { - term = mle_dist(rng_p); - } -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/sumcheck_random.h b/sxt/proof/sumcheck/sumcheck_random.h deleted file mode 100644 index f58cd0f8e..000000000 --- a/sxt/proof/sumcheck/sumcheck_random.h +++ /dev/null @@ -1,57 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -#include "sxt/proof/sumcheck/constant.h" - -namespace sxt::s25t { -class element; -} -namespace sxt::basn { -class fast_random_number_generator; -} - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// random_sumcheck_descriptor -//-------------------------------------------------------------------------------------------------- -struct random_sumcheck_descriptor { - unsigned min_length = 1; - unsigned max_length = 10; - - unsigned min_num_products = 1; - unsigned max_num_products = 5; - - unsigned min_product_length = 2; - unsigned max_product_length = max_degree_v; - - unsigned min_num_mles = 1; - unsigned max_num_mles = 5; -}; - -//-------------------------------------------------------------------------------------------------- -// generate_random_sumcheck_problem -//-------------------------------------------------------------------------------------------------- -void generate_random_sumcheck_problem( - std::vector& mles, - std::vector>& product_table, - std::vector& product_terms, unsigned& n, basn::fast_random_number_generator& rng, - const random_sumcheck_descriptor& descriptor) noexcept; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/sumcheck_transcript.cc b/sxt/proof/sumcheck/sumcheck_transcript.cc deleted file mode 100644 index 36b1ecd89..000000000 --- a/sxt/proof/sumcheck/sumcheck_transcript.cc +++ /dev/null @@ -1,17 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/sumcheck_transcript.h" diff --git a/sxt/proof/sumcheck/sumcheck_transcript.h b/sxt/proof/sumcheck/sumcheck_transcript.h deleted file mode 100644 index 2d690ded2..000000000 --- a/sxt/proof/sumcheck/sumcheck_transcript.h +++ /dev/null @@ -1,40 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "sxt/base/container/span.h" - -namespace sxt::s25t { -class element; -} - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// sumcheck_transcript -//-------------------------------------------------------------------------------------------------- -class sumcheck_transcript { -public: - virtual ~sumcheck_transcript() noexcept = default; - - virtual void init(size_t num_variables, size_t round_degree) noexcept = 0; - - virtual void round_challenge(s25t::element& r, - basct::cspan polynomial) noexcept = 0; -}; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/workspace.cc b/sxt/proof/sumcheck/workspace.cc deleted file mode 100644 index d356b4af7..000000000 --- a/sxt/proof/sumcheck/workspace.cc +++ /dev/null @@ -1,17 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/workspace.h" diff --git a/sxt/proof/sumcheck/workspace.h b/sxt/proof/sumcheck/workspace.h deleted file mode 100644 index edacbd2b5..000000000 --- a/sxt/proof/sumcheck/workspace.h +++ /dev/null @@ -1,27 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// workspace -//-------------------------------------------------------------------------------------------------- -class workspace { -public: - virtual ~workspace() noexcept = default; -}; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck2/BUILD index 41937f96d..47dfd00fe 100644 --- a/sxt/proof/sumcheck2/BUILD +++ b/sxt/proof/sumcheck2/BUILD @@ -192,6 +192,9 @@ sxt_cc_component( name = "polynomial_mapper", test_deps = [ "//sxt/base/test:unit_test", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/realization:field", + "//sxt/scalar25/type:literal", ], deps = [ ":polynomial_utility", diff --git a/sxt/proof/sumcheck2/polynomial_mapper.t.cc b/sxt/proof/sumcheck2/polynomial_mapper.t.cc index 57b837240..11d26d27e 100644 --- a/sxt/proof/sumcheck2/polynomial_mapper.t.cc +++ b/sxt/proof/sumcheck2/polynomial_mapper.t.cc @@ -16,6 +16,48 @@ */ #include "sxt/proof/sumcheck2/polynomial_mapper.h" +#include + #include "sxt/base/test/unit_test.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/realization/field.h" +#include "sxt/scalar25/type/literal.h" + +using namespace sxt; +using namespace sxt::prfsk2; +using s25t::operator""_s25; + +using T = s25t::element; + +TEST_CASE("we can map indexes to expanded polynomials") { + std::vector mles; + std::vector product_terms; + + SECTION("we can map a single element mle") { + mles = {0x123_s25}; + product_terms = {0}; + polynomial_mapper<1, T> m{ + .mles = mles.data(), + .product_terms = product_terms.data(), + .split = 1, + .n = 1, + }; + auto p = m.map_index(0); + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == -mles[0]); + } -TEST_CASE("todo") {} + SECTION("we can map an mle with two elements") { + mles = {0x123_s25, 0x456_s25}; + product_terms = {0}; + polynomial_mapper<1, T> m{ + .mles = mles.data(), + .product_terms = product_terms.data(), + .split = 1, + .n = 2, + }; + auto p = m.map_index(0); + REQUIRE(p[0] == mles[0]); + REQUIRE(p[1] == mles[1] - mles[0]); + } +} From 4f236ca5b220329423584da31e254a2eeef92c0f Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 19:59:37 -0800 Subject: [PATCH 80/83] rework sumcheck --- sxt/proof/sumcheck/BUILD | 157 -------------- sxt/proof/sumcheck/polynomial_mapper.cc | 17 -- sxt/proof/sumcheck/polynomial_mapper.h | 51 ----- sxt/proof/sumcheck/polynomial_mapper.t.cc | 61 ------ sxt/proof/sumcheck/polynomial_utility.cc | 137 ------------ sxt/proof/sumcheck/polynomial_utility.h | 55 ----- sxt/proof/sumcheck/polynomial_utility.t.cc | 131 ------------ sxt/proof/sumcheck/proof_computation.cc | 71 ------- sxt/proof/sumcheck/proof_computation.h | 44 ---- sxt/proof/sumcheck/proof_computation.t.cc | 232 --------------------- sxt/proof/sumcheck/reduction_gpu.cc | 114 ---------- sxt/proof/sumcheck/reduction_gpu.h | 35 ---- sxt/proof/sumcheck/reduction_gpu.t.cc | 67 ------ sxt/proof/sumcheck/verification.cc | 77 ------- sxt/proof/sumcheck/verification.h | 39 ---- sxt/proof/sumcheck/verification.t.cc | 129 ------------ 16 files changed, 1417 deletions(-) delete mode 100644 sxt/proof/sumcheck/BUILD delete mode 100644 sxt/proof/sumcheck/polynomial_mapper.cc delete mode 100644 sxt/proof/sumcheck/polynomial_mapper.h delete mode 100644 sxt/proof/sumcheck/polynomial_mapper.t.cc delete mode 100644 sxt/proof/sumcheck/polynomial_utility.cc delete mode 100644 sxt/proof/sumcheck/polynomial_utility.h delete mode 100644 sxt/proof/sumcheck/polynomial_utility.t.cc delete mode 100644 sxt/proof/sumcheck/proof_computation.cc delete mode 100644 sxt/proof/sumcheck/proof_computation.h delete mode 100644 sxt/proof/sumcheck/proof_computation.t.cc delete mode 100644 sxt/proof/sumcheck/reduction_gpu.cc delete mode 100644 sxt/proof/sumcheck/reduction_gpu.h delete mode 100644 sxt/proof/sumcheck/reduction_gpu.t.cc delete mode 100644 sxt/proof/sumcheck/verification.cc delete mode 100644 sxt/proof/sumcheck/verification.h delete mode 100644 sxt/proof/sumcheck/verification.t.cc diff --git a/sxt/proof/sumcheck/BUILD b/sxt/proof/sumcheck/BUILD deleted file mode 100644 index a80a29ed6..000000000 --- a/sxt/proof/sumcheck/BUILD +++ /dev/null @@ -1,157 +0,0 @@ -# sxt_cc_component( -# name = "polynomial_utility", -# impl_deps = [ -# "//sxt/scalar25/operation:add", -# "//sxt/scalar25/operation:mul", -# "//sxt/scalar25/operation:sub", -# "//sxt/scalar25/operation:muladd", -# "//sxt/scalar25/operation:neg", -# "//sxt/scalar25/type:element", -# ], -# test_deps = [ -# "//sxt/base/test:unit_test", -# "//sxt/scalar25/operation:overload", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# deps = [ -# "//sxt/base/container:span", -# ], -# ) -# -# sxt_cc_component( -# name = "chunked_gpu_driver", -# impl_deps = [ -# ":device_cache", -# ":fold_gpu", -# ":gpu_driver", -# ":mle_utility", -# ":sum_gpu", -# "//sxt/algorithm/iteration:transform", -# "//sxt/base/error:assert", -# "//sxt/base/num:ceil_log2", -# "//sxt/execution/async:coroutine", -# "//sxt/execution/async:future", -# "//sxt/memory/management:managed_array", -# "//sxt/scalar25/operation:sub", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# test_deps = [ -# ":driver_test", -# "//sxt/base/test:unit_test", -# ], -# deps = [ -# ":driver", -# ], -# ) -# -# sxt_cc_component( -# name = "proof_computation", -# impl_deps = [ -# ":driver", -# ":sumcheck_transcript", -# "//sxt/execution/async:future", -# "//sxt/base/error:assert", -# "//sxt/base/num:ceil_log2", -# "//sxt/execution/async:coroutine", -# "//sxt/scalar25/type:element", -# ], -# test_deps = [ -# ":chunked_gpu_driver", -# ":cpu_driver", -# ":gpu_driver", -# ":mle_utility", -# ":reference_transcript", -# ":sumcheck_random", -# ":verification", -# "//sxt/base/test:unit_test", -# "//sxt/execution/async:future", -# "//sxt/execution/schedule:scheduler", -# "//sxt/proof/transcript", -# "//sxt/scalar25/operation:overload", -# "//sxt/scalar25/type:element", -# ], -# deps = [ -# "//sxt/base/container:span", -# "//sxt/execution/async:future_fwd", -# ], -# ) -# -# sxt_cc_component( -# name = "verification", -# impl_deps = [ -# ":polynomial_utility", -# ":sumcheck_transcript", -# "//sxt/base/error:assert", -# "//sxt/base/log:log", -# "//sxt/scalar25/operation:overload", -# "//sxt/scalar25/type:element", -# ], -# test_deps = [ -# ":reference_transcript", -# "//sxt/base/test:unit_test", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# deps = [ -# "//sxt/base/container:span", -# ], -# ) -# -# sxt_cc_component( -# name = "reduction_gpu", -# impl_deps = [ -# "//sxt/algorithm/base:identity_mapper", -# "//sxt/algorithm/reduction:kernel_fit", -# "//sxt/algorithm/reduction:thread_reduction", -# "//sxt/base/device:memory_utility", -# "//sxt/base/device:stream", -# "//sxt/base/error:assert", -# "//sxt/execution/async:coroutine", -# "//sxt/execution/async:future", -# "//sxt/execution/device:synchronization", -# "//sxt/execution/kernel:kernel_dims", -# "//sxt/execution/kernel:launch", -# "//sxt/memory/management:managed_array", -# "//sxt/memory/resource:async_device_resource", -# "//sxt/scalar25/operation:add", -# "//sxt/scalar25/operation:accumulator", -# "//sxt/scalar25/type:element", -# ], -# test_deps = [ -# "//sxt/base/device:property", -# "//sxt/base/device:stream", -# "//sxt/base/test:unit_test", -# "//sxt/execution/async:future", -# "//sxt/execution/schedule:scheduler", -# "//sxt/memory/resource:managed_device_resource", -# "//sxt/scalar25/operation:overload", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# deps = [ -# "//sxt/base/container:span", -# "//sxt/execution/async:future_fwd", -# ], -# ) -# -# sxt_cc_component( -# name = "polynomial_mapper", -# test_deps = [ -# "//sxt/algorithm/base:mapper", -# "//sxt/base/test:unit_test", -# "//sxt/scalar25/operation:overload", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# deps = [ -# ":polynomial_utility", -# "//sxt/base/macro:cuda_callable", -# "//sxt/scalar25/operation:add", -# "//sxt/scalar25/operation:mul", -# "//sxt/scalar25/operation:muladd", -# "//sxt/scalar25/type:element", -# "//sxt/scalar25/type:literal", -# ], -# ) diff --git a/sxt/proof/sumcheck/polynomial_mapper.cc b/sxt/proof/sumcheck/polynomial_mapper.cc deleted file mode 100644 index a46364a4e..000000000 --- a/sxt/proof/sumcheck/polynomial_mapper.cc +++ /dev/null @@ -1,17 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/polynomial_mapper.h" diff --git a/sxt/proof/sumcheck/polynomial_mapper.h b/sxt/proof/sumcheck/polynomial_mapper.h deleted file mode 100644 index a10dc935a..000000000 --- a/sxt/proof/sumcheck/polynomial_mapper.h +++ /dev/null @@ -1,51 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/base/macro/cuda_callable.h" -#include "sxt/proof/sumcheck/polynomial_utility.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// polynomial_mapper -//-------------------------------------------------------------------------------------------------- -template struct polynomial_mapper { - using value_type = std::array; - - CUDA_CALLABLE - value_type map_index(unsigned index) const noexcept { - value_type res; - this->map_index(res, index); - return res; - } - - CUDA_CALLABLE - void map_index(value_type& p, unsigned index) const noexcept { - if (index + split < n) { - expand_products(p, mles + index, n, split, {product_terms, Degree}); - } else { - partial_expand_products(p, mles + index, n, {product_terms, Degree}); - } - } - - const s25t::element* __restrict__ mles; - const unsigned* __restrict__ product_terms; - unsigned split; - unsigned n; -}; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_mapper.t.cc b/sxt/proof/sumcheck/polynomial_mapper.t.cc deleted file mode 100644 index 6c0d85827..000000000 --- a/sxt/proof/sumcheck/polynomial_mapper.t.cc +++ /dev/null @@ -1,61 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/polynomial_mapper.h" - -#include - -#include "sxt/base/test/unit_test.h" -#include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -using namespace sxt; -using namespace sxt::prfsk; -using s25t::operator""_s25; - -TEST_CASE("we can map indexes to expanded polynomials") { - std::vector mles; - std::vector product_terms; - - SECTION("we can map a single element mle") { - mles = {0x123_s25}; - product_terms = {0}; - polynomial_mapper<1> m{ - .mles = mles.data(), - .product_terms = product_terms.data(), - .split = 1, - .n = 1, - }; - auto p = m.map_index(0); - REQUIRE(p[0] == mles[0]); - REQUIRE(p[1] == -mles[0]); - } - - SECTION("we can map an mle with two elements") { - mles = {0x123_s25, 0x456_s25}; - product_terms = {0}; - polynomial_mapper<1> m{ - .mles = mles.data(), - .product_terms = product_terms.data(), - .split = 1, - .n = 2, - }; - auto p = m.map_index(0); - REQUIRE(p[0] == mles[0]); - REQUIRE(p[1] == mles[1] - mles[0]); - } -} diff --git a/sxt/proof/sumcheck/polynomial_utility.cc b/sxt/proof/sumcheck/polynomial_utility.cc deleted file mode 100644 index 4ba3d9169..000000000 --- a/sxt/proof/sumcheck/polynomial_utility.cc +++ /dev/null @@ -1,137 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/polynomial_utility.h" - -#include - -#include "sxt/scalar25/operation/add.h" -#include "sxt/scalar25/operation/mul.h" -#include "sxt/scalar25/operation/muladd.h" -#include "sxt/scalar25/operation/neg.h" -#include "sxt/scalar25/operation/sub.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// sum_polynomial_01 -//-------------------------------------------------------------------------------------------------- -void sum_polynomial_01(s25t::element& e, basct::cspan polynomial) noexcept { - if (polynomial.empty()) { - e = s25t::element{}; - return; - } - e = polynomial[0]; - for (unsigned i = 0; i < polynomial.size(); ++i) { - s25o::add(e, e, polynomial[i]); - } -} - -//-------------------------------------------------------------------------------------------------- -// evaluate_polynomial -//-------------------------------------------------------------------------------------------------- -void evaluate_polynomial(s25t::element& e, basct::cspan polynomial, - const s25t::element& x) noexcept { - if (polynomial.empty()) { - e = s25t::element{}; - return; - } - auto i = polynomial.size(); - --i; - e = polynomial[i]; - while (i > 0) { - --i; - s25o::muladd(e, e, x, polynomial[i]); - } -} - -//-------------------------------------------------------------------------------------------------- -// expand_products -//-------------------------------------------------------------------------------------------------- -CUDA_CALLABLE -void expand_products(basct::span p, const s25t::element* mles, unsigned n, - unsigned step, basct::cspan terms) noexcept { - auto num_terms = terms.size(); - assert( - // clang-format off - num_terms > 0 && - n > step && - p.size() == num_terms + 1u - // clang-format on - ); - s25t::element a, b; - auto mle_index = terms[0]; - a = *(mles + mle_index * n); - b = *(mles + mle_index * n + step); - s25o::sub(b, b, a); - p[0] = a; - p[1] = b; - - for (unsigned i = 1; i < num_terms; ++i) { - auto mle_index = terms[i]; - a = *(mles + mle_index * n); - b = *(mles + mle_index * n + step); - s25o::sub(b, b, a); - - auto c_prev = p[0]; - s25o::mul(p[0], c_prev, a); - for (unsigned pow = 1u; pow < i + 1u; ++pow) { - auto c = p[pow]; - s25o::mul(p[pow], c, a); - s25o::muladd(p[pow], c_prev, b, p[pow]); - c_prev = c; - } - s25o::mul(p[i + 1u], c_prev, b); - } -} - -//-------------------------------------------------------------------------------------------------- -// partial_expand_products -//-------------------------------------------------------------------------------------------------- -CUDA_CALLABLE -void partial_expand_products(basct::span p, const s25t::element* mles, unsigned n, - basct::cspan terms) noexcept { - auto num_terms = terms.size(); - assert( - // clang-format off - num_terms > 0 && - p.size() == num_terms + 1u - // clang-format on - ); - s25t::element a, b; - auto mle_index = terms[0]; - a = *(mles + mle_index * n); - s25o::neg(b, a); - p[0] = a; - p[1] = b; - - for (unsigned i = 1; i < num_terms; ++i) { - auto mle_index = terms[i]; - a = *(mles + mle_index * n); - s25o::neg(b, a); - - auto c_prev = p[0]; - s25o::mul(p[0], c_prev, a); - for (unsigned pow = 1u; pow < i + 1u; ++pow) { - auto c = p[pow]; - s25o::mul(p[pow], c, a); - s25o::muladd(p[pow], c_prev, b, p[pow]); - c_prev = c; - } - s25o::mul(p[i + 1u], c_prev, b); - } -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_utility.h b/sxt/proof/sumcheck/polynomial_utility.h deleted file mode 100644 index 6754c4d74..000000000 --- a/sxt/proof/sumcheck/polynomial_utility.h +++ /dev/null @@ -1,55 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/base/container/span.h" -#include "sxt/base/macro/cuda_callable.h" - -namespace sxt::s25t { -class element; -} - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// sum_polynomial_01 -//-------------------------------------------------------------------------------------------------- -// Given a polynomial -// f_a(X) = a[0] + a[1] * X + a[2] * X^2 + ... -// compute the sum -// f_a(0) + f_a(1) -void sum_polynomial_01(s25t::element& e, basct::cspan polynomial) noexcept; - -//-------------------------------------------------------------------------------------------------- -// evaluate_polynomial -//-------------------------------------------------------------------------------------------------- -void evaluate_polynomial(s25t::element& e, basct::cspan polynomial, - const s25t::element& x) noexcept; - -//-------------------------------------------------------------------------------------------------- -// expand_products -//-------------------------------------------------------------------------------------------------- -CUDA_CALLABLE -void expand_products(basct::span p, const s25t::element* mles, unsigned n, - unsigned step, basct::cspan terms) noexcept; - -//-------------------------------------------------------------------------------------------------- -// partial_expand_products -//-------------------------------------------------------------------------------------------------- -CUDA_CALLABLE -void partial_expand_products(basct::span p, const s25t::element* mles, unsigned n, - basct::cspan terms) noexcept; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_utility.t.cc b/sxt/proof/sumcheck/polynomial_utility.t.cc deleted file mode 100644 index 281e65e3a..000000000 --- a/sxt/proof/sumcheck/polynomial_utility.t.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/polynomial_utility.h" - -#include - -#include "sxt/base/test/unit_test.h" -#include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -using namespace sxt; -using namespace sxt::prfsk; -using s25t::operator""_s25; - -TEST_CASE("we perform basic operations on polynomials") { - s25t::element e; - - std::vector p; - - SECTION("we can compute the 0-1 sum of a zero polynomials") { - sum_polynomial_01(e, p); - REQUIRE(e == 0x0_s25); - } - - SECTION("we can compute the 0-1 sum of a constant polynomial") { - p = {0x123_s25}; - sum_polynomial_01(e, p); - REQUIRE(e == 0x246_s25); - } - - SECTION("we can compute the 0-1 sum of a 1 degree polynomial") { - p = {0x123_s25, 0x456_s25}; - sum_polynomial_01(e, p); - REQUIRE(e == 0x246_s25 + 0x456_s25); - } - - SECTION("we can evaluate the zero polynomial") { - evaluate_polynomial(e, p, 0x123_s25); - REQUIRE(e == 0x0_s25); - } - - SECTION("we can evaluate a constant polynomial") { - p = {0x123_s25}; - evaluate_polynomial(e, p, 0x321_s25); - REQUIRE(e == 0x123_s25); - } - - SECTION("we can evaluate a polynomial of degree 1") { - p = {0x123_s25, 0x456_s25}; - evaluate_polynomial(e, p, 0x321_s25); - REQUIRE(e == 0x123_s25 + 0x456_s25 * 0x321_s25); - } - - SECTION("we can evaluate a polynomial of degree 2") { - p = {0x123_s25, 0x456_s25, 0x789_s25}; - evaluate_polynomial(e, p, 0x321_s25); - REQUIRE(e == 0x123_s25 + 0x456_s25 * 0x321_s25 + 0x789_s25 * 0x321_s25 * 0x321_s25); - } -} - -TEST_CASE("we can expand a product of MLEs") { - std::vector p; - std::vector mles; - std::vector terms; - - SECTION("we can expand a single MLE") { - p.resize(2); - mles = {0x123_s25, 0x456_s25}; - terms = {0}; - expand_products(p, mles.data(), 2, 1, terms); - REQUIRE(p[0] == mles[0]); - REQUIRE(p[1] == mles[1] - mles[0]); - } - - SECTION("we can partially expand MLEs (where some terms are assumed to be zero)") { - mles = {0x123_s25, 0x0_s25}; - p.resize(2); - terms = {0}; - partial_expand_products(p, mles.data(), 1, terms); - - std::vector expected(2); - expand_products(expected, mles.data(), 2, 1, terms); - REQUIRE(p == expected); - } - - SECTION("we can expand two MLEs") { - p.resize(3); - mles = {0x123_s25, 0x456_s25, 0x1122_s25, 0x4455_s25}; - terms = {0, 1}; - expand_products(p, mles.data(), 2, 1, terms); - auto a1 = mles[0]; - auto a2 = mles[1] - mles[0]; - auto b1 = mles[2]; - auto b2 = mles[3] - mles[2]; - REQUIRE(p[0] == a1 * b1); - REQUIRE(p[1] == a1 * b2 + a2 * b1); - REQUIRE(p[2] == a2 * b2); - } - - SECTION("we can expand three MLEs") { - p.resize(4); - mles = {0x123_s25, 0x456_s25, 0x1122_s25, 0x4455_s25, 0x2233_s25, 0x5566_s25}; - terms = {0, 1, 2}; - expand_products(p, mles.data(), 2, 1, terms); - auto a1 = mles[0]; - auto a2 = mles[1] - mles[0]; - auto b1 = mles[2]; - auto b2 = mles[3] - mles[2]; - auto c1 = mles[4]; - auto c2 = mles[5] - mles[4]; - REQUIRE(p[0] == a1 * b1 * c1); - REQUIRE(p[1] == a1 * b1 * c2 + a1 * b2 * c1 + a2 * b1 * c1); - REQUIRE(p[2] == a1 * b2 * c2 + a2 * b1 * c2 + a2 * b2 * c1); - REQUIRE(p[3] == a2 * b2 * c2); - } -} diff --git a/sxt/proof/sumcheck/proof_computation.cc b/sxt/proof/sumcheck/proof_computation.cc deleted file mode 100644 index d77311d79..000000000 --- a/sxt/proof/sumcheck/proof_computation.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/proof_computation.h" - -#include "sxt/base/error/assert.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/proof/sumcheck/driver.h" -#include "sxt/proof/sumcheck/sumcheck_transcript.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// prove_sum -//-------------------------------------------------------------------------------------------------- -xena::future<> prove_sum(basct::span polynomials, - basct::span evaluation_point, - sumcheck_transcript& transcript, const driver& drv, - basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) noexcept { - SXT_RELEASE_ASSERT(0 < n); - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto polynomial_length = polynomials.size() / num_variables; - auto num_mles = mles.size() / n; - SXT_RELEASE_ASSERT( - // clang-format off - polynomial_length > 1 && - evaluation_point.size() == num_variables && - polynomials.size() == num_variables * polynomial_length && - mles.size() == n * num_mles - // clang-format on - ); - - transcript.init(num_variables, polynomial_length - 1); - - auto ws = co_await drv.make_workspace(mles, product_table, product_terms, n); - - for (unsigned round_index = 0; round_index < num_variables; ++round_index) { - auto polynomial = polynomials.subspan(round_index * polynomial_length, polynomial_length); - - // compute the round polynomial - co_await drv.sum(polynomial, *ws); - - // draw the next random challenge - s25t::element r; - transcript.round_challenge(r, polynomial); - evaluation_point[round_index] = r; - - // fold the polynomial - if (round_index < num_variables - 1u) { - co_await drv.fold(*ws, r); - } - } -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/proof_computation.h b/sxt/proof/sumcheck/proof_computation.h deleted file mode 100644 index 2569d7744..000000000 --- a/sxt/proof/sumcheck/proof_computation.h +++ /dev/null @@ -1,44 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "sxt/base/container/span.h" -#include "sxt/execution/async/future_fwd.h" - -namespace sxt::prft { -class transcript; -} -namespace sxt::s25t { -class element; -} - -namespace sxt::prfsk { -class driver; -class sumcheck_transcript; - -//-------------------------------------------------------------------------------------------------- -// prove_sum -//-------------------------------------------------------------------------------------------------- -xena::future<> prove_sum(basct::span polynomials, - basct::span evaluation_point, - sumcheck_transcript& transcript, const driver& drv, - basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) noexcept; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/proof_computation.t.cc b/sxt/proof/sumcheck/proof_computation.t.cc deleted file mode 100644 index b820e0a2b..000000000 --- a/sxt/proof/sumcheck/proof_computation.t.cc +++ /dev/null @@ -1,232 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/proof_computation.h" - -#include -#include - -#include "sxt/base/container/span_utility.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/base/num/fast_random_number_generator.h" -#include "sxt/base/test/unit_test.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/schedule/scheduler.h" -#include "sxt/proof/sumcheck/chunked_gpu_driver.h" -#include "sxt/proof/sumcheck/cpu_driver.h" -#include "sxt/proof/sumcheck/gpu_driver.h" -#include "sxt/proof/sumcheck/mle_utility.h" -#include "sxt/proof/sumcheck/polynomial_utility.h" -#include "sxt/proof/sumcheck/reference_transcript.h" -#include "sxt/proof/sumcheck/sumcheck_random.h" -#include "sxt/proof/sumcheck/verification.h" -#include "sxt/proof/transcript/transcript.h" -#include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -using namespace sxt; -using namespace sxt::prfsk; -using s25t::operator""_s25; - -static void test_proof(const driver& drv) noexcept { - prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; - std::vector polynomials(2); - std::vector evaluation_point(1); - std::vector mles = { - 0x8_s25, - 0x3_s25, - }; - std::vector> product_table = { - {0x1_s25, 1}, - }; - std::vector product_terms = {0}; - - SECTION("we can prove a sum with n=1") { - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 1); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(polynomials[0] == mles[0]); - REQUIRE(polynomials[1] == -mles[0]); - } - - SECTION("we can prove a sum with a single variable") { - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 2); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(polynomials[0] == mles[0]); - REQUIRE(polynomials[1] == mles[1] - mles[0]); - } - - SECTION("we can prove a sum degree greater than 1") { - product_table = { - {0x1_s25, 2}, - }; - product_terms = {0, 0}; - polynomials.resize(3); - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 2); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(polynomials[0] == mles[0] * mles[0]); - REQUIRE(polynomials[1] == 0x2_s25 * (mles[1] - mles[0]) * mles[0]); - REQUIRE(polynomials[2] == (mles[1] - mles[0]) * (mles[1] - mles[0])); - } - - SECTION("we can prove a sum with multiple MLEs") { - product_table = { - {0x1_s25, 2}, - }; - product_terms = {0, 1}; - polynomials.resize(3); - mles.push_back(0x7_s25); - mles.push_back(0x10_s25); - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 2); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(polynomials[0] == mles[0] * mles[2]); - REQUIRE(polynomials[1] == (mles[1] - mles[0]) * mles[2] + (mles[3] - mles[2]) * mles[0]); - REQUIRE(polynomials[2] == (mles[1] - mles[0]) * (mles[3] - mles[2])); - } - - SECTION("we can prove a sum where the term multiplier is different from one") { - product_table[0].first = 0x2_s25; - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 2); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(polynomials[0] == 0x2_s25 * mles[0]); - REQUIRE(polynomials[1] == 0x2_s25 * (mles[1] - mles[0])); - } - - SECTION("we can prove a sum with two variables") { - mles.push_back(0x4_s25); - mles.push_back(0x7_s25); - polynomials.resize(4); - evaluation_point.resize(2); - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 4); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(polynomials[0] == mles[0] + mles[1]); - REQUIRE(polynomials[1] == (mles[2] - mles[0]) + (mles[3] - mles[1])); - - auto r = evaluation_point[0]; - mles[0] = mles[0] * (0x1_s25 - r) + mles[2] * r; - mles[1] = mles[1] * (0x1_s25 - r) + mles[3] * r; - - REQUIRE(polynomials[2] == mles[0]); - REQUIRE(polynomials[3] == mles[1] - mles[0]); - } - - SECTION("we can prove a sum with n=3") { - mles.push_back(0x4_s25); - polynomials.resize(4); - evaluation_point.resize(2); - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 3); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(polynomials[0] == mles[0] + mles[1]); - REQUIRE(polynomials[1] == (mles[2] - mles[0]) - mles[1]); - - auto r = evaluation_point[0]; - mles[0] = mles[0] * (0x1_s25 - r) + mles[2] * r; - mles[1] = mles[1] * (0x1_s25 - r); - - REQUIRE(polynomials[2] == mles[0]); - REQUIRE(polynomials[3] == mles[1] - mles[0]); - } - - SECTION("we can verify random sumcheck problems") { - basn::fast_random_number_generator rng{1, 2}; - - for (unsigned i = 0; i < 10; ++i) { - random_sumcheck_descriptor descriptor; - unsigned n; - generate_random_sumcheck_problem(mles, product_table, product_terms, n, rng, descriptor); - - unsigned polynomial_length = 0; - for (auto [_, len] : product_table) { - polynomial_length = std::max(polynomial_length, len + 1u); - } - - auto num_variables = n == 1 ? 1 : basn::ceil_log2(n); - evaluation_point.resize(num_variables); - polynomials.resize(polynomial_length * num_variables); - - // prove - { - prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, n); - xens::get_scheduler().run(); - } - - // we can verify - { - prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; - s25t::element expected_sum; - sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); - auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - polynomials, polynomial_length - 1u); - REQUIRE(valid); - } - - // verification fails if we break the proof - { - prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; - s25t::element expected_sum; - sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); - polynomials[polynomials.size() - 1] = polynomials[0] + polynomials[1]; - auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - polynomials, polynomial_length - 1u); - REQUIRE(!valid); - } - } - } -} - -TEST_CASE("we can create a sumcheck proof") { - SECTION("we can prove with the cpu driver") { - cpu_driver drv; - test_proof(drv); - } - - SECTION("we can prove with the gpu driver") { - gpu_driver drv; - test_proof(drv); - } - - SECTION("we can prove with the chunked gpu driver") { - chunked_gpu_driver drv{0.0}; - test_proof(drv); - } - - SECTION("we can prove with a chunked driver that switches over to the single gpu driver") { - std::vector mles(4); - auto fraction = get_gpu_memory_fraction(mles); - chunked_gpu_driver drv{fraction}; - test_proof(drv); - } -} diff --git a/sxt/proof/sumcheck/reduction_gpu.cc b/sxt/proof/sumcheck/reduction_gpu.cc deleted file mode 100644 index e03c12d62..000000000 --- a/sxt/proof/sumcheck/reduction_gpu.cc +++ /dev/null @@ -1,114 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/reduction_gpu.h" - -#include "sxt/algorithm/base/identity_mapper.h" -#include "sxt/algorithm/reduction/kernel_fit.h" -#include "sxt/algorithm/reduction/thread_reduction.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/error/assert.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/device/synchronization.h" -#include "sxt/execution/kernel/kernel_dims.h" -#include "sxt/execution/kernel/launch.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/memory/resource/async_device_resource.h" -#include "sxt/scalar25/operation/accumulator.h" -#include "sxt/scalar25/operation/add.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// reduction_kernel -//-------------------------------------------------------------------------------------------------- -template -__global__ static void reduction_kernel(s25t::element* __restrict__ out, - const s25t::element* __restrict__ partials, - unsigned n) noexcept { - auto thread_index = threadIdx.x; - auto block_index = blockIdx.x; - auto coefficient_index = blockIdx.y; - auto index = block_index * (BlockSize * 2) + thread_index; - auto step = BlockSize * 2 * gridDim.x; - __shared__ s25t::element shared_data[2 * BlockSize]; - - // coefficient adjustment - out += coefficient_index; - partials += coefficient_index * n; - - // mapper - algb::identity_mapper mapper{partials}; - - // reduce - algr::thread_reduce(out + block_index, shared_data, mapper, n, step, - thread_index, index); -} - -//-------------------------------------------------------------------------------------------------- -// reduce_sums -//-------------------------------------------------------------------------------------------------- -xena::future<> reduce_sums(basct::span p, basdv::stream& stream, - basct::cspan partial_terms) noexcept { - auto num_coefficients = p.size(); - auto n = partial_terms.size() / num_coefficients; - SXT_DEBUG_ASSERT( - // clang-format off - n > 0 && - partial_terms.size() == num_coefficients * n && - basdv::is_host_pointer(p.data()) && - basdv::is_active_device_pointer(partial_terms.data()) - // clang-format on - ); - auto dims = algr::fit_reduction_kernel(n); - - // p_dev - memr::async_device_resource resource{stream}; - memmg::managed_array p_dev{num_coefficients * dims.num_blocks, &resource}; - - // launch kernel - xenk::launch_kernel(dims.block_size, [&]( - std::integral_constant) noexcept { - reduction_kernel - <<>>( - p_dev.data(), partial_terms.data(), n); - }); - - // copy polynomial to host - memmg::managed_array p_host_data; - basct::span p_host = p; - if (dims.num_blocks > 1) { - p_host_data.resize(p_dev.size()); - p_host = p_host_data; - } - basdv::async_copy_device_to_host(p_host, p_dev, stream); - co_await xendv::await_stream(stream); - - // complete reduction on host if necessary - if (dims.num_blocks == 1) { - co_return; - } - for (unsigned coefficient_index = 0; coefficient_index < num_coefficients; ++coefficient_index) { - p[coefficient_index] = p_host[coefficient_index * dims.num_blocks]; - for (unsigned block_index = 1; block_index < dims.num_blocks; ++block_index) { - s25o::add(p[coefficient_index], p[coefficient_index], - p_host[coefficient_index * dims.num_blocks + block_index]); - } - } -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/reduction_gpu.h b/sxt/proof/sumcheck/reduction_gpu.h deleted file mode 100644 index eb89835c5..000000000 --- a/sxt/proof/sumcheck/reduction_gpu.h +++ /dev/null @@ -1,35 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/base/container/span.h" -#include "sxt/execution/async/future_fwd.h" - -namespace sxt::basdv { -class stream; -} -namespace sxt::s25t { -class element; -} - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// reduce_sums -//-------------------------------------------------------------------------------------------------- -xena::future<> reduce_sums(basct::span p, basdv::stream& stream, - basct::cspan partial_terms) noexcept; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/reduction_gpu.t.cc b/sxt/proof/sumcheck/reduction_gpu.t.cc deleted file mode 100644 index 218d8f1b6..000000000 --- a/sxt/proof/sumcheck/reduction_gpu.t.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2025-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/reduction_gpu.h" - -#include - -#include "sxt/base/device/stream.h" -#include "sxt/base/test/unit_test.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/schedule/scheduler.h" -#include "sxt/memory/resource/managed_device_resource.h" -#include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -using namespace sxt; -using namespace sxt::prfsk; -using s25t::operator""_s25; - -TEST_CASE("we can reduce sumcheck polynomials") { - std::vector p; - std::pmr::vector partial_terms{memr::get_managed_device_resource()}; - - basdv::stream stream; - - SECTION("we can reduce a sum with a single term") { - p.resize(1); - partial_terms = {0x123_s25}; - auto fut = reduce_sums(p, stream, partial_terms); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == 0x123_s25); - } - - SECTION("we can reduce two terms") { - p.resize(1); - partial_terms = {0x123_s25, 0x456_s25}; - auto fut = reduce_sums(p, stream, partial_terms); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == 0x123_s25 + 0x456_s25); - } - - SECTION("we can reduce multiple coefficients") { - p.resize(2); - partial_terms = {0x123_s25, 0x456_s25}; - auto fut = reduce_sums(p, stream, partial_terms); - xens::get_scheduler().run(); - REQUIRE(fut.ready()); - REQUIRE(p[0] == 0x123_s25); - REQUIRE(p[1] == 0x456_s25); - } -} diff --git a/sxt/proof/sumcheck/verification.cc b/sxt/proof/sumcheck/verification.cc deleted file mode 100644 index 96e0eb66f..000000000 --- a/sxt/proof/sumcheck/verification.cc +++ /dev/null @@ -1,77 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/verification.h" - -#include "sxt/base/error/assert.h" -#include "sxt/base/log/log.h" -#include "sxt/proof/sumcheck/polynomial_utility.h" -#include "sxt/proof/sumcheck/sumcheck_transcript.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// verify_sumcheck_no_evaluation -//-------------------------------------------------------------------------------------------------- -bool verify_sumcheck_no_evaluation(s25t::element& expected_sum, - basct::span evaluation_point, - sumcheck_transcript& transcript, - basct::cspan round_polynomials, - unsigned round_degree) noexcept { - auto num_variables = evaluation_point.size(); - SXT_RELEASE_ASSERT( - // clang-format off - num_variables > 0 && round_degree > 0 - // clang-format on - ); - - basl::info("verifying sumcheck of {} variables and round degree {}", num_variables, round_degree); - - // check dimensions - if (auto expected_count = (round_degree + 1u) * num_variables; - round_polynomials.size() != expected_count) { - basl::info("sumcheck verification failed: expected {} scalars for round_polynomials but got {}", - expected_count, round_polynomials.size()); - return false; - } - - transcript.init(num_variables, round_degree); - - // go through sumcheck rounds - for (unsigned round_index = 0; round_index < num_variables; ++round_index) { - auto polynomial = - round_polynomials.subspan((round_degree + 1u) * round_index, round_degree + 1u); - - // check sum - s25t::element sum; - sum_polynomial_01(sum, polynomial); - if (expected_sum != sum) { - basl::info("sumcheck verification failed on round {}", round_index + 1); - return false; - } - - // draw a random scalar - s25t::element r; - transcript.round_challenge(r, polynomial); - evaluation_point[round_index] = r; - - // evaluate at random point - evaluate_polynomial(expected_sum, polynomial, r); - } - - return true; -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/verification.h b/sxt/proof/sumcheck/verification.h deleted file mode 100644 index 3ceea1819..000000000 --- a/sxt/proof/sumcheck/verification.h +++ /dev/null @@ -1,39 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/base/container/span.h" - -namespace sxt::prft { -class transcript; -} -namespace sxt::s25t { -class element; -} - -namespace sxt::prfsk { -class sumcheck_transcript; - -//-------------------------------------------------------------------------------------------------- -// verify_sumcheck_no_evaluation -//-------------------------------------------------------------------------------------------------- -bool verify_sumcheck_no_evaluation(s25t::element& expected_sum, - basct::span evaluation_point, - sumcheck_transcript& transcript, - basct::cspan round_polynomials, - unsigned round_degree) noexcept; -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/verification.t.cc b/sxt/proof/sumcheck/verification.t.cc deleted file mode 100644 index b2c2ccfe5..000000000 --- a/sxt/proof/sumcheck/verification.t.cc +++ /dev/null @@ -1,129 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/proof/sumcheck/verification.h" - -#include - -#include "sxt/base/test/unit_test.h" -#include "sxt/proof/sumcheck/reference_transcript.h" -#include "sxt/proof/transcript/transcript.h" -#include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -using namespace sxt; -using namespace sxt::prfsk; -using sxt::s25t::operator""_s25; - -TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { - s25t::element expected_sum = 0x0_s25; - std::vector evaluation_point = {0x0_s25}; - prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; - std::vector round_polynomials = {0x0_s25, 0x0_s25}; - - SECTION("verification fails if dimensions don't match") { - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 2); - REQUIRE(!res); - } - - SECTION("we can verify a single round") { - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); - REQUIRE(res); - REQUIRE(evaluation_point[0] != 0x0_s25); - } - - SECTION("verification fails if the round polynomial doesn't match the sum") { - round_polynomials[1] = 0x1_s25; - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); - REQUIRE(!res); - } - - SECTION("we can verify a sum with two rounds") { - // Use the MLE: - // 3(1-x1)(1-x2) + 5(1-x1)x2 -7x1(1-x2) -1x1x2 - round_polynomials.resize(4); - - // round 1 - round_polynomials[0] = 0x3_s25 + 0x5_s25; - round_polynomials[1] = -0x3_s25 - 0x7_s25 - 0x5_s25 - 0x1_s25; - - // draw scalar - s25t::element r; - { - prft::transcript base_transcript_p{"abc"}; - reference_transcript transcript_p{base_transcript_p}; - transcript_p.init(2, 1); - transcript_p.round_challenge(r, basct::span{round_polynomials}.subspan(0, 2)); - } - - // round 2 - round_polynomials[2] = 0x3_s25 * (0x1_s25 - r) - 0x7_s25 * r; - round_polynomials[3] = - -0x3_s25 * (0x1_s25 - r) + 0x5_s25 * (0x1_s25 - r) + 0x7_s25 * r - 0x1_s25 * r; - - // prove - evaluation_point.resize(2); - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); - REQUIRE(evaluation_point[0] == r); - REQUIRE(res); - } - - SECTION("sumcheck verification fails if the random scalar used is wrong") { - // Use the MLE: - // 3(1-x1)(1-x2) + 5(1-x1)x2 -7x1(1-x2) -1x1x2 - round_polynomials.resize(4); - - // round 1 - round_polynomials[0] = 0x3_s25 + 0x5_s25; - round_polynomials[1] = -0x3_s25 - 0x7_s25 - 0x5_s25 - 0x1_s25; - - // draw scalar - s25t::element r = 0x112233_s25; - - // round 2 - round_polynomials[2] = 0x3_s25 * (0x1_s25 - r) - 0x7_s25 * r; - round_polynomials[3] = - -0x3_s25 * (0x1_s25 - r) + 0x5_s25 * (0x1_s25 - r) + 0x7_s25 * r - 0x1_s25 * r; - - // prove - evaluation_point.resize(2); - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); - REQUIRE(!res); - } - - SECTION("we can verify a polynomial of degree 2 with one round") { - // Use the MLEs: - // f(x1) = 3(1-x1) -7x1 - // g(x1) = -2 (1 - x1) + 4 x1 - round_polynomials = { - 0x3_s25 * -0x2_s25, - (-0x3_s25 - 0x7_s25) * -0x2_s25 + 0x3_s25 * (0x2_s25 + 0x4_s25), - (-0x3_s25 - 0x7_s25) * (0x2_s25 + 0x4_s25), - }; - expected_sum = 0x3_s25 * -0x2_s25 - 0x7_s25 * 0x4_s25; - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 2); - REQUIRE(res); - REQUIRE(evaluation_point[0] != 0x0_s25); - } -} From 52d9376450e500bdef80bd5083a5fcbcc12c634e Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 20:01:18 -0800 Subject: [PATCH 81/83] refactor sumcheck --- benchmark/sumcheck/BUILD | 6 +++--- benchmark/sumcheck/benchmark.m.cc | 6 +++--- cbindings/BUILD | 2 +- cbindings/sumcheck.t.cc | 2 +- sxt/cbindings/backend/BUILD | 10 +++++----- .../backend/callback_sumcheck_transcript.h | 2 +- sxt/cbindings/backend/cpu_backend.cc | 4 ++-- sxt/cbindings/backend/gpu_backend.cc | 4 ++-- sxt/proof/{sumcheck2 => sumcheck}/BUILD | 0 .../chunked_gpu_driver.cc | 2 +- .../chunked_gpu_driver.h | 12 ++++++------ .../chunked_gpu_driver.t.cc | 4 ++-- sxt/proof/{sumcheck2 => sumcheck}/constant.cc | 2 +- sxt/proof/{sumcheck2 => sumcheck}/constant.h | 0 .../{sumcheck2 => sumcheck}/cpu_driver.cc | 2 +- sxt/proof/{sumcheck2 => sumcheck}/cpu_driver.h | 4 ++-- .../{sumcheck2 => sumcheck}/cpu_driver.t.cc | 4 ++-- .../{sumcheck2 => sumcheck}/device_cache.cc | 2 +- .../{sumcheck2 => sumcheck}/device_cache.h | 0 .../{sumcheck2 => sumcheck}/device_cache.t.cc | 2 +- sxt/proof/{sumcheck2 => sumcheck}/driver.cc | 2 +- sxt/proof/{sumcheck2 => sumcheck}/driver.h | 2 +- .../{sumcheck2 => sumcheck}/driver_test.cc | 2 +- .../{sumcheck2 => sumcheck}/driver_test.h | 2 +- sxt/proof/{sumcheck2 => sumcheck}/fold_gpu.cc | 2 +- sxt/proof/{sumcheck2 => sumcheck}/fold_gpu.h | 2 +- .../{sumcheck2 => sumcheck}/fold_gpu.t.cc | 2 +- .../{sumcheck2 => sumcheck}/gpu_driver.cc | 2 +- sxt/proof/{sumcheck2 => sumcheck}/gpu_driver.h | 4 ++-- .../{sumcheck2 => sumcheck}/gpu_driver.t.cc | 4 ++-- .../{sumcheck2 => sumcheck}/mle_utility.cc | 2 +- .../{sumcheck2 => sumcheck}/mle_utility.h | 0 .../{sumcheck2 => sumcheck}/mle_utility.t.cc | 2 +- .../polynomial_mapper.cc | 2 +- .../polynomial_mapper.h | 2 +- .../polynomial_mapper.t.cc | 2 +- .../polynomial_reducer.cc | 2 +- .../polynomial_reducer.h | 0 .../polynomial_utility.cc | 2 +- .../polynomial_utility.h | 0 .../polynomial_utility.t.cc | 2 +- .../proof_computation.cc | 2 +- .../proof_computation.h | 4 ++-- .../proof_computation.t.cc | 18 +++++++++--------- .../{sumcheck2 => sumcheck}/reduction_gpu.cc | 2 +- .../{sumcheck2 => sumcheck}/reduction_gpu.h | 0 .../{sumcheck2 => sumcheck}/reduction_gpu.t.cc | 2 +- .../reference_transcript.cc | 2 +- .../reference_transcript.h | 2 +- .../reference_transcript.t.cc | 2 +- sxt/proof/{sumcheck2 => sumcheck}/sum_gpu.cc | 2 +- sxt/proof/{sumcheck2 => sumcheck}/sum_gpu.h | 12 ++++++------ sxt/proof/{sumcheck2 => sumcheck}/sum_gpu.t.cc | 2 +- .../{sumcheck2 => sumcheck}/sumcheck_random.cc | 2 +- .../{sumcheck2 => sumcheck}/sumcheck_random.h | 2 +- .../sumcheck_transcript.cc | 2 +- .../sumcheck_transcript.h | 0 .../{sumcheck2 => sumcheck}/verification.cc | 2 +- .../{sumcheck2 => sumcheck}/verification.h | 4 ++-- .../{sumcheck2 => sumcheck}/verification.t.cc | 4 ++-- sxt/proof/{sumcheck2 => sumcheck}/workspace.cc | 2 +- sxt/proof/{sumcheck2 => sumcheck}/workspace.h | 0 62 files changed, 89 insertions(+), 89 deletions(-) rename sxt/proof/{sumcheck2 => sumcheck}/BUILD (100%) rename sxt/proof/{sumcheck2 => sumcheck}/chunked_gpu_driver.cc (93%) rename sxt/proof/{sumcheck2 => sumcheck}/chunked_gpu_driver.h (94%) rename sxt/proof/{sumcheck2 => sumcheck}/chunked_gpu_driver.t.cc (92%) rename sxt/proof/{sumcheck2 => sumcheck}/constant.cc (94%) rename sxt/proof/{sumcheck2 => sumcheck}/constant.h (100%) rename sxt/proof/{sumcheck2 => sumcheck}/cpu_driver.cc (94%) rename sxt/proof/{sumcheck2 => sumcheck}/cpu_driver.h (98%) rename sxt/proof/{sumcheck2 => sumcheck}/cpu_driver.t.cc (91%) rename sxt/proof/{sumcheck2 => sumcheck}/device_cache.cc (93%) rename sxt/proof/{sumcheck2 => sumcheck}/device_cache.h (100%) rename sxt/proof/{sumcheck2 => sumcheck}/device_cache.t.cc (98%) rename sxt/proof/{sumcheck2 => sumcheck}/driver.cc (94%) rename sxt/proof/{sumcheck2 => sumcheck}/driver.h (97%) rename sxt/proof/{sumcheck2 => sumcheck}/driver_test.cc (99%) rename sxt/proof/{sumcheck2 => sumcheck}/driver_test.h (96%) rename sxt/proof/{sumcheck2 => sumcheck}/fold_gpu.cc (94%) rename sxt/proof/{sumcheck2 => sumcheck}/fold_gpu.h (99%) rename sxt/proof/{sumcheck2 => sumcheck}/fold_gpu.t.cc (98%) rename sxt/proof/{sumcheck2 => sumcheck}/gpu_driver.cc (94%) rename sxt/proof/{sumcheck2 => sumcheck}/gpu_driver.h (98%) rename sxt/proof/{sumcheck2 => sumcheck}/gpu_driver.t.cc (91%) rename sxt/proof/{sumcheck2 => sumcheck}/mle_utility.cc (93%) rename sxt/proof/{sumcheck2 => sumcheck}/mle_utility.h (100%) rename sxt/proof/{sumcheck2 => sumcheck}/mle_utility.t.cc (98%) rename sxt/proof/{sumcheck2 => sumcheck}/polynomial_mapper.cc (93%) rename sxt/proof/{sumcheck2 => sumcheck}/polynomial_mapper.h (97%) rename sxt/proof/{sumcheck2 => sumcheck}/polynomial_mapper.t.cc (97%) rename sxt/proof/{sumcheck2 => sumcheck}/polynomial_reducer.cc (93%) rename sxt/proof/{sumcheck2 => sumcheck}/polynomial_reducer.h (100%) rename sxt/proof/{sumcheck2 => sumcheck}/polynomial_utility.cc (93%) rename sxt/proof/{sumcheck2 => sumcheck}/polynomial_utility.h (100%) rename sxt/proof/{sumcheck2 => sumcheck}/polynomial_utility.t.cc (98%) rename sxt/proof/{sumcheck2 => sumcheck}/proof_computation.cc (93%) rename sxt/proof/{sumcheck2 => sumcheck}/proof_computation.h (96%) rename sxt/proof/{sumcheck2 => sumcheck}/proof_computation.t.cc (94%) rename sxt/proof/{sumcheck2 => sumcheck}/reduction_gpu.cc (93%) rename sxt/proof/{sumcheck2 => sumcheck}/reduction_gpu.h (100%) rename sxt/proof/{sumcheck2 => sumcheck}/reduction_gpu.t.cc (97%) rename sxt/proof/{sumcheck2 => sumcheck}/reference_transcript.cc (92%) rename sxt/proof/{sumcheck2 => sumcheck}/reference_transcript.h (97%) rename sxt/proof/{sumcheck2 => sumcheck}/reference_transcript.t.cc (97%) rename sxt/proof/{sumcheck2 => sumcheck}/sum_gpu.cc (94%) rename sxt/proof/{sumcheck2 => sumcheck}/sum_gpu.h (97%) rename sxt/proof/{sumcheck2 => sumcheck}/sum_gpu.t.cc (99%) rename sxt/proof/{sumcheck2 => sumcheck}/sumcheck_random.cc (98%) rename sxt/proof/{sumcheck2 => sumcheck}/sumcheck_random.h (97%) rename sxt/proof/{sumcheck2 => sumcheck}/sumcheck_transcript.cc (92%) rename sxt/proof/{sumcheck2 => sumcheck}/sumcheck_transcript.h (100%) rename sxt/proof/{sumcheck2 => sumcheck}/verification.cc (93%) rename sxt/proof/{sumcheck2 => sumcheck}/verification.h (96%) rename sxt/proof/{sumcheck2 => sumcheck}/verification.t.cc (97%) rename sxt/proof/{sumcheck2 => sumcheck}/workspace.cc (94%) rename sxt/proof/{sumcheck2 => sumcheck}/workspace.h (100%) diff --git a/benchmark/sumcheck/BUILD b/benchmark/sumcheck/BUILD index d7e31ca8f..22a847b72 100644 --- a/benchmark/sumcheck/BUILD +++ b/benchmark/sumcheck/BUILD @@ -14,9 +14,9 @@ sxt_cc_benchmark( "//sxt/execution/async:future", "//sxt/execution/schedule:scheduler", "//sxt/memory/management:managed_array", - "//sxt/proof/sumcheck2:gpu_driver", - "//sxt/proof/sumcheck2:proof_computation", - "//sxt/proof/sumcheck2:reference_transcript", + "//sxt/proof/sumcheck:gpu_driver", + "//sxt/proof/sumcheck:proof_computation", + "//sxt/proof/sumcheck:reference_transcript", "//sxt/proof/transcript", "//sxt/scalar25/random:element", "//sxt/scalar25/realization:field", diff --git a/benchmark/sumcheck/benchmark.m.cc b/benchmark/sumcheck/benchmark.m.cc index 60697ebb6..48516953f 100644 --- a/benchmark/sumcheck/benchmark.m.cc +++ b/benchmark/sumcheck/benchmark.m.cc @@ -27,9 +27,9 @@ #include "sxt/execution/async/future.h" #include "sxt/execution/schedule/scheduler.h" #include "sxt/memory/management/managed_array.h" -#include "sxt/proof/sumcheck2/gpu_driver.h" -#include "sxt/proof/sumcheck2/proof_computation.h" -#include "sxt/proof/sumcheck2/reference_transcript.h" +#include "sxt/proof/sumcheck/gpu_driver.h" +#include "sxt/proof/sumcheck/proof_computation.h" +#include "sxt/proof/sumcheck/reference_transcript.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/scalar25/random/element.h" #include "sxt/scalar25/realization/field.h" diff --git a/cbindings/BUILD b/cbindings/BUILD index 56f548a23..0c49e4024 100644 --- a/cbindings/BUILD +++ b/cbindings/BUILD @@ -224,7 +224,7 @@ sxt_cc_component( test_deps = [ ":backend", "//sxt/base/test:unit_test", - "//sxt/proof/sumcheck2:reference_transcript", + "//sxt/proof/sumcheck:reference_transcript", "//sxt/scalar25/operation:overload", "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", diff --git a/cbindings/sumcheck.t.cc b/cbindings/sumcheck.t.cc index c397ff62e..e21c33f73 100644 --- a/cbindings/sumcheck.t.cc +++ b/cbindings/sumcheck.t.cc @@ -20,7 +20,7 @@ #include "cbindings/backend.h" #include "sxt/base/test/unit_test.h" -#include "sxt/proof/sumcheck2/reference_transcript.h" +#include "sxt/proof/sumcheck/reference_transcript.h" #include "sxt/scalar25/operation/overload.h" #include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" diff --git a/sxt/cbindings/backend/BUILD b/sxt/cbindings/backend/BUILD index 26f6cddfe..15b317d0a 100644 --- a/sxt/cbindings/backend/BUILD +++ b/sxt/cbindings/backend/BUILD @@ -9,7 +9,7 @@ sxt_cc_component( ], with_test = False, deps = [ - "//sxt/proof/sumcheck2:sumcheck_transcript", + "//sxt/proof/sumcheck:sumcheck_transcript", ], ) @@ -97,8 +97,8 @@ sxt_cc_component( "//sxt/proof/inner_product:proof_descriptor", "//sxt/proof/inner_product:proof_computation", "//sxt/proof/inner_product:gpu_driver", - "//sxt/proof/sumcheck2:chunked_gpu_driver", - "//sxt/proof/sumcheck2:proof_computation", + "//sxt/proof/sumcheck:chunked_gpu_driver", + "//sxt/proof/sumcheck:proof_computation", "//sxt/scalar25/realization:field", ], with_test = False, @@ -156,8 +156,8 @@ sxt_cc_component( "//sxt/proof/inner_product:proof_descriptor", "//sxt/proof/inner_product:proof_computation", "//sxt/proof/inner_product:cpu_driver", - "//sxt/proof/sumcheck2:cpu_driver", - "//sxt/proof/sumcheck2:proof_computation", + "//sxt/proof/sumcheck:cpu_driver", + "//sxt/proof/sumcheck:proof_computation", "//sxt/scalar25/realization:field", ], with_test = False, diff --git a/sxt/cbindings/backend/callback_sumcheck_transcript.h b/sxt/cbindings/backend/callback_sumcheck_transcript.h index 8c35e446a..0c7da99f2 100644 --- a/sxt/cbindings/backend/callback_sumcheck_transcript.h +++ b/sxt/cbindings/backend/callback_sumcheck_transcript.h @@ -16,7 +16,7 @@ */ #pragma once -#include "sxt/proof/sumcheck2/sumcheck_transcript.h" +#include "sxt/proof/sumcheck/sumcheck_transcript.h" namespace sxt::cbnbck { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/cbindings/backend/cpu_backend.cc b/sxt/cbindings/backend/cpu_backend.cc index b245c0c61..755b26610 100644 --- a/sxt/cbindings/backend/cpu_backend.cc +++ b/sxt/cbindings/backend/cpu_backend.cc @@ -59,8 +59,8 @@ #include "sxt/proof/inner_product/cpu_driver.h" #include "sxt/proof/inner_product/proof_computation.h" #include "sxt/proof/inner_product/proof_descriptor.h" -#include "sxt/proof/sumcheck2/cpu_driver.h" -#include "sxt/proof/sumcheck2/proof_computation.h" +#include "sxt/proof/sumcheck/cpu_driver.h" +#include "sxt/proof/sumcheck/proof_computation.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/ristretto/operation/compression.h" #include "sxt/ristretto/type/compressed_element.h" diff --git a/sxt/cbindings/backend/gpu_backend.cc b/sxt/cbindings/backend/gpu_backend.cc index 6ee151a74..eccb0c7a0 100644 --- a/sxt/cbindings/backend/gpu_backend.cc +++ b/sxt/cbindings/backend/gpu_backend.cc @@ -63,8 +63,8 @@ #include "sxt/proof/inner_product/gpu_driver.h" #include "sxt/proof/inner_product/proof_computation.h" #include "sxt/proof/inner_product/proof_descriptor.h" -#include "sxt/proof/sumcheck2/chunked_gpu_driver.h" -#include "sxt/proof/sumcheck2/proof_computation.h" +#include "sxt/proof/sumcheck/chunked_gpu_driver.h" +#include "sxt/proof/sumcheck/proof_computation.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/ristretto/operation/compression.h" #include "sxt/ristretto/type/compressed_element.h" diff --git a/sxt/proof/sumcheck2/BUILD b/sxt/proof/sumcheck/BUILD similarity index 100% rename from sxt/proof/sumcheck2/BUILD rename to sxt/proof/sumcheck/BUILD diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.cc b/sxt/proof/sumcheck/chunked_gpu_driver.cc similarity index 93% rename from sxt/proof/sumcheck2/chunked_gpu_driver.cc rename to sxt/proof/sumcheck/chunked_gpu_driver.cc index 8f2afc388..4a652d2fc 100644 --- a/sxt/proof/sumcheck2/chunked_gpu_driver.cc +++ b/sxt/proof/sumcheck/chunked_gpu_driver.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/chunked_gpu_driver.h" +#include "sxt/proof/sumcheck/chunked_gpu_driver.h" diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.h b/sxt/proof/sumcheck/chunked_gpu_driver.h similarity index 94% rename from sxt/proof/sumcheck2/chunked_gpu_driver.h rename to sxt/proof/sumcheck/chunked_gpu_driver.h index 1871fe920..016775932 100644 --- a/sxt/proof/sumcheck2/chunked_gpu_driver.h +++ b/sxt/proof/sumcheck/chunked_gpu_driver.h @@ -24,12 +24,12 @@ #include "sxt/execution/async/coroutine.h" #include "sxt/execution/async/future.h" #include "sxt/memory/management/managed_array.h" -#include "sxt/proof/sumcheck2/device_cache.h" -#include "sxt/proof/sumcheck2/driver.h" -#include "sxt/proof/sumcheck2/fold_gpu.h" -#include "sxt/proof/sumcheck2/gpu_driver.h" -#include "sxt/proof/sumcheck2/mle_utility.h" -#include "sxt/proof/sumcheck2/sum_gpu.h" +#include "sxt/proof/sumcheck/device_cache.h" +#include "sxt/proof/sumcheck/driver.h" +#include "sxt/proof/sumcheck/fold_gpu.h" +#include "sxt/proof/sumcheck/gpu_driver.h" +#include "sxt/proof/sumcheck/mle_utility.h" +#include "sxt/proof/sumcheck/sum_gpu.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc b/sxt/proof/sumcheck/chunked_gpu_driver.t.cc similarity index 92% rename from sxt/proof/sumcheck2/chunked_gpu_driver.t.cc rename to sxt/proof/sumcheck/chunked_gpu_driver.t.cc index da60a2e21..14bed91a7 100644 --- a/sxt/proof/sumcheck2/chunked_gpu_driver.t.cc +++ b/sxt/proof/sumcheck/chunked_gpu_driver.t.cc @@ -14,10 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/chunked_gpu_driver.h" +#include "sxt/proof/sumcheck/chunked_gpu_driver.h" #include "sxt/base/test/unit_test.h" -#include "sxt/proof/sumcheck2/driver_test.h" +#include "sxt/proof/sumcheck/driver_test.h" using namespace sxt; using namespace sxt::prfsk2; diff --git a/sxt/proof/sumcheck2/constant.cc b/sxt/proof/sumcheck/constant.cc similarity index 94% rename from sxt/proof/sumcheck2/constant.cc rename to sxt/proof/sumcheck/constant.cc index babcea3f3..12625ca3b 100644 --- a/sxt/proof/sumcheck2/constant.cc +++ b/sxt/proof/sumcheck/constant.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/constant.h" +#include "sxt/proof/sumcheck/constant.h" diff --git a/sxt/proof/sumcheck2/constant.h b/sxt/proof/sumcheck/constant.h similarity index 100% rename from sxt/proof/sumcheck2/constant.h rename to sxt/proof/sumcheck/constant.h diff --git a/sxt/proof/sumcheck2/cpu_driver.cc b/sxt/proof/sumcheck/cpu_driver.cc similarity index 94% rename from sxt/proof/sumcheck2/cpu_driver.cc rename to sxt/proof/sumcheck/cpu_driver.cc index d504430dc..7fa91e0da 100644 --- a/sxt/proof/sumcheck2/cpu_driver.cc +++ b/sxt/proof/sumcheck/cpu_driver.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/cpu_driver.h" +#include "sxt/proof/sumcheck/cpu_driver.h" diff --git a/sxt/proof/sumcheck2/cpu_driver.h b/sxt/proof/sumcheck/cpu_driver.h similarity index 98% rename from sxt/proof/sumcheck2/cpu_driver.h rename to sxt/proof/sumcheck/cpu_driver.h index 508124647..ae25aa8ce 100644 --- a/sxt/proof/sumcheck2/cpu_driver.h +++ b/sxt/proof/sumcheck/cpu_driver.h @@ -21,8 +21,8 @@ #include "sxt/base/num/ceil_log2.h" #include "sxt/execution/async/coroutine.h" #include "sxt/memory/management/managed_array.h" -#include "sxt/proof/sumcheck2/driver.h" -#include "sxt/proof/sumcheck2/polynomial_utility.h" +#include "sxt/proof/sumcheck/driver.h" +#include "sxt/proof/sumcheck/polynomial_utility.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/cpu_driver.t.cc b/sxt/proof/sumcheck/cpu_driver.t.cc similarity index 91% rename from sxt/proof/sumcheck2/cpu_driver.t.cc rename to sxt/proof/sumcheck/cpu_driver.t.cc index 4775cb0fb..47913ec41 100644 --- a/sxt/proof/sumcheck2/cpu_driver.t.cc +++ b/sxt/proof/sumcheck/cpu_driver.t.cc @@ -14,10 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/cpu_driver.h" +#include "sxt/proof/sumcheck/cpu_driver.h" #include "sxt/base/test/unit_test.h" -#include "sxt/proof/sumcheck2/driver_test.h" +#include "sxt/proof/sumcheck/driver_test.h" using namespace sxt; using namespace sxt::prfsk2; diff --git a/sxt/proof/sumcheck2/device_cache.cc b/sxt/proof/sumcheck/device_cache.cc similarity index 93% rename from sxt/proof/sumcheck2/device_cache.cc rename to sxt/proof/sumcheck/device_cache.cc index 5cafd3238..0fb5bce21 100644 --- a/sxt/proof/sumcheck2/device_cache.cc +++ b/sxt/proof/sumcheck/device_cache.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/device_cache.h" +#include "sxt/proof/sumcheck/device_cache.h" diff --git a/sxt/proof/sumcheck2/device_cache.h b/sxt/proof/sumcheck/device_cache.h similarity index 100% rename from sxt/proof/sumcheck2/device_cache.h rename to sxt/proof/sumcheck/device_cache.h diff --git a/sxt/proof/sumcheck2/device_cache.t.cc b/sxt/proof/sumcheck/device_cache.t.cc similarity index 98% rename from sxt/proof/sumcheck2/device_cache.t.cc rename to sxt/proof/sumcheck/device_cache.t.cc index 63a15b89a..9c28851c8 100644 --- a/sxt/proof/sumcheck2/device_cache.t.cc +++ b/sxt/proof/sumcheck/device_cache.t.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/device_cache.h" +#include "sxt/proof/sumcheck/device_cache.h" #include diff --git a/sxt/proof/sumcheck2/driver.cc b/sxt/proof/sumcheck/driver.cc similarity index 94% rename from sxt/proof/sumcheck2/driver.cc rename to sxt/proof/sumcheck/driver.cc index d28bdf310..075234c6c 100644 --- a/sxt/proof/sumcheck2/driver.cc +++ b/sxt/proof/sumcheck/driver.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/driver.h" +#include "sxt/proof/sumcheck/driver.h" diff --git a/sxt/proof/sumcheck2/driver.h b/sxt/proof/sumcheck/driver.h similarity index 97% rename from sxt/proof/sumcheck2/driver.h rename to sxt/proof/sumcheck/driver.h index 23aeec4d5..8cd5b8742 100644 --- a/sxt/proof/sumcheck2/driver.h +++ b/sxt/proof/sumcheck/driver.h @@ -21,7 +21,7 @@ #include "sxt/base/container/span.h" #include "sxt/base/field/element.h" #include "sxt/execution/async/future_fwd.h" -#include "sxt/proof/sumcheck2/workspace.h" +#include "sxt/proof/sumcheck/workspace.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/driver_test.cc b/sxt/proof/sumcheck/driver_test.cc similarity index 99% rename from sxt/proof/sumcheck2/driver_test.cc rename to sxt/proof/sumcheck/driver_test.cc index b74f47ada..befd57857 100644 --- a/sxt/proof/sumcheck2/driver_test.cc +++ b/sxt/proof/sumcheck/driver_test.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/driver_test.h" +#include "sxt/proof/sumcheck/driver_test.h" #include diff --git a/sxt/proof/sumcheck2/driver_test.h b/sxt/proof/sumcheck/driver_test.h similarity index 96% rename from sxt/proof/sumcheck2/driver_test.h rename to sxt/proof/sumcheck/driver_test.h index 167554551..8294f5c0c 100644 --- a/sxt/proof/sumcheck2/driver_test.h +++ b/sxt/proof/sumcheck/driver_test.h @@ -16,7 +16,7 @@ */ #pragma once -#include "sxt/proof/sumcheck2/driver.h" +#include "sxt/proof/sumcheck/driver.h" #include "sxt/scalar25/realization/field.h" namespace sxt::prfsk2 { diff --git a/sxt/proof/sumcheck2/fold_gpu.cc b/sxt/proof/sumcheck/fold_gpu.cc similarity index 94% rename from sxt/proof/sumcheck2/fold_gpu.cc rename to sxt/proof/sumcheck/fold_gpu.cc index 332097aae..3ade12a7c 100644 --- a/sxt/proof/sumcheck2/fold_gpu.cc +++ b/sxt/proof/sumcheck/fold_gpu.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/fold_gpu.h" +#include "sxt/proof/sumcheck/fold_gpu.h" diff --git a/sxt/proof/sumcheck2/fold_gpu.h b/sxt/proof/sumcheck/fold_gpu.h similarity index 99% rename from sxt/proof/sumcheck2/fold_gpu.h rename to sxt/proof/sumcheck/fold_gpu.h index 8dc89d6a7..4a6581c4b 100644 --- a/sxt/proof/sumcheck2/fold_gpu.h +++ b/sxt/proof/sumcheck/fold_gpu.h @@ -35,7 +35,7 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" #include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck2/mle_utility.h" +#include "sxt/proof/sumcheck/mle_utility.h" #include "sxt/scalar25/operation/mul.h" #include "sxt/scalar25/operation/muladd.h" #include "sxt/scalar25/operation/sub.h" diff --git a/sxt/proof/sumcheck2/fold_gpu.t.cc b/sxt/proof/sumcheck/fold_gpu.t.cc similarity index 98% rename from sxt/proof/sumcheck2/fold_gpu.t.cc rename to sxt/proof/sumcheck/fold_gpu.t.cc index d8f82b1cd..17da2fd6c 100644 --- a/sxt/proof/sumcheck2/fold_gpu.t.cc +++ b/sxt/proof/sumcheck/fold_gpu.t.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/fold_gpu.h" +#include "sxt/proof/sumcheck/fold_gpu.h" #include diff --git a/sxt/proof/sumcheck2/gpu_driver.cc b/sxt/proof/sumcheck/gpu_driver.cc similarity index 94% rename from sxt/proof/sumcheck2/gpu_driver.cc rename to sxt/proof/sumcheck/gpu_driver.cc index 2951ac714..a9c0566ff 100644 --- a/sxt/proof/sumcheck2/gpu_driver.cc +++ b/sxt/proof/sumcheck/gpu_driver.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/gpu_driver.h" +#include "sxt/proof/sumcheck/gpu_driver.h" diff --git a/sxt/proof/sumcheck2/gpu_driver.h b/sxt/proof/sumcheck/gpu_driver.h similarity index 98% rename from sxt/proof/sumcheck2/gpu_driver.h rename to sxt/proof/sumcheck/gpu_driver.h index 537df4d3a..3fc9ca260 100644 --- a/sxt/proof/sumcheck2/gpu_driver.h +++ b/sxt/proof/sumcheck/gpu_driver.h @@ -27,8 +27,8 @@ #include "sxt/execution/device/synchronization.h" #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck2/driver.h" -#include "sxt/proof/sumcheck2/sum_gpu.h" +#include "sxt/proof/sumcheck/driver.h" +#include "sxt/proof/sumcheck/sum_gpu.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/gpu_driver.t.cc b/sxt/proof/sumcheck/gpu_driver.t.cc similarity index 91% rename from sxt/proof/sumcheck2/gpu_driver.t.cc rename to sxt/proof/sumcheck/gpu_driver.t.cc index 1a58e5823..fff764959 100644 --- a/sxt/proof/sumcheck2/gpu_driver.t.cc +++ b/sxt/proof/sumcheck/gpu_driver.t.cc @@ -14,10 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/gpu_driver.h" +#include "sxt/proof/sumcheck/gpu_driver.h" #include "sxt/base/test/unit_test.h" -#include "sxt/proof/sumcheck2/driver_test.h" +#include "sxt/proof/sumcheck/driver_test.h" using namespace sxt; using namespace sxt::prfsk2; diff --git a/sxt/proof/sumcheck2/mle_utility.cc b/sxt/proof/sumcheck/mle_utility.cc similarity index 93% rename from sxt/proof/sumcheck2/mle_utility.cc rename to sxt/proof/sumcheck/mle_utility.cc index c4bf89222..e548908cc 100644 --- a/sxt/proof/sumcheck2/mle_utility.cc +++ b/sxt/proof/sumcheck/mle_utility.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/mle_utility.h" +#include "sxt/proof/sumcheck/mle_utility.h" diff --git a/sxt/proof/sumcheck2/mle_utility.h b/sxt/proof/sumcheck/mle_utility.h similarity index 100% rename from sxt/proof/sumcheck2/mle_utility.h rename to sxt/proof/sumcheck/mle_utility.h diff --git a/sxt/proof/sumcheck2/mle_utility.t.cc b/sxt/proof/sumcheck/mle_utility.t.cc similarity index 98% rename from sxt/proof/sumcheck2/mle_utility.t.cc rename to sxt/proof/sumcheck/mle_utility.t.cc index b0b53d1b6..20e48bc07 100644 --- a/sxt/proof/sumcheck2/mle_utility.t.cc +++ b/sxt/proof/sumcheck/mle_utility.t.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/mle_utility.h" +#include "sxt/proof/sumcheck/mle_utility.h" #include diff --git a/sxt/proof/sumcheck2/polynomial_mapper.cc b/sxt/proof/sumcheck/polynomial_mapper.cc similarity index 93% rename from sxt/proof/sumcheck2/polynomial_mapper.cc rename to sxt/proof/sumcheck/polynomial_mapper.cc index cb2f5c790..a46364a4e 100644 --- a/sxt/proof/sumcheck2/polynomial_mapper.cc +++ b/sxt/proof/sumcheck/polynomial_mapper.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/polynomial_mapper.h" +#include "sxt/proof/sumcheck/polynomial_mapper.h" diff --git a/sxt/proof/sumcheck2/polynomial_mapper.h b/sxt/proof/sumcheck/polynomial_mapper.h similarity index 97% rename from sxt/proof/sumcheck2/polynomial_mapper.h rename to sxt/proof/sumcheck/polynomial_mapper.h index b8eab6ed6..0b4c9f50d 100644 --- a/sxt/proof/sumcheck2/polynomial_mapper.h +++ b/sxt/proof/sumcheck/polynomial_mapper.h @@ -18,7 +18,7 @@ #include "sxt/base/field/element.h" #include "sxt/base/macro/cuda_callable.h" -#include "sxt/proof/sumcheck2/polynomial_utility.h" +#include "sxt/proof/sumcheck/polynomial_utility.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/polynomial_mapper.t.cc b/sxt/proof/sumcheck/polynomial_mapper.t.cc similarity index 97% rename from sxt/proof/sumcheck2/polynomial_mapper.t.cc rename to sxt/proof/sumcheck/polynomial_mapper.t.cc index 11d26d27e..55b44a88f 100644 --- a/sxt/proof/sumcheck2/polynomial_mapper.t.cc +++ b/sxt/proof/sumcheck/polynomial_mapper.t.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/polynomial_mapper.h" +#include "sxt/proof/sumcheck/polynomial_mapper.h" #include diff --git a/sxt/proof/sumcheck2/polynomial_reducer.cc b/sxt/proof/sumcheck/polynomial_reducer.cc similarity index 93% rename from sxt/proof/sumcheck2/polynomial_reducer.cc rename to sxt/proof/sumcheck/polynomial_reducer.cc index 47e62134d..cc61a64d1 100644 --- a/sxt/proof/sumcheck2/polynomial_reducer.cc +++ b/sxt/proof/sumcheck/polynomial_reducer.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/polynomial_reducer.h" +#include "sxt/proof/sumcheck/polynomial_reducer.h" diff --git a/sxt/proof/sumcheck2/polynomial_reducer.h b/sxt/proof/sumcheck/polynomial_reducer.h similarity index 100% rename from sxt/proof/sumcheck2/polynomial_reducer.h rename to sxt/proof/sumcheck/polynomial_reducer.h diff --git a/sxt/proof/sumcheck2/polynomial_utility.cc b/sxt/proof/sumcheck/polynomial_utility.cc similarity index 93% rename from sxt/proof/sumcheck2/polynomial_utility.cc rename to sxt/proof/sumcheck/polynomial_utility.cc index 701fe438d..46faf33bb 100644 --- a/sxt/proof/sumcheck2/polynomial_utility.cc +++ b/sxt/proof/sumcheck/polynomial_utility.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/polynomial_utility.h" +#include "sxt/proof/sumcheck/polynomial_utility.h" diff --git a/sxt/proof/sumcheck2/polynomial_utility.h b/sxt/proof/sumcheck/polynomial_utility.h similarity index 100% rename from sxt/proof/sumcheck2/polynomial_utility.h rename to sxt/proof/sumcheck/polynomial_utility.h diff --git a/sxt/proof/sumcheck2/polynomial_utility.t.cc b/sxt/proof/sumcheck/polynomial_utility.t.cc similarity index 98% rename from sxt/proof/sumcheck2/polynomial_utility.t.cc rename to sxt/proof/sumcheck/polynomial_utility.t.cc index 4081ed049..9d429dc8b 100644 --- a/sxt/proof/sumcheck2/polynomial_utility.t.cc +++ b/sxt/proof/sumcheck/polynomial_utility.t.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/polynomial_utility.h" +#include "sxt/proof/sumcheck/polynomial_utility.h" #include diff --git a/sxt/proof/sumcheck2/proof_computation.cc b/sxt/proof/sumcheck/proof_computation.cc similarity index 93% rename from sxt/proof/sumcheck2/proof_computation.cc rename to sxt/proof/sumcheck/proof_computation.cc index a87c23c27..d999e8d39 100644 --- a/sxt/proof/sumcheck2/proof_computation.cc +++ b/sxt/proof/sumcheck/proof_computation.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/proof_computation.h" +#include "sxt/proof/sumcheck/proof_computation.h" diff --git a/sxt/proof/sumcheck2/proof_computation.h b/sxt/proof/sumcheck/proof_computation.h similarity index 96% rename from sxt/proof/sumcheck2/proof_computation.h rename to sxt/proof/sumcheck/proof_computation.h index 727542e4c..23539d665 100644 --- a/sxt/proof/sumcheck2/proof_computation.h +++ b/sxt/proof/sumcheck/proof_computation.h @@ -21,8 +21,8 @@ #include "sxt/base/num/ceil_log2.h" #include "sxt/execution/async/coroutine.h" #include "sxt/execution/async/future.h" -#include "sxt/proof/sumcheck2/driver.h" -#include "sxt/proof/sumcheck2/sumcheck_transcript.h" +#include "sxt/proof/sumcheck/driver.h" +#include "sxt/proof/sumcheck/sumcheck_transcript.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/proof_computation.t.cc b/sxt/proof/sumcheck/proof_computation.t.cc similarity index 94% rename from sxt/proof/sumcheck2/proof_computation.t.cc rename to sxt/proof/sumcheck/proof_computation.t.cc index 022b501a2..8321698a7 100644 --- a/sxt/proof/sumcheck2/proof_computation.t.cc +++ b/sxt/proof/sumcheck/proof_computation.t.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/proof_computation.h" +#include "sxt/proof/sumcheck/proof_computation.h" #include #include @@ -25,14 +25,14 @@ #include "sxt/base/test/unit_test.h" #include "sxt/execution/async/future.h" #include "sxt/execution/schedule/scheduler.h" -#include "sxt/proof/sumcheck2/chunked_gpu_driver.h" -#include "sxt/proof/sumcheck2/cpu_driver.h" -#include "sxt/proof/sumcheck2/gpu_driver.h" -#include "sxt/proof/sumcheck2/mle_utility.h" -#include "sxt/proof/sumcheck2/polynomial_utility.h" -#include "sxt/proof/sumcheck2/reference_transcript.h" -#include "sxt/proof/sumcheck2/sumcheck_random.h" -#include "sxt/proof/sumcheck2/verification.h" +#include "sxt/proof/sumcheck/chunked_gpu_driver.h" +#include "sxt/proof/sumcheck/cpu_driver.h" +#include "sxt/proof/sumcheck/gpu_driver.h" +#include "sxt/proof/sumcheck/mle_utility.h" +#include "sxt/proof/sumcheck/polynomial_utility.h" +#include "sxt/proof/sumcheck/reference_transcript.h" +#include "sxt/proof/sumcheck/sumcheck_random.h" +#include "sxt/proof/sumcheck/verification.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/scalar25/operation/overload.h" #include "sxt/scalar25/realization/field.h" diff --git a/sxt/proof/sumcheck2/reduction_gpu.cc b/sxt/proof/sumcheck/reduction_gpu.cc similarity index 93% rename from sxt/proof/sumcheck2/reduction_gpu.cc rename to sxt/proof/sumcheck/reduction_gpu.cc index 1b4232869..837b4c09d 100644 --- a/sxt/proof/sumcheck2/reduction_gpu.cc +++ b/sxt/proof/sumcheck/reduction_gpu.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/reduction_gpu.h" +#include "sxt/proof/sumcheck/reduction_gpu.h" diff --git a/sxt/proof/sumcheck2/reduction_gpu.h b/sxt/proof/sumcheck/reduction_gpu.h similarity index 100% rename from sxt/proof/sumcheck2/reduction_gpu.h rename to sxt/proof/sumcheck/reduction_gpu.h diff --git a/sxt/proof/sumcheck2/reduction_gpu.t.cc b/sxt/proof/sumcheck/reduction_gpu.t.cc similarity index 97% rename from sxt/proof/sumcheck2/reduction_gpu.t.cc rename to sxt/proof/sumcheck/reduction_gpu.t.cc index 4d7659d53..d3da21f26 100644 --- a/sxt/proof/sumcheck2/reduction_gpu.t.cc +++ b/sxt/proof/sumcheck/reduction_gpu.t.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/reduction_gpu.h" +#include "sxt/proof/sumcheck/reduction_gpu.h" #include diff --git a/sxt/proof/sumcheck2/reference_transcript.cc b/sxt/proof/sumcheck/reference_transcript.cc similarity index 92% rename from sxt/proof/sumcheck2/reference_transcript.cc rename to sxt/proof/sumcheck/reference_transcript.cc index f49c8a511..1d1616dc0 100644 --- a/sxt/proof/sumcheck2/reference_transcript.cc +++ b/sxt/proof/sumcheck/reference_transcript.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/reference_transcript.h" +#include "sxt/proof/sumcheck/reference_transcript.h" diff --git a/sxt/proof/sumcheck2/reference_transcript.h b/sxt/proof/sumcheck/reference_transcript.h similarity index 97% rename from sxt/proof/sumcheck2/reference_transcript.h rename to sxt/proof/sumcheck/reference_transcript.h index be65c6dc4..cc23cb7f6 100644 --- a/sxt/proof/sumcheck2/reference_transcript.h +++ b/sxt/proof/sumcheck/reference_transcript.h @@ -16,7 +16,7 @@ */ #pragma once -#include "sxt/proof/sumcheck2/sumcheck_transcript.h" +#include "sxt/proof/sumcheck/sumcheck_transcript.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/proof/transcript/transcript_utility.h" diff --git a/sxt/proof/sumcheck2/reference_transcript.t.cc b/sxt/proof/sumcheck/reference_transcript.t.cc similarity index 97% rename from sxt/proof/sumcheck2/reference_transcript.t.cc rename to sxt/proof/sumcheck/reference_transcript.t.cc index adac2b364..b97eaf6ac 100644 --- a/sxt/proof/sumcheck2/reference_transcript.t.cc +++ b/sxt/proof/sumcheck/reference_transcript.t.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/reference_transcript.h" +#include "sxt/proof/sumcheck/reference_transcript.h" #include "sxt/base/test/unit_test.h" #include "sxt/proof/transcript/transcript.h" diff --git a/sxt/proof/sumcheck2/sum_gpu.cc b/sxt/proof/sumcheck/sum_gpu.cc similarity index 94% rename from sxt/proof/sumcheck2/sum_gpu.cc rename to sxt/proof/sumcheck/sum_gpu.cc index 26d71a8af..e177b6f83 100644 --- a/sxt/proof/sumcheck2/sum_gpu.cc +++ b/sxt/proof/sumcheck/sum_gpu.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/sum_gpu.h" +#include "sxt/proof/sumcheck/sum_gpu.h" diff --git a/sxt/proof/sumcheck2/sum_gpu.h b/sxt/proof/sumcheck/sum_gpu.h similarity index 97% rename from sxt/proof/sumcheck2/sum_gpu.h rename to sxt/proof/sumcheck/sum_gpu.h index a5e8fc909..14adec7c3 100644 --- a/sxt/proof/sumcheck2/sum_gpu.h +++ b/sxt/proof/sumcheck/sum_gpu.h @@ -35,12 +35,12 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" #include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck2/constant.h" -#include "sxt/proof/sumcheck2/device_cache.h" -#include "sxt/proof/sumcheck2/mle_utility.h" -#include "sxt/proof/sumcheck2/polynomial_mapper.h" -#include "sxt/proof/sumcheck2/polynomial_reducer.h" -#include "sxt/proof/sumcheck2/reduction_gpu.h" +#include "sxt/proof/sumcheck/constant.h" +#include "sxt/proof/sumcheck/device_cache.h" +#include "sxt/proof/sumcheck/mle_utility.h" +#include "sxt/proof/sumcheck/polynomial_mapper.h" +#include "sxt/proof/sumcheck/polynomial_reducer.h" +#include "sxt/proof/sumcheck/reduction_gpu.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/sum_gpu.t.cc b/sxt/proof/sumcheck/sum_gpu.t.cc similarity index 99% rename from sxt/proof/sumcheck2/sum_gpu.t.cc rename to sxt/proof/sumcheck/sum_gpu.t.cc index c6759ab61..48c022388 100644 --- a/sxt/proof/sumcheck2/sum_gpu.t.cc +++ b/sxt/proof/sumcheck/sum_gpu.t.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/sum_gpu.h" +#include "sxt/proof/sumcheck/sum_gpu.h" #include diff --git a/sxt/proof/sumcheck2/sumcheck_random.cc b/sxt/proof/sumcheck/sumcheck_random.cc similarity index 98% rename from sxt/proof/sumcheck2/sumcheck_random.cc rename to sxt/proof/sumcheck/sumcheck_random.cc index 17c32d430..04db4e7d5 100644 --- a/sxt/proof/sumcheck2/sumcheck_random.cc +++ b/sxt/proof/sumcheck/sumcheck_random.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/sumcheck_random.h" +#include "sxt/proof/sumcheck/sumcheck_random.h" #include diff --git a/sxt/proof/sumcheck2/sumcheck_random.h b/sxt/proof/sumcheck/sumcheck_random.h similarity index 97% rename from sxt/proof/sumcheck2/sumcheck_random.h rename to sxt/proof/sumcheck/sumcheck_random.h index 7bfec674f..5c88b944d 100644 --- a/sxt/proof/sumcheck2/sumcheck_random.h +++ b/sxt/proof/sumcheck/sumcheck_random.h @@ -19,7 +19,7 @@ #include #include -#include "sxt/proof/sumcheck2/constant.h" +#include "sxt/proof/sumcheck/constant.h" namespace sxt::s25t { class element; diff --git a/sxt/proof/sumcheck2/sumcheck_transcript.cc b/sxt/proof/sumcheck/sumcheck_transcript.cc similarity index 92% rename from sxt/proof/sumcheck2/sumcheck_transcript.cc rename to sxt/proof/sumcheck/sumcheck_transcript.cc index d04730afa..36b1ecd89 100644 --- a/sxt/proof/sumcheck2/sumcheck_transcript.cc +++ b/sxt/proof/sumcheck/sumcheck_transcript.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/sumcheck_transcript.h" +#include "sxt/proof/sumcheck/sumcheck_transcript.h" diff --git a/sxt/proof/sumcheck2/sumcheck_transcript.h b/sxt/proof/sumcheck/sumcheck_transcript.h similarity index 100% rename from sxt/proof/sumcheck2/sumcheck_transcript.h rename to sxt/proof/sumcheck/sumcheck_transcript.h diff --git a/sxt/proof/sumcheck2/verification.cc b/sxt/proof/sumcheck/verification.cc similarity index 93% rename from sxt/proof/sumcheck2/verification.cc rename to sxt/proof/sumcheck/verification.cc index e1d4e844f..16233df78 100644 --- a/sxt/proof/sumcheck2/verification.cc +++ b/sxt/proof/sumcheck/verification.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/verification.h" +#include "sxt/proof/sumcheck/verification.h" diff --git a/sxt/proof/sumcheck2/verification.h b/sxt/proof/sumcheck/verification.h similarity index 96% rename from sxt/proof/sumcheck2/verification.h rename to sxt/proof/sumcheck/verification.h index 40166a92b..75ee6b428 100644 --- a/sxt/proof/sumcheck2/verification.h +++ b/sxt/proof/sumcheck/verification.h @@ -19,8 +19,8 @@ #include "sxt/base/container/span.h" #include "sxt/base/error/assert.h" #include "sxt/base/log/log.h" -#include "sxt/proof/sumcheck2/polynomial_utility.h" -#include "sxt/proof/sumcheck2/sumcheck_transcript.h" +#include "sxt/proof/sumcheck/polynomial_utility.h" +#include "sxt/proof/sumcheck/sumcheck_transcript.h" namespace sxt::prfsk2 { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck2/verification.t.cc b/sxt/proof/sumcheck/verification.t.cc similarity index 97% rename from sxt/proof/sumcheck2/verification.t.cc rename to sxt/proof/sumcheck/verification.t.cc index f48605338..e7b614f61 100644 --- a/sxt/proof/sumcheck2/verification.t.cc +++ b/sxt/proof/sumcheck/verification.t.cc @@ -14,12 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/verification.h" +#include "sxt/proof/sumcheck/verification.h" #include #include "sxt/base/test/unit_test.h" -#include "sxt/proof/sumcheck2/reference_transcript.h" +#include "sxt/proof/sumcheck/reference_transcript.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/scalar25/operation/overload.h" #include "sxt/scalar25/realization/field.h" diff --git a/sxt/proof/sumcheck2/workspace.cc b/sxt/proof/sumcheck/workspace.cc similarity index 94% rename from sxt/proof/sumcheck2/workspace.cc rename to sxt/proof/sumcheck/workspace.cc index 32a934d6b..997e37cd3 100644 --- a/sxt/proof/sumcheck2/workspace.cc +++ b/sxt/proof/sumcheck/workspace.cc @@ -14,4 +14,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sxt/proof/sumcheck2/workspace.h" +#include "sxt/proof/sumcheck/workspace.h" diff --git a/sxt/proof/sumcheck2/workspace.h b/sxt/proof/sumcheck/workspace.h similarity index 100% rename from sxt/proof/sumcheck2/workspace.h rename to sxt/proof/sumcheck/workspace.h From 81603bdfc85d3bd7a9b457f2e0f1ecb5a999d22b Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 20:02:08 -0800 Subject: [PATCH 82/83] refactor sumcheck --- benchmark/sumcheck/benchmark.m.cc | 12 ++++++------ cbindings/sumcheck.t.cc | 8 ++++---- sxt/cbindings/backend/callback_sumcheck_transcript.h | 2 +- sxt/cbindings/backend/cpu_backend.cc | 6 +++--- sxt/cbindings/backend/gpu_backend.cc | 6 +++--- sxt/proof/sumcheck/chunked_gpu_driver.h | 4 ++-- sxt/proof/sumcheck/chunked_gpu_driver.t.cc | 2 +- sxt/proof/sumcheck/constant.h | 4 ++-- sxt/proof/sumcheck/cpu_driver.h | 4 ++-- sxt/proof/sumcheck/cpu_driver.t.cc | 2 +- sxt/proof/sumcheck/device_cache.h | 4 ++-- sxt/proof/sumcheck/device_cache.t.cc | 2 +- sxt/proof/sumcheck/driver.h | 4 ++-- sxt/proof/sumcheck/driver_test.cc | 4 ++-- sxt/proof/sumcheck/driver_test.h | 4 ++-- sxt/proof/sumcheck/fold_gpu.h | 4 ++-- sxt/proof/sumcheck/fold_gpu.t.cc | 2 +- sxt/proof/sumcheck/gpu_driver.h | 4 ++-- sxt/proof/sumcheck/gpu_driver.t.cc | 2 +- sxt/proof/sumcheck/mle_utility.h | 4 ++-- sxt/proof/sumcheck/mle_utility.t.cc | 2 +- sxt/proof/sumcheck/polynomial_mapper.h | 4 ++-- sxt/proof/sumcheck/polynomial_mapper.t.cc | 2 +- sxt/proof/sumcheck/polynomial_reducer.h | 4 ++-- sxt/proof/sumcheck/polynomial_utility.h | 4 ++-- sxt/proof/sumcheck/polynomial_utility.t.cc | 2 +- sxt/proof/sumcheck/proof_computation.h | 4 ++-- sxt/proof/sumcheck/proof_computation.t.cc | 2 +- sxt/proof/sumcheck/reduction_gpu.h | 4 ++-- sxt/proof/sumcheck/reduction_gpu.t.cc | 2 +- sxt/proof/sumcheck/reference_transcript.h | 4 ++-- sxt/proof/sumcheck/reference_transcript.t.cc | 2 +- sxt/proof/sumcheck/sum_gpu.h | 4 ++-- sxt/proof/sumcheck/sum_gpu.t.cc | 2 +- sxt/proof/sumcheck/sumcheck_random.cc | 4 ++-- sxt/proof/sumcheck/sumcheck_random.h | 4 ++-- sxt/proof/sumcheck/sumcheck_transcript.h | 4 ++-- sxt/proof/sumcheck/verification.h | 4 ++-- sxt/proof/sumcheck/verification.t.cc | 2 +- sxt/proof/sumcheck/workspace.h | 4 ++-- 40 files changed, 74 insertions(+), 74 deletions(-) diff --git a/benchmark/sumcheck/benchmark.m.cc b/benchmark/sumcheck/benchmark.m.cc index 48516953f..6b0e64fa7 100644 --- a/benchmark/sumcheck/benchmark.m.cc +++ b/benchmark/sumcheck/benchmark.m.cc @@ -111,13 +111,13 @@ int main(int argc, char* argv[]) { memmg::managed_array polynomials((p.degree + 1u) * num_rounds); memmg::managed_array evaluation_point(num_rounds); prft::transcript base_transcript{"abc123"}; - prfsk2::reference_transcript transcript{base_transcript}; - prfsk2::gpu_driver drv; + prfsk::reference_transcript transcript{base_transcript}; + prfsk::gpu_driver drv; // initial run { - auto fut = prfsk2::prove_sum(polynomials, evaluation_point, transcript, drv, - mles, product_table, product_terms, p.n); + auto fut = prfsk::prove_sum(polynomials, evaluation_point, transcript, drv, mles, + product_table, product_terms, p.n); xens::get_scheduler().run(); } @@ -125,8 +125,8 @@ int main(int argc, char* argv[]) { double elapse = 0; for (unsigned i = 0; i < (p.num_samples + 1u); ++i) { auto t1 = std::chrono::steady_clock::now(); - auto fut = prfsk2::prove_sum(polynomials, evaluation_point, transcript, drv, - mles, product_table, product_terms, p.n); + auto fut = prfsk::prove_sum(polynomials, evaluation_point, transcript, drv, mles, + product_table, product_terms, p.n); xens::get_scheduler().run(); auto t2 = std::chrono::steady_clock::now(); auto duration = std::chrono::duration_cast(t2 - t1); diff --git a/cbindings/sumcheck.t.cc b/cbindings/sumcheck.t.cc index e21c33f73..2b0c739af 100644 --- a/cbindings/sumcheck.t.cc +++ b/cbindings/sumcheck.t.cc @@ -30,7 +30,7 @@ using s25t::operator""_s25; TEST_CASE("we can create sumcheck proofs") { prft::transcript base_transcript{"abc"}; - prfsk2::reference_transcript transcript{base_transcript}; + prfsk::reference_transcript transcript{base_transcript}; std::vector polynomials(2); std::vector evaluation_point(1); @@ -55,7 +55,7 @@ TEST_CASE("we can create sumcheck proofs") { auto f = [](s25t::element* r, void* context, const s25t::element* polynomial, unsigned polynomial_len) noexcept { - static_cast*>(context)->round_challenge( + static_cast*>(context)->round_challenge( *r, {polynomial, polynomial_len}); }; @@ -70,7 +70,7 @@ TEST_CASE("we can create sumcheck proofs") { REQUIRE(polynomials[1] == mles[1] - mles[0]); { prft::transcript base_transcript_p{"abc"}; - prfsk2::reference_transcript transcript_p{base_transcript_p}; + prfsk::reference_transcript transcript_p{base_transcript_p}; s25t::element r; transcript_p.round_challenge(r, polynomials); REQUIRE(evaluation_point[0] == r); @@ -88,7 +88,7 @@ TEST_CASE("we can create sumcheck proofs") { REQUIRE(polynomials[1] == mles[1] - mles[0]); { prft::transcript base_transcript_p{"abc"}; - prfsk2::reference_transcript transcript_p{base_transcript_p}; + prfsk::reference_transcript transcript_p{base_transcript_p}; s25t::element r; transcript_p.round_challenge(r, polynomials); REQUIRE(evaluation_point[0] == r); diff --git a/sxt/cbindings/backend/callback_sumcheck_transcript.h b/sxt/cbindings/backend/callback_sumcheck_transcript.h index 0c7da99f2..b2270cbef 100644 --- a/sxt/cbindings/backend/callback_sumcheck_transcript.h +++ b/sxt/cbindings/backend/callback_sumcheck_transcript.h @@ -23,7 +23,7 @@ namespace sxt::cbnbck { // callback_sumcheck_transcript //-------------------------------------------------------------------------------------------------- template -class callback_sumcheck_transcript final : public prfsk2::sumcheck_transcript { +class callback_sumcheck_transcript final : public prfsk::sumcheck_transcript { public: using callback_t = void (*)(T* r, void* context, const T* polynomial, unsigned polynomial_len); diff --git a/sxt/cbindings/backend/cpu_backend.cc b/sxt/cbindings/backend/cpu_backend.cc index 755b26610..28296e537 100644 --- a/sxt/cbindings/backend/cpu_backend.cc +++ b/sxt/cbindings/backend/cpu_backend.cc @@ -105,10 +105,10 @@ void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi descriptor.product_terms, descriptor.num_product_terms, }; - prfsk2::cpu_driver drv; + prfsk::cpu_driver drv; auto fut = - prfsk2::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, - mles_span, product_table_span, product_terms_span, descriptor.n); + prfsk::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, mles_span, + product_table_span, product_terms_span, descriptor.n); SXT_RELEASE_ASSERT(fut.ready()); }); } diff --git a/sxt/cbindings/backend/gpu_backend.cc b/sxt/cbindings/backend/gpu_backend.cc index eccb0c7a0..b07fa69ad 100644 --- a/sxt/cbindings/backend/gpu_backend.cc +++ b/sxt/cbindings/backend/gpu_backend.cc @@ -138,10 +138,10 @@ void gpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi descriptor.product_terms, descriptor.num_product_terms, }; - prfsk2::chunked_gpu_driver drv; + prfsk::chunked_gpu_driver drv; auto fut = - prfsk2::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, - mles_span, product_table_span, product_terms_span, descriptor.n); + prfsk::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, mles_span, + product_table_span, product_terms_span, descriptor.n); xens::get_scheduler().run(); }); } diff --git a/sxt/proof/sumcheck/chunked_gpu_driver.h b/sxt/proof/sumcheck/chunked_gpu_driver.h index 016775932..fbe687a13 100644 --- a/sxt/proof/sumcheck/chunked_gpu_driver.h +++ b/sxt/proof/sumcheck/chunked_gpu_driver.h @@ -31,7 +31,7 @@ #include "sxt/proof/sumcheck/mle_utility.h" #include "sxt/proof/sumcheck/sum_gpu.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // chunked_gpu_driver //-------------------------------------------------------------------------------------------------- @@ -134,4 +134,4 @@ template class chunked_gpu_driver final : public driver { private: double no_chunk_cutoff_; }; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/chunked_gpu_driver.t.cc b/sxt/proof/sumcheck/chunked_gpu_driver.t.cc index 14bed91a7..5642a3a8c 100644 --- a/sxt/proof/sumcheck/chunked_gpu_driver.t.cc +++ b/sxt/proof/sumcheck/chunked_gpu_driver.t.cc @@ -20,7 +20,7 @@ #include "sxt/proof/sumcheck/driver_test.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; TEST_CASE("we can perform the primitive operations for sumcheck proofs") { SECTION("we handle the case when only chunking is used") { diff --git a/sxt/proof/sumcheck/constant.h b/sxt/proof/sumcheck/constant.h index 0b8fc5348..dec9e8a33 100644 --- a/sxt/proof/sumcheck/constant.h +++ b/sxt/proof/sumcheck/constant.h @@ -16,11 +16,11 @@ */ #pragma once -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // max_degree_v //-------------------------------------------------------------------------------------------------- // the maximum degree of the round polynomial // used in sumcheck constexpr unsigned max_degree_v = 5u; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/cpu_driver.h b/sxt/proof/sumcheck/cpu_driver.h index ae25aa8ce..ec82f3a5d 100644 --- a/sxt/proof/sumcheck/cpu_driver.h +++ b/sxt/proof/sumcheck/cpu_driver.h @@ -24,7 +24,7 @@ #include "sxt/proof/sumcheck/driver.h" #include "sxt/proof/sumcheck/polynomial_utility.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // cpu_driver //-------------------------------------------------------------------------------------------------- @@ -141,4 +141,4 @@ template class cpu_driver final : public driver { return xena::make_ready_future(); } }; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/cpu_driver.t.cc b/sxt/proof/sumcheck/cpu_driver.t.cc index 47913ec41..2544c1bc2 100644 --- a/sxt/proof/sumcheck/cpu_driver.t.cc +++ b/sxt/proof/sumcheck/cpu_driver.t.cc @@ -20,7 +20,7 @@ #include "sxt/proof/sumcheck/driver_test.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; TEST_CASE("we can perform the primitive operations for sumcheck proofs") { cpu_driver drv; diff --git a/sxt/proof/sumcheck/device_cache.h b/sxt/proof/sumcheck/device_cache.h index 951524f32..f2510501d 100644 --- a/sxt/proof/sumcheck/device_cache.h +++ b/sxt/proof/sumcheck/device_cache.h @@ -29,7 +29,7 @@ #include "sxt/memory/resource/device_resource.h" #include "sxt/scalar25/type/element.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // device_cache_data //-------------------------------------------------------------------------------------------------- @@ -86,4 +86,4 @@ template class device_cache { basct::cspan product_terms_; basdv::device_map>> data_; }; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/device_cache.t.cc b/sxt/proof/sumcheck/device_cache.t.cc index 9c28851c8..cad5ea828 100644 --- a/sxt/proof/sumcheck/device_cache.t.cc +++ b/sxt/proof/sumcheck/device_cache.t.cc @@ -26,7 +26,7 @@ #include "sxt/scalar25/type/literal.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; using s25t::operator""_s25; TEST_CASE("we can cache device values that don't change as a proof is computed") { diff --git a/sxt/proof/sumcheck/driver.h b/sxt/proof/sumcheck/driver.h index 8cd5b8742..3dcd2ae2b 100644 --- a/sxt/proof/sumcheck/driver.h +++ b/sxt/proof/sumcheck/driver.h @@ -23,7 +23,7 @@ #include "sxt/execution/async/future_fwd.h" #include "sxt/proof/sumcheck/workspace.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // driver //-------------------------------------------------------------------------------------------------- @@ -39,4 +39,4 @@ template class driver { virtual xena::future<> fold(workspace& ws, const T& r) const noexcept = 0; }; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/driver_test.cc b/sxt/proof/sumcheck/driver_test.cc index befd57857..58bac88a4 100644 --- a/sxt/proof/sumcheck/driver_test.cc +++ b/sxt/proof/sumcheck/driver_test.cc @@ -27,7 +27,7 @@ #include "sxt/scalar25/type/element.h" #include "sxt/scalar25/type/literal.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { using s25t::operator""_s25; //-------------------------------------------------------------------------------------------------- @@ -130,4 +130,4 @@ void exercise_driver(const driver& drv) { REQUIRE(p[1] == mles[1] - mles[0]); } } -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/driver_test.h b/sxt/proof/sumcheck/driver_test.h index 8294f5c0c..f13db0b55 100644 --- a/sxt/proof/sumcheck/driver_test.h +++ b/sxt/proof/sumcheck/driver_test.h @@ -19,9 +19,9 @@ #include "sxt/proof/sumcheck/driver.h" #include "sxt/scalar25/realization/field.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // exercise_driver //-------------------------------------------------------------------------------------------------- void exercise_driver(const driver& drv); -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/fold_gpu.h b/sxt/proof/sumcheck/fold_gpu.h index 4a6581c4b..c5498a21b 100644 --- a/sxt/proof/sumcheck/fold_gpu.h +++ b/sxt/proof/sumcheck/fold_gpu.h @@ -42,7 +42,7 @@ #include "sxt/scalar25/type/element.h" #include "sxt/scalar25/type/literal.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // fold_kernel //-------------------------------------------------------------------------------------------------- @@ -137,4 +137,4 @@ xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, unsigned n, }; co_await fold_gpu(mles_p, split_options, mles, n, r); } -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/fold_gpu.t.cc b/sxt/proof/sumcheck/fold_gpu.t.cc index 17da2fd6c..edb00aaa1 100644 --- a/sxt/proof/sumcheck/fold_gpu.t.cc +++ b/sxt/proof/sumcheck/fold_gpu.t.cc @@ -27,7 +27,7 @@ #include "sxt/scalar25/type/literal.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; using s25t::operator""_s25; TEST_CASE("we can fold scalars using the gpu") { diff --git a/sxt/proof/sumcheck/gpu_driver.h b/sxt/proof/sumcheck/gpu_driver.h index 3fc9ca260..1fd638b87 100644 --- a/sxt/proof/sumcheck/gpu_driver.h +++ b/sxt/proof/sumcheck/gpu_driver.h @@ -30,7 +30,7 @@ #include "sxt/proof/sumcheck/driver.h" #include "sxt/proof/sumcheck/sum_gpu.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // gpu_driver //-------------------------------------------------------------------------------------------------- @@ -149,4 +149,4 @@ template class gpu_driver final : public driver { work.mles = std::move(mles_p); } }; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/gpu_driver.t.cc b/sxt/proof/sumcheck/gpu_driver.t.cc index fff764959..565ac6fa7 100644 --- a/sxt/proof/sumcheck/gpu_driver.t.cc +++ b/sxt/proof/sumcheck/gpu_driver.t.cc @@ -20,7 +20,7 @@ #include "sxt/proof/sumcheck/driver_test.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; TEST_CASE("we can perform the primitive operations for sumcheck proofs") { gpu_driver drv; diff --git a/sxt/proof/sumcheck/mle_utility.h b/sxt/proof/sumcheck/mle_utility.h index d5b2e3413..1fec67281 100644 --- a/sxt/proof/sumcheck/mle_utility.h +++ b/sxt/proof/sumcheck/mle_utility.h @@ -30,7 +30,7 @@ #include "sxt/base/num/divide_up.h" #include "sxt/memory/management/managed_array.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // copy_partial_mles //-------------------------------------------------------------------------------------------------- @@ -96,4 +96,4 @@ template double get_gpu_memory_fraction(basct::cspan mles auto total_memory = static_cast(basdv::get_total_device_memory()); return static_cast(mles.size() * sizeof(T)) / total_memory; } -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/mle_utility.t.cc b/sxt/proof/sumcheck/mle_utility.t.cc index 20e48bc07..2494b2c57 100644 --- a/sxt/proof/sumcheck/mle_utility.t.cc +++ b/sxt/proof/sumcheck/mle_utility.t.cc @@ -27,7 +27,7 @@ #include "sxt/scalar25/type/literal.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; using s25t::operator""_s25; using T = s25t::element; diff --git a/sxt/proof/sumcheck/polynomial_mapper.h b/sxt/proof/sumcheck/polynomial_mapper.h index 0b4c9f50d..a44981847 100644 --- a/sxt/proof/sumcheck/polynomial_mapper.h +++ b/sxt/proof/sumcheck/polynomial_mapper.h @@ -20,7 +20,7 @@ #include "sxt/base/macro/cuda_callable.h" #include "sxt/proof/sumcheck/polynomial_utility.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // polynomial_mapper //-------------------------------------------------------------------------------------------------- @@ -48,4 +48,4 @@ template struct polynomial_mapper { unsigned split; unsigned n; }; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_mapper.t.cc b/sxt/proof/sumcheck/polynomial_mapper.t.cc index 55b44a88f..4f940c53f 100644 --- a/sxt/proof/sumcheck/polynomial_mapper.t.cc +++ b/sxt/proof/sumcheck/polynomial_mapper.t.cc @@ -24,7 +24,7 @@ #include "sxt/scalar25/type/literal.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; using s25t::operator""_s25; using T = s25t::element; diff --git a/sxt/proof/sumcheck/polynomial_reducer.h b/sxt/proof/sumcheck/polynomial_reducer.h index 565fb1f49..92fd0110d 100644 --- a/sxt/proof/sumcheck/polynomial_reducer.h +++ b/sxt/proof/sumcheck/polynomial_reducer.h @@ -19,7 +19,7 @@ #include "sxt/base/field/element.h" #include "sxt/base/macro/cuda_callable.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // polynomial_reducer //-------------------------------------------------------------------------------------------------- @@ -32,4 +32,4 @@ template struct polynomial_reducer { } } }; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_utility.h b/sxt/proof/sumcheck/polynomial_utility.h index 3ae6aa949..65c21608d 100644 --- a/sxt/proof/sumcheck/polynomial_utility.h +++ b/sxt/proof/sumcheck/polynomial_utility.h @@ -22,7 +22,7 @@ #include "sxt/base/field/element.h" #include "sxt/base/macro/cuda_callable.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // sum_polynomial_01 //-------------------------------------------------------------------------------------------------- @@ -135,4 +135,4 @@ CUDA_CALLABLE void partial_expand_products(basct::span p, const T* mles, unsi mul(p[i + 1u], c_prev, b); } } -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_utility.t.cc b/sxt/proof/sumcheck/polynomial_utility.t.cc index 9d429dc8b..821124e47 100644 --- a/sxt/proof/sumcheck/polynomial_utility.t.cc +++ b/sxt/proof/sumcheck/polynomial_utility.t.cc @@ -24,7 +24,7 @@ #include "sxt/scalar25/type/literal.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; using s25t::operator""_s25; using T = s25t::element; diff --git a/sxt/proof/sumcheck/proof_computation.h b/sxt/proof/sumcheck/proof_computation.h index 23539d665..b3e1328d3 100644 --- a/sxt/proof/sumcheck/proof_computation.h +++ b/sxt/proof/sumcheck/proof_computation.h @@ -24,7 +24,7 @@ #include "sxt/proof/sumcheck/driver.h" #include "sxt/proof/sumcheck/sumcheck_transcript.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // prove_sum //-------------------------------------------------------------------------------------------------- @@ -67,4 +67,4 @@ xena::future<> prove_sum(basct::span polynomials, basct::span evaluation_p } } } -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/proof_computation.t.cc b/sxt/proof/sumcheck/proof_computation.t.cc index 8321698a7..11be3a42d 100644 --- a/sxt/proof/sumcheck/proof_computation.t.cc +++ b/sxt/proof/sumcheck/proof_computation.t.cc @@ -39,7 +39,7 @@ #include "sxt/scalar25/type/literal.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; using s25t::operator""_s25; using T = s25t::element; diff --git a/sxt/proof/sumcheck/reduction_gpu.h b/sxt/proof/sumcheck/reduction_gpu.h index 638c17802..ddb3995cf 100644 --- a/sxt/proof/sumcheck/reduction_gpu.h +++ b/sxt/proof/sumcheck/reduction_gpu.h @@ -32,7 +32,7 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // reduction_kernel //-------------------------------------------------------------------------------------------------- @@ -110,4 +110,4 @@ xena::future<> reduce_sums(basct::span p, basdv::stream& stream, } } } -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/reduction_gpu.t.cc b/sxt/proof/sumcheck/reduction_gpu.t.cc index d3da21f26..ce4a8e73e 100644 --- a/sxt/proof/sumcheck/reduction_gpu.t.cc +++ b/sxt/proof/sumcheck/reduction_gpu.t.cc @@ -28,7 +28,7 @@ #include "sxt/scalar25/type/literal.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; using s25t::operator""_s25; TEST_CASE("we can reduce sumcheck polynomials") { diff --git a/sxt/proof/sumcheck/reference_transcript.h b/sxt/proof/sumcheck/reference_transcript.h index cc23cb7f6..265ae1a34 100644 --- a/sxt/proof/sumcheck/reference_transcript.h +++ b/sxt/proof/sumcheck/reference_transcript.h @@ -20,7 +20,7 @@ #include "sxt/proof/transcript/transcript.h" #include "sxt/proof/transcript/transcript_utility.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // reference_transcript //-------------------------------------------------------------------------------------------------- @@ -42,4 +42,4 @@ template class reference_transcript final : public sumcheck_ private: prft::transcript& transcript_; }; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/reference_transcript.t.cc b/sxt/proof/sumcheck/reference_transcript.t.cc index b97eaf6ac..373fea8c5 100644 --- a/sxt/proof/sumcheck/reference_transcript.t.cc +++ b/sxt/proof/sumcheck/reference_transcript.t.cc @@ -23,7 +23,7 @@ #include "sxt/scalar25/type/literal.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; using sxt::s25t::operator""_s25; TEST_CASE("we provide an implementation of sumcheck transcript") { diff --git a/sxt/proof/sumcheck/sum_gpu.h b/sxt/proof/sumcheck/sum_gpu.h index 14adec7c3..8f8cfb82e 100644 --- a/sxt/proof/sumcheck/sum_gpu.h +++ b/sxt/proof/sumcheck/sum_gpu.h @@ -42,7 +42,7 @@ #include "sxt/proof/sumcheck/polynomial_reducer.h" #include "sxt/proof/sumcheck/reduction_gpu.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // sum_options //-------------------------------------------------------------------------------------------------- @@ -223,4 +223,4 @@ xena::future<> sum_gpu(basct::span p, basct::cspan mles, basdv::stream stream; co_await partial_sum(p, stream, mles, product_table, product_terms, mid, n); } -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/sum_gpu.t.cc b/sxt/proof/sumcheck/sum_gpu.t.cc index 48c022388..5538698ca 100644 --- a/sxt/proof/sumcheck/sum_gpu.t.cc +++ b/sxt/proof/sumcheck/sum_gpu.t.cc @@ -28,7 +28,7 @@ #include "sxt/scalar25/type/literal.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; using s25t::operator""_s25; TEST_CASE("we can sum MLEs") { diff --git a/sxt/proof/sumcheck/sumcheck_random.cc b/sxt/proof/sumcheck/sumcheck_random.cc index 04db4e7d5..d3f164e3a 100644 --- a/sxt/proof/sumcheck/sumcheck_random.cc +++ b/sxt/proof/sumcheck/sumcheck_random.cc @@ -23,7 +23,7 @@ #include "sxt/scalar25/random/element.h" #include "sxt/scalar25/type/element.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // generate_random_sumcheck_problem //-------------------------------------------------------------------------------------------------- @@ -74,4 +74,4 @@ void generate_random_sumcheck_problem( term = mle_dist(rng_p); } } -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/sumcheck_random.h b/sxt/proof/sumcheck/sumcheck_random.h index 5c88b944d..f58cd0f8e 100644 --- a/sxt/proof/sumcheck/sumcheck_random.h +++ b/sxt/proof/sumcheck/sumcheck_random.h @@ -28,7 +28,7 @@ namespace sxt::basn { class fast_random_number_generator; } -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // random_sumcheck_descriptor //-------------------------------------------------------------------------------------------------- @@ -54,4 +54,4 @@ void generate_random_sumcheck_problem( std::vector>& product_table, std::vector& product_terms, unsigned& n, basn::fast_random_number_generator& rng, const random_sumcheck_descriptor& descriptor) noexcept; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/sumcheck_transcript.h b/sxt/proof/sumcheck/sumcheck_transcript.h index 8447212e0..db871f62a 100644 --- a/sxt/proof/sumcheck/sumcheck_transcript.h +++ b/sxt/proof/sumcheck/sumcheck_transcript.h @@ -19,7 +19,7 @@ #include "sxt/base/container/span.h" #include "sxt/base/field/element.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // sumcheck_transcript //-------------------------------------------------------------------------------------------------- @@ -31,4 +31,4 @@ template class sumcheck_transcript { virtual void round_challenge(T& r, basct::cspan polynomial) noexcept = 0; }; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/verification.h b/sxt/proof/sumcheck/verification.h index 75ee6b428..127725aaa 100644 --- a/sxt/proof/sumcheck/verification.h +++ b/sxt/proof/sumcheck/verification.h @@ -22,7 +22,7 @@ #include "sxt/proof/sumcheck/polynomial_utility.h" #include "sxt/proof/sumcheck/sumcheck_transcript.h" -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // verify_sumcheck_no_evaluation //-------------------------------------------------------------------------------------------------- @@ -74,4 +74,4 @@ bool verify_sumcheck_no_evaluation(T& expected_sum, basct::span evaluation_po return true; } -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/verification.t.cc b/sxt/proof/sumcheck/verification.t.cc index e7b614f61..3c1bc60ef 100644 --- a/sxt/proof/sumcheck/verification.t.cc +++ b/sxt/proof/sumcheck/verification.t.cc @@ -26,7 +26,7 @@ #include "sxt/scalar25/type/literal.h" using namespace sxt; -using namespace sxt::prfsk2; +using namespace sxt::prfsk; using sxt::s25t::operator""_s25; TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { diff --git a/sxt/proof/sumcheck/workspace.h b/sxt/proof/sumcheck/workspace.h index 8fae4355d..2a46631e1 100644 --- a/sxt/proof/sumcheck/workspace.h +++ b/sxt/proof/sumcheck/workspace.h @@ -16,7 +16,7 @@ */ #pragma once -namespace sxt::prfsk2 { +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // workspace //-------------------------------------------------------------------------------------------------- @@ -24,4 +24,4 @@ class workspace { public: virtual ~workspace() noexcept = default; }; -} // namespace sxt::prfsk2 +} // namespace sxt::prfsk From 084caf6660322804aa4ab8add2abaaa3cabf14c1 Mon Sep 17 00:00:00 2001 From: rnburn Date: Tue, 25 Feb 2025 20:07:08 -0800 Subject: [PATCH 83/83] refactor --- sxt/cbindings/backend/BUILD | 2 -- sxt/cbindings/backend/cpu_backend.cc | 1 - sxt/cbindings/backend/gpu_backend.cc | 1 - sxt/cbindings/base/BUILD | 2 +- sxt/cbindings/base/field_id_utility.h | 2 +- 5 files changed, 2 insertions(+), 6 deletions(-) diff --git a/sxt/cbindings/backend/BUILD b/sxt/cbindings/backend/BUILD index 15b317d0a..a4a2e1b42 100644 --- a/sxt/cbindings/backend/BUILD +++ b/sxt/cbindings/backend/BUILD @@ -99,7 +99,6 @@ sxt_cc_component( "//sxt/proof/inner_product:gpu_driver", "//sxt/proof/sumcheck:chunked_gpu_driver", "//sxt/proof/sumcheck:proof_computation", - "//sxt/scalar25/realization:field", ], with_test = False, deps = [ @@ -158,7 +157,6 @@ sxt_cc_component( "//sxt/proof/inner_product:cpu_driver", "//sxt/proof/sumcheck:cpu_driver", "//sxt/proof/sumcheck:proof_computation", - "//sxt/scalar25/realization:field", ], with_test = False, deps = [ diff --git a/sxt/cbindings/backend/cpu_backend.cc b/sxt/cbindings/backend/cpu_backend.cc index 28296e537..676057963 100644 --- a/sxt/cbindings/backend/cpu_backend.cc +++ b/sxt/cbindings/backend/cpu_backend.cc @@ -64,7 +64,6 @@ #include "sxt/proof/transcript/transcript.h" #include "sxt/ristretto/operation/compression.h" #include "sxt/ristretto/type/compressed_element.h" -#include "sxt/scalar25/realization/field.h" #include "sxt/seqcommit/generator/precomputed_generators.h" namespace sxt::cbnbck { diff --git a/sxt/cbindings/backend/gpu_backend.cc b/sxt/cbindings/backend/gpu_backend.cc index b07fa69ad..3f1c95bad 100644 --- a/sxt/cbindings/backend/gpu_backend.cc +++ b/sxt/cbindings/backend/gpu_backend.cc @@ -69,7 +69,6 @@ #include "sxt/ristretto/operation/compression.h" #include "sxt/ristretto/type/compressed_element.h" #include "sxt/ristretto/type/literal.h" -#include "sxt/scalar25/realization/field.h" #include "sxt/seqcommit/generator/precomputed_generators.h" using sxt::rstt::operator""_rs; diff --git a/sxt/cbindings/base/BUILD b/sxt/cbindings/base/BUILD index 9d4137b40..b32dd7e8a 100644 --- a/sxt/cbindings/base/BUILD +++ b/sxt/cbindings/base/BUILD @@ -57,7 +57,7 @@ sxt_cc_component( deps = [ ":field_id", "//sxt/base/error:panic", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", ], ) diff --git a/sxt/cbindings/base/field_id_utility.h b/sxt/cbindings/base/field_id_utility.h index 497ff4964..6c305ba2e 100644 --- a/sxt/cbindings/base/field_id_utility.h +++ b/sxt/cbindings/base/field_id_utility.h @@ -20,7 +20,7 @@ #include "sxt/base/error/panic.h" #include "sxt/cbindings/base/field_id.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" namespace sxt::cbnb { //--------------------------------------------------------------------------------------------------