diff --git a/benchmark/sumcheck/BUILD b/benchmark/sumcheck/BUILD index 14d521e5f..22a847b72 100644 --- a/benchmark/sumcheck/BUILD +++ b/benchmark/sumcheck/BUILD @@ -19,6 +19,6 @@ sxt_cc_benchmark( "//sxt/proof/sumcheck:reference_transcript", "//sxt/proof/transcript", "//sxt/scalar25/random:element", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", ], ) diff --git a/benchmark/sumcheck/benchmark.m.cc b/benchmark/sumcheck/benchmark.m.cc index 348e43e4f..6b0e64fa7 100644 --- a/benchmark/sumcheck/benchmark.m.cc +++ b/benchmark/sumcheck/benchmark.m.cc @@ -32,7 +32,7 @@ #include "sxt/proof/sumcheck/reference_transcript.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/scalar25/random/element.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" using namespace sxt; @@ -111,13 +111,13 @@ int main(int argc, char* argv[]) { memmg::managed_array polynomials((p.degree + 1u) * num_rounds); memmg::managed_array evaluation_point(num_rounds); prft::transcript base_transcript{"abc123"}; - prfsk::reference_transcript transcript{base_transcript}; - prfsk::gpu_driver drv; + prfsk::reference_transcript transcript{base_transcript}; + prfsk::gpu_driver drv; // initial run { - auto fut = prfsk::prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, p.n); + auto fut = prfsk::prove_sum(polynomials, evaluation_point, transcript, drv, mles, + product_table, product_terms, p.n); xens::get_scheduler().run(); } @@ -125,8 +125,8 @@ int main(int argc, char* argv[]) { double elapse = 0; for (unsigned i = 0; i < (p.num_samples + 1u); ++i) { auto t1 = std::chrono::steady_clock::now(); - auto fut = prfsk::prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, p.n); + auto fut = prfsk::prove_sum(polynomials, evaluation_point, transcript, drv, mles, + product_table, product_terms, p.n); xens::get_scheduler().run(); auto t2 = std::chrono::steady_clock::now(); auto duration = std::chrono::duration_cast(t2 - t1); diff --git a/cbindings/BUILD b/cbindings/BUILD index b975cfa9d..0c49e4024 100644 --- a/cbindings/BUILD +++ b/cbindings/BUILD @@ -226,7 +226,7 @@ sxt_cc_component( "//sxt/base/test:unit_test", "//sxt/proof/sumcheck:reference_transcript", "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], deps = [ diff --git a/cbindings/sumcheck.t.cc b/cbindings/sumcheck.t.cc index 282204848..2b0c739af 100644 --- a/cbindings/sumcheck.t.cc +++ b/cbindings/sumcheck.t.cc @@ -22,7 +22,7 @@ #include "sxt/base/test/unit_test.h" #include "sxt/proof/sumcheck/reference_transcript.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; @@ -30,7 +30,7 @@ using s25t::operator""_s25; TEST_CASE("we can create sumcheck proofs") { prft::transcript base_transcript{"abc"}; - prfsk::reference_transcript transcript{base_transcript}; + prfsk::reference_transcript transcript{base_transcript}; std::vector polynomials(2); std::vector evaluation_point(1); @@ -55,7 +55,7 @@ TEST_CASE("we can create sumcheck proofs") { auto f = [](s25t::element* r, void* context, const s25t::element* polynomial, unsigned polynomial_len) noexcept { - static_cast(context)->round_challenge( + static_cast*>(context)->round_challenge( *r, {polynomial, polynomial_len}); }; @@ -70,7 +70,7 @@ TEST_CASE("we can create sumcheck proofs") { REQUIRE(polynomials[1] == mles[1] - mles[0]); { prft::transcript base_transcript_p{"abc"}; - prfsk::reference_transcript transcript_p{base_transcript_p}; + prfsk::reference_transcript transcript_p{base_transcript_p}; s25t::element r; transcript_p.round_challenge(r, polynomials); REQUIRE(evaluation_point[0] == r); @@ -88,7 +88,7 @@ TEST_CASE("we can create sumcheck proofs") { REQUIRE(polynomials[1] == mles[1] - mles[0]); { prft::transcript base_transcript_p{"abc"}; - prfsk::reference_transcript transcript_p{base_transcript_p}; + prfsk::reference_transcript transcript_p{base_transcript_p}; s25t::element r; transcript_p.round_challenge(r, polynomials); REQUIRE(evaluation_point[0] == r); diff --git a/sxt/base/concept/BUILD b/sxt/base/concept/BUILD index 1e3e001ce..a6c9e77bf 100644 --- a/sxt/base/concept/BUILD +++ b/sxt/base/concept/BUILD @@ -9,3 +9,8 @@ sxt_cc_component( "//sxt/base/test:unit_test", ], ) + +sxt_cc_component( + name = "field", + with_test = False, +) diff --git a/sxt/base/concept/field.cc b/sxt/base/concept/field.cc new file mode 100644 index 000000000..16a6731c6 --- /dev/null +++ b/sxt/base/concept/field.cc @@ -0,0 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/base/concept/field.h" diff --git a/sxt/base/concept/field.h b/sxt/base/concept/field.h new file mode 100644 index 000000000..aee6c1224 --- /dev/null +++ b/sxt/base/concept/field.h @@ -0,0 +1,31 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace sxt::bascpt { +//-------------------------------------------------------------------------------------------------- +// field +//-------------------------------------------------------------------------------------------------- +template +concept field = requires(T& res, const T& e) { + neg(res, e); + add(res, e, e); + sub(res, e, e); + mul(res, e, e); + muladd(res, e, e, e); +}; +} // namespace sxt::bascpt diff --git a/sxt/base/field/BUILD b/sxt/base/field/BUILD index 68838f458..7a2a89d23 100644 --- a/sxt/base/field/BUILD +++ b/sxt/base/field/BUILD @@ -14,3 +14,17 @@ sxt_cc_component( "//sxt/base/type:narrow_cast", ], ) + +sxt_cc_component( + name = "element", + with_test = False, +) + +sxt_cc_component( + name = "accumulator", + with_test = False, + deps = [ + ":element", + "//sxt/base/macro:cuda_callable", + ], +) diff --git a/sxt/base/field/accumulator.cc b/sxt/base/field/accumulator.cc new file mode 100644 index 000000000..fe0f80fa5 --- /dev/null +++ b/sxt/base/field/accumulator.cc @@ -0,0 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/base/field/accumulator.h" diff --git a/sxt/base/field/accumulator.h b/sxt/base/field/accumulator.h new file mode 100644 index 000000000..a693cd08a --- /dev/null +++ b/sxt/base/field/accumulator.h @@ -0,0 +1,31 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "sxt/base/field/element.h" +#include "sxt/base/macro/cuda_callable.h" + +namespace sxt::basfld { +//-------------------------------------------------------------------------------------------------- +// accumulator +//-------------------------------------------------------------------------------------------------- +template struct accumulator { + using value_type = T; + + CUDA_CALLABLE static void accumulate_inplace(T& res, T& e) noexcept { add(res, res, e); } +}; +} // namespace sxt::basfld diff --git a/sxt/base/field/element.cc b/sxt/base/field/element.cc new file mode 100644 index 000000000..0c91592ea --- /dev/null +++ b/sxt/base/field/element.cc @@ -0,0 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/base/field/element.h" diff --git a/sxt/base/field/element.h b/sxt/base/field/element.h new file mode 100644 index 000000000..f6feb08d3 --- /dev/null +++ b/sxt/base/field/element.h @@ -0,0 +1,35 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace sxt::basfld { +//-------------------------------------------------------------------------------------------------- +// element +//-------------------------------------------------------------------------------------------------- +template +concept element = requires(T& res, const T& e) { + neg(res, e); + add(res, e, e); + sub(res, e, e); + mul(res, e, e); + muladd(res, e, e, e); + { T::identity() } noexcept -> std::same_as; + { T::one() } noexcept -> std::same_as; +}; +} // namespace sxt::basfld diff --git a/sxt/cbindings/backend/callback_sumcheck_transcript.h b/sxt/cbindings/backend/callback_sumcheck_transcript.h index 8f4c9c908..b2270cbef 100644 --- a/sxt/cbindings/backend/callback_sumcheck_transcript.h +++ b/sxt/cbindings/backend/callback_sumcheck_transcript.h @@ -22,16 +22,16 @@ namespace sxt::cbnbck { //-------------------------------------------------------------------------------------------------- // callback_sumcheck_transcript //-------------------------------------------------------------------------------------------------- -class callback_sumcheck_transcript final : public prfsk::sumcheck_transcript { +template +class callback_sumcheck_transcript final : public prfsk::sumcheck_transcript { public: - using callback_t = void (*)(s25t::element* r, void* context, const s25t::element* polynomial, - unsigned polynomial_len); + using callback_t = void (*)(T* r, void* context, const T* polynomial, unsigned polynomial_len); callback_sumcheck_transcript(callback_t f, void* context) noexcept : f_{f}, context_{context} {} void init(size_t /*num_variables*/, size_t /*round_degree*/) noexcept override {} - void round_challenge(s25t::element& r, basct::cspan polynomial) noexcept override { + void round_challenge(T& r, basct::cspan polynomial) noexcept override { f_(&r, context_, polynomial.data(), static_cast(polynomial.size())); } diff --git a/sxt/cbindings/backend/cpu_backend.cc b/sxt/cbindings/backend/cpu_backend.cc index 29d33d48e..676057963 100644 --- a/sxt/cbindings/backend/cpu_backend.cc +++ b/sxt/cbindings/backend/cpu_backend.cc @@ -64,17 +64,12 @@ #include "sxt/proof/transcript/transcript.h" #include "sxt/ristretto/operation/compression.h" #include "sxt/ristretto/type/compressed_element.h" -#include "sxt/scalar25/type/element.h" #include "sxt/seqcommit/generator/precomputed_generators.h" namespace sxt::cbnbck { //-------------------------------------------------------------------------------------------------- // prove_sumcheck //-------------------------------------------------------------------------------------------------- -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-parameter" void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsigned field_id, const cbnb::sumcheck_descriptor& descriptor, void* transcript_callback, void* transcript_context) noexcept { @@ -83,8 +78,8 @@ void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi static_cast(field_id), [&](std::type_identity) noexcept { static_assert(std::same_as, "only support curve-255 right now"); // transcript - callback_sumcheck_transcript transcript{ - reinterpret_cast( + callback_sumcheck_transcript transcript{ + reinterpret_cast::callback_t>( const_cast(transcript_callback)), transcript_context}; @@ -109,14 +104,13 @@ void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi descriptor.product_terms, descriptor.num_product_terms, }; - prfsk::cpu_driver drv; + prfsk::cpu_driver drv; auto fut = - prfsk::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, mles_span, - product_table_span, product_terms_span, descriptor.n); + prfsk::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, mles_span, + product_table_span, product_terms_span, descriptor.n); SXT_RELEASE_ASSERT(fut.ready()); }); } -#pragma clang diagnostic pop //-------------------------------------------------------------------------------------------------- // compute_commitments diff --git a/sxt/cbindings/backend/gpu_backend.cc b/sxt/cbindings/backend/gpu_backend.cc index 6dc4208c4..3f1c95bad 100644 --- a/sxt/cbindings/backend/gpu_backend.cc +++ b/sxt/cbindings/backend/gpu_backend.cc @@ -69,7 +69,6 @@ #include "sxt/ristretto/operation/compression.h" #include "sxt/ristretto/type/compressed_element.h" #include "sxt/ristretto/type/literal.h" -#include "sxt/scalar25/type/element.h" #include "sxt/seqcommit/generator/precomputed_generators.h" using sxt::rstt::operator""_rs; @@ -112,8 +111,8 @@ void gpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi static_cast(field_id), [&](std::type_identity) noexcept { static_assert(std::same_as, "only support curve-255 right now"); // transcript - callback_sumcheck_transcript transcript{ - reinterpret_cast( + callback_sumcheck_transcript transcript{ + reinterpret_cast::callback_t>( const_cast(transcript_callback)), transcript_context}; @@ -138,10 +137,10 @@ void gpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi descriptor.product_terms, descriptor.num_product_terms, }; - prfsk::chunked_gpu_driver drv; + prfsk::chunked_gpu_driver drv; auto fut = - prfsk::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, mles_span, - product_table_span, product_terms_span, descriptor.n); + prfsk::prove_sum(polynomials_span, evaluation_point_span, transcript, drv, mles_span, + product_table_span, product_terms_span, descriptor.n); xens::get_scheduler().run(); }); } diff --git a/sxt/cbindings/base/BUILD b/sxt/cbindings/base/BUILD index 9d4137b40..b32dd7e8a 100644 --- a/sxt/cbindings/base/BUILD +++ b/sxt/cbindings/base/BUILD @@ -57,7 +57,7 @@ sxt_cc_component( deps = [ ":field_id", "//sxt/base/error:panic", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", ], ) diff --git a/sxt/cbindings/base/field_id_utility.h b/sxt/cbindings/base/field_id_utility.h index 497ff4964..6c305ba2e 100644 --- a/sxt/cbindings/base/field_id_utility.h +++ b/sxt/cbindings/base/field_id_utility.h @@ -20,7 +20,7 @@ #include "sxt/base/error/panic.h" #include "sxt/cbindings/base/field_id.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" namespace sxt::cbnb { //-------------------------------------------------------------------------------------------------- diff --git a/sxt/proof/sumcheck/BUILD b/sxt/proof/sumcheck/BUILD index 6fdc8dbd3..47dfd00fe 100644 --- a/sxt/proof/sumcheck/BUILD +++ b/sxt/proof/sumcheck/BUILD @@ -6,125 +6,54 @@ load( sxt_cc_component( name = "constant", with_test = False, + deps = [ + ], ) sxt_cc_component( - name = "device_cache", - impl_deps = [ - "//sxt/base/device:memory_utility", - "//sxt/base/device:state", - "//sxt/base/device:stream", - "//sxt/memory/resource:device_resource", + name = "workspace", + with_test = False, + deps = [ ], +) + +sxt_cc_component( + name = "device_cache", test_deps = [ "//sxt/base/device:memory_utility", "//sxt/base/device:stream", "//sxt/base/device:synchronization", "//sxt/base/test:unit_test", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], deps = [ "//sxt/base/container:span", "//sxt/base/device:device_map", - "//sxt/memory/management:managed_array", - "//sxt/scalar25/type:element", - ], -) - -sxt_cc_component( - name = "fold_gpu", - impl_deps = [ - ":mle_utility", - "//sxt/algorithm/iteration:kernel_fit", - "//sxt/base/error:assert", - "//sxt/base/device:property", "//sxt/base/device:memory_utility", + "//sxt/base/device:state", "//sxt/base/device:stream", - "//sxt/base/iterator:split", - "//sxt/base/num:ceil_log2", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:muladd", - "//sxt/scalar25/operation:sub", - "//sxt/execution/async:coroutine", - "//sxt/execution/async:future", - "//sxt/execution/device:for_each", - "//sxt/execution/device:synchronization", - "//sxt/execution/kernel:kernel_dims", + "//sxt/base/field:element", "//sxt/memory/management:managed_array", - "//sxt/memory/resource:async_device_resource", "//sxt/memory/resource:device_resource", ], - test_deps = [ - "//sxt/base/iterator:split", - "//sxt/base/test:unit_test", - "//sxt/execution/async:future", - "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], - deps = [ - "//sxt/base/container:span", - "//sxt/execution/async:future_fwd", - ], ) sxt_cc_component( - name = "mle_utility", - impl_deps = [ - "//sxt/base/container:span_utility", - "//sxt/base/device:memory_utility", - "//sxt/base/device:property", - "//sxt/base/num:divide_up", - "//sxt/base/num:ceil_log2", - "//sxt/memory/management:managed_array", - "//sxt/scalar25/type:element", - ], - test_deps = [ - "//sxt/base/device:stream", - "//sxt/base/device:synchronization", - "//sxt/base/test:unit_test", - "//sxt/memory/management:managed_array", - "//sxt/memory/resource:managed_device_resource", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", - ], + name = "driver", + with_test = False, deps = [ + ":workspace", "//sxt/base/container:span", - "//sxt/memory/management:managed_array_fwd", + "//sxt/base/field:element", + "//sxt/execution/async:future_fwd", ], ) sxt_cc_component( - name = "sum_gpu", + name = "driver_test", impl_deps = [ - ":constant", - ":device_cache", - ":mle_utility", - ":polynomial_mapper", - ":reduction_gpu", - "//sxt/algorithm/reduction:kernel_fit", - "//sxt/algorithm/reduction:thread_reduction", - "//sxt/base/device:memory_utility", - "//sxt/base/device:stream", - "//sxt/base/device:state", - "//sxt/base/iterator:split", - "//sxt/base/num:ceil_log2", - "//sxt/base/num:constexpr_switch", - "//sxt/execution/async:coroutine", - "//sxt/execution/async:future", - "//sxt/execution/device:for_each", - "//sxt/execution/kernel:kernel_dims", - "//sxt/memory/resource:device_resource", - "//sxt/memory/resource:async_device_resource", - "//sxt/scalar25/operation:add", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/type:element", - ], - test_deps = [ - ":device_cache", - "//sxt/base/iterator:split", + ":workspace", "//sxt/base/test:unit_test", "//sxt/execution/async:future", "//sxt/execution/schedule:scheduler", @@ -132,199 +61,174 @@ sxt_cc_component( "//sxt/scalar25/type:element", "//sxt/scalar25/type:literal", ], - deps = [ - "//sxt/base/container:span", - "//sxt/execution/async:future_fwd", - ], -) - -sxt_cc_component( - name = "sumcheck_transcript", with_test = False, deps = [ - "//sxt/base/container:span", + ":driver", + "//sxt/scalar25/realization:field", ], ) sxt_cc_component( - name = "reference_transcript", - impl_deps = [ - "//sxt/scalar25/type:element", - "//sxt/proof/transcript:transcript_utility", - ], + name = "cpu_driver", test_deps = [ + ":driver_test", "//sxt/base/test:unit_test", - "//sxt/scalar25/type:literal", ], deps = [ - ":sumcheck_transcript", - "//sxt/proof/transcript", - ], -) - -sxt_cc_component( - name = "sumcheck_random", - impl_deps = [ + ":driver", + ":polynomial_utility", + "//sxt/base/container:stack_array", "//sxt/base/error:assert", - "//sxt/base/num:fast_random_number_generator", - "//sxt/scalar25/random:element", - "//sxt/scalar25/type:element", - ], - with_test = False, - deps = [ - ":constant", - ], -) - -sxt_cc_component( - name = "workspace", - with_test = False, -) - -sxt_cc_component( - name = "driver", - with_test = False, - deps = [ - ":workspace", - "//sxt/base/container:span", - "//sxt/execution/async:future_fwd", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:coroutine", + "//sxt/memory/management:managed_array", ], ) sxt_cc_component( - name = "driver_test", - impl_deps = [ - ":driver", - ":workspace", + name = "fold_gpu", + test_deps = [ + "//sxt/base/iterator:split", "//sxt/base/test:unit_test", "//sxt/execution/async:future", - "//sxt/execution/schedule:scheduler", "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], - with_test = False, -) - -sxt_cc_component( - name = "cpu_driver", - impl_deps = [ - ":polynomial_utility", - "//sxt/base/container:stack_array", - "//sxt/base/error:panic", + deps = [ + ":mle_utility", + "//sxt/algorithm/iteration:kernel_fit", + "//sxt/base/container:span", + "//sxt/base/device:memory_utility", + "//sxt/base/device:property", + "//sxt/base/device:stream", + "//sxt/base/error:assert", + "//sxt/base/field:element", + "//sxt/base/iterator:split", "//sxt/base/num:ceil_log2", + "//sxt/execution/async:coroutine", "//sxt/execution/async:future", + "//sxt/execution/device:for_each", + "//sxt/execution/device:synchronization", + "//sxt/execution/kernel:kernel_dims", "//sxt/memory/management:managed_array", + "//sxt/memory/resource:async_device_resource", + "//sxt/memory/resource:device_resource", "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:sub", "//sxt/scalar25/operation:muladd", + "//sxt/scalar25/operation:sub", "//sxt/scalar25/type:element", "//sxt/scalar25/type:literal", ], +) + +sxt_cc_component( + name = "gpu_driver", test_deps = [ ":driver_test", "//sxt/base/test:unit_test", ], deps = [ ":driver", - ":workspace", - ], -) - -sxt_cc_component( - name = "gpu_driver", - impl_deps = [ - ":constant", - ":polynomial_mapper", ":sum_gpu", "//sxt/algorithm/iteration:for_each", - "//sxt/algorithm/reduction:reduction", - "//sxt/base/container:stack_array", - "//sxt/base/device:stream", "//sxt/base/device:memory_utility", - "//sxt/base/error:panic", + "//sxt/base/device:stream", + "//sxt/base/error:assert", "//sxt/base/num:ceil_log2", - "//sxt/base/num:constexpr_switch", "//sxt/execution/async:coroutine", - "//sxt/execution/async:future", "//sxt/execution/device:synchronization", "//sxt/memory/management:managed_array", - "//sxt/memory/resource:async_device_resource", "//sxt/memory/resource:device_resource", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:sub", - "//sxt/scalar25/operation:muladd", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", ], +) + +sxt_cc_component( + name = "chunked_gpu_driver", test_deps = [ ":driver_test", "//sxt/base/test:unit_test", ], deps = [ + ":device_cache", ":driver", + ":fold_gpu", + ":gpu_driver", + ":mle_utility", + ":sum_gpu", + "//sxt/algorithm/iteration:transform", + "//sxt/base/error:assert", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:coroutine", + "//sxt/execution/async:future", + "//sxt/memory/management:managed_array", ], ) sxt_cc_component( - name = "polynomial_utility", + name = "mle_utility", impl_deps = [ - "//sxt/scalar25/operation:add", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:sub", - "//sxt/scalar25/operation:muladd", - "//sxt/scalar25/operation:neg", - "//sxt/scalar25/type:element", + "//sxt/base/container:span_utility", + "//sxt/base/device:memory_utility", + "//sxt/base/device:property", + "//sxt/base/container:span", + "//sxt/base/field:element", + "//sxt/base/num:divide_up", + "//sxt/base/num:ceil_log2", + "//sxt/memory/management:managed_array", ], + test_deps = [ + "//sxt/base/device:stream", + "//sxt/base/device:synchronization", + "//sxt/base/test:unit_test", + "//sxt/memory/management:managed_array", + "//sxt/memory/resource:managed_device_resource", + "//sxt/scalar25/realization:field", + "//sxt/scalar25/type:literal", + ], +) + +sxt_cc_component( + name = "polynomial_mapper", test_deps = [ "//sxt/base/test:unit_test", "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], deps = [ - "//sxt/base/container:span", + ":polynomial_utility", + "//sxt/base/field:element", + "//sxt/base/macro:cuda_callable", ], ) sxt_cc_component( - name = "chunked_gpu_driver", - impl_deps = [ - ":device_cache", - ":fold_gpu", - ":gpu_driver", - ":mle_utility", - ":sum_gpu", - "//sxt/algorithm/iteration:transform", - "//sxt/base/error:assert", - "//sxt/base/num:ceil_log2", - "//sxt/execution/async:coroutine", - "//sxt/execution/async:future", - "//sxt/memory/management:managed_array", - "//sxt/scalar25/operation:sub", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", + name = "polynomial_reducer", + with_test = False, + deps = [ + "//sxt/base/field:element", + "//sxt/base/macro:cuda_callable", ], +) + +sxt_cc_component( + name = "polynomial_utility", test_deps = [ - ":driver_test", "//sxt/base/test:unit_test", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/realization:field", + "//sxt/scalar25/type:literal", ], deps = [ - ":driver", + "//sxt/base/container:span", + "//sxt/base/field:element", + "//sxt/base/macro:cuda_callable", ], ) sxt_cc_component( name = "proof_computation", - impl_deps = [ - ":driver", - ":sumcheck_transcript", - "//sxt/execution/async:future", - "//sxt/base/error:assert", - "//sxt/base/num:ceil_log2", - "//sxt/execution/async:coroutine", - "//sxt/scalar25/type:element", - ], test_deps = [ ":chunked_gpu_driver", ":cpu_driver", @@ -338,64 +242,126 @@ sxt_cc_component( "//sxt/execution/schedule:scheduler", "//sxt/proof/transcript", "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", ], deps = [ + ":driver", + ":sumcheck_transcript", "//sxt/base/container:span", - "//sxt/execution/async:future_fwd", + "//sxt/base/error:assert", + "//sxt/base/field:element", + "//sxt/base/num:ceil_log2", + "//sxt/execution/async:coroutine", + "//sxt/execution/async:future", ], ) sxt_cc_component( - name = "verification", + name = "sumcheck_transcript", + with_test = False, + deps = [ + "//sxt/base/container:span", + "//sxt/base/field:element", + ], +) + +sxt_cc_component( + name = "sumcheck_random", impl_deps = [ - ":polynomial_utility", - ":sumcheck_transcript", "//sxt/base/error:assert", - "//sxt/base/log:log", - "//sxt/scalar25/operation:overload", + "//sxt/base/num:fast_random_number_generator", + "//sxt/scalar25/random:element", "//sxt/scalar25/type:element", ], + with_test = False, + deps = [ + ":constant", + ], +) + +sxt_cc_component( + name = "reference_transcript", test_deps = [ - ":reference_transcript", "//sxt/base/test:unit_test", - "//sxt/scalar25/type:element", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], deps = [ + ":sumcheck_transcript", "//sxt/base/container:span", + "//sxt/base/field:element", + "//sxt/proof/transcript:transcript_utility", ], ) sxt_cc_component( name = "reduction_gpu", - impl_deps = [ + test_deps = [ + "//sxt/base/device:property", + "//sxt/base/device:stream", + "//sxt/base/test:unit_test", + "//sxt/execution/async:future", + "//sxt/execution/schedule:scheduler", + "//sxt/memory/resource:managed_device_resource", + "//sxt/scalar25/operation:overload", + "//sxt/scalar25/realization:field", + "//sxt/scalar25/type:literal", + ], + deps = [ "//sxt/algorithm/base:identity_mapper", "//sxt/algorithm/reduction:kernel_fit", "//sxt/algorithm/reduction:thread_reduction", + "//sxt/base/container:span", "//sxt/base/device:memory_utility", "//sxt/base/device:stream", "//sxt/base/error:assert", + "//sxt/base/field:accumulator", + "//sxt/base/field:element", "//sxt/execution/async:coroutine", "//sxt/execution/async:future", + "//sxt/execution/async:future_fwd", "//sxt/execution/device:synchronization", "//sxt/execution/kernel:kernel_dims", "//sxt/execution/kernel:launch", "//sxt/memory/management:managed_array", "//sxt/memory/resource:async_device_resource", - "//sxt/scalar25/operation:add", - "//sxt/scalar25/operation:accumulator", - "//sxt/scalar25/type:element", ], - test_deps = [ - "//sxt/base/device:property", +) + +sxt_cc_component( + name = "sum_gpu", + impl_deps = [ + ":constant", + ":device_cache", + ":mle_utility", + ":polynomial_mapper", + ":polynomial_reducer", + ":reduction_gpu", + "//sxt/algorithm/reduction:kernel_fit", + "//sxt/algorithm/reduction:thread_reduction", + "//sxt/base/device:memory_utility", "//sxt/base/device:stream", + "//sxt/base/device:state", + "//sxt/base/field:element", + "//sxt/base/iterator:split", + "//sxt/base/num:ceil_log2", + "//sxt/base/num:constexpr_switch", + "//sxt/execution/async:coroutine", + "//sxt/execution/async:future", + "//sxt/execution/device:for_each", + "//sxt/execution/kernel:kernel_dims", + "//sxt/memory/resource:device_resource", + "//sxt/memory/resource:async_device_resource", + ], + test_deps = [ + # ":device_cache", + "//sxt/base/iterator:split", "//sxt/base/test:unit_test", "//sxt/execution/async:future", "//sxt/execution/schedule:scheduler", - "//sxt/memory/resource:managed_device_resource", "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], deps = [ @@ -405,21 +371,24 @@ sxt_cc_component( ) sxt_cc_component( - name = "polynomial_mapper", + name = "verification", + # impl_deps = [ + # ":sumcheck_transcript", + # "//sxt/base/error:assert", + # "//sxt/base/log:log", + # "//sxt/scalar25/type:element", + # ], test_deps = [ - "//sxt/algorithm/base:mapper", + ":reference_transcript", "//sxt/base/test:unit_test", "//sxt/scalar25/operation:overload", - "//sxt/scalar25/type:element", + "//sxt/scalar25/realization:field", "//sxt/scalar25/type:literal", ], deps = [ ":polynomial_utility", - "//sxt/base/macro:cuda_callable", - "//sxt/scalar25/operation:add", - "//sxt/scalar25/operation:mul", - "//sxt/scalar25/operation:muladd", - "//sxt/scalar25/type:element", - "//sxt/scalar25/type:literal", + ":sumcheck_transcript", + "//sxt/base/container:span", + "//sxt/base/log", ], ) diff --git a/sxt/proof/sumcheck/chunked_gpu_driver.cc b/sxt/proof/sumcheck/chunked_gpu_driver.cc index 1358c4cc7..4a652d2fc 100644 --- a/sxt/proof/sumcheck/chunked_gpu_driver.cc +++ b/sxt/proof/sumcheck/chunked_gpu_driver.cc @@ -15,140 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/chunked_gpu_driver.h" - -#include - -#include "sxt/algorithm/iteration/transform.h" -#include "sxt/base/error/assert.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/proof/sumcheck/device_cache.h" -#include "sxt/proof/sumcheck/fold_gpu.h" -#include "sxt/proof/sumcheck/gpu_driver.h" -#include "sxt/proof/sumcheck/mle_utility.h" -#include "sxt/proof/sumcheck/sum_gpu.h" -#include "sxt/scalar25/operation/sub.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// chunked_gpu_workspace -//-------------------------------------------------------------------------------------------------- -namespace { -struct chunked_gpu_workspace final : public workspace { - std::unique_ptr single_gpu_workspace; - - device_cache cache; - memmg::managed_array mles_data; - basct::cspan mles; - unsigned n; - unsigned num_variables; - - chunked_gpu_workspace(basct::cspan> product_table, - basct::cspan product_terms) noexcept - : cache{product_table, product_terms} {} -}; -} // namespace - -//-------------------------------------------------------------------------------------------------- -// try_make_single_gpu_workspace -//-------------------------------------------------------------------------------------------------- -static xena::future<> try_make_single_gpu_workspace(chunked_gpu_workspace& work, - double no_chunk_cutoff) noexcept { - auto gpu_memory_fraction = get_gpu_memory_fraction(work.mles); - if (gpu_memory_fraction > no_chunk_cutoff) { - co_return; - } - - // construct single gpu workspace - auto cache_data = work.cache.clear(); - gpu_driver drv; - work.single_gpu_workspace = - co_await drv.make_workspace(work.mles, std::move(cache_data->product_table), - std::move(cache_data->product_terms), work.n); - - // free data we no longer need - work.mles_data.reset(); - work.mles = {}; -} - -//-------------------------------------------------------------------------------------------------- -// constructor -//-------------------------------------------------------------------------------------------------- -chunked_gpu_driver::chunked_gpu_driver(double no_chunk_cutoff) noexcept - : no_chunk_cutoff_{no_chunk_cutoff} { - SXT_RELEASE_ASSERT(0 <= no_chunk_cutoff_ && no_chunk_cutoff_ <= 1.0); -} - -//-------------------------------------------------------------------------------------------------- -// make_workspace -//-------------------------------------------------------------------------------------------------- -xena::future> -chunked_gpu_driver::make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, - unsigned n) const noexcept { - auto res = std::make_unique(product_table, product_terms); - res->mles = mles; - res->n = n; - res->num_variables = std::max(basn::ceil_log2(n), 1); - auto gpu_memory_fraction = get_gpu_memory_fraction(mles); - if (gpu_memory_fraction <= no_chunk_cutoff_) { - gpu_driver drv; - res->single_gpu_workspace = co_await drv.make_workspace(mles, product_table, product_terms, n); - } - co_return std::unique_ptr(std::move(res)); -} - -//-------------------------------------------------------------------------------------------------- -// sum -//-------------------------------------------------------------------------------------------------- -xena::future<> chunked_gpu_driver::sum(basct::span polynomial, - workspace& ws) const noexcept { - auto& work = static_cast(ws); - if (work.single_gpu_workspace) { - gpu_driver drv; - co_return co_await drv.sum(polynomial, *work.single_gpu_workspace); - } - co_await sum_gpu(polynomial, work.cache, work.mles, work.n); -} - -//-------------------------------------------------------------------------------------------------- -// fold -//-------------------------------------------------------------------------------------------------- -xena::future<> chunked_gpu_driver::fold(workspace& ws, const s25t::element& r) const noexcept { - using s25t::operator""_s25; - auto& work = static_cast(ws); - if (work.single_gpu_workspace) { - gpu_driver drv; - co_return co_await drv.fold(*work.single_gpu_workspace, r); - } - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - auto num_mles = work.mles.size() / n; - SXT_RELEASE_ASSERT( - // clang-format off - work.n >= mid && work.mles.size() % n == 0 - // clang-format on - ); - - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); - - // fold - memmg::managed_array mles_p(num_mles * mid); - co_await fold_gpu(mles_p, work.mles, n, r); - - // update - work.n = mid; - --work.num_variables; - work.mles_data = std::move(mles_p); - work.mles = work.mles_data; - - // check if we should fall back to single gpu workspace - co_await try_make_single_gpu_workspace(work, no_chunk_cutoff_); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/chunked_gpu_driver.h b/sxt/proof/sumcheck/chunked_gpu_driver.h index ed93decbe..fbe687a13 100644 --- a/sxt/proof/sumcheck/chunked_gpu_driver.h +++ b/sxt/proof/sumcheck/chunked_gpu_driver.h @@ -16,25 +16,120 @@ */ #pragma once +#include + +#include "sxt/algorithm/iteration/transform.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/async/future.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/proof/sumcheck/device_cache.h" #include "sxt/proof/sumcheck/driver.h" +#include "sxt/proof/sumcheck/fold_gpu.h" +#include "sxt/proof/sumcheck/gpu_driver.h" +#include "sxt/proof/sumcheck/mle_utility.h" +#include "sxt/proof/sumcheck/sum_gpu.h" namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // chunked_gpu_driver //-------------------------------------------------------------------------------------------------- -class chunked_gpu_driver final : public driver { +template class chunked_gpu_driver final : public driver { + struct chunked_gpu_workspace final : public workspace { + std::unique_ptr single_gpu_workspace; + + device_cache cache; + memmg::managed_array mles_data; + basct::cspan mles; + unsigned n; + unsigned num_variables; + + chunked_gpu_workspace(basct::cspan> product_table, + basct::cspan product_terms) noexcept + : cache{product_table, product_terms} {} + }; + + static xena::future<> try_make_single_gpu_workspace(chunked_gpu_workspace& work, + double no_chunk_cutoff) noexcept { + auto gpu_memory_fraction = get_gpu_memory_fraction(work.mles); + if (gpu_memory_fraction > no_chunk_cutoff) { + co_return; + } + + // construct single gpu workspace + auto cache_data = work.cache.clear(); + gpu_driver drv; + work.single_gpu_workspace = + co_await drv.make_workspace(work.mles, std::move(cache_data->product_table), + std::move(cache_data->product_terms), work.n); + + // free data we no longer need + work.mles_data.reset(); + work.mles = {}; + } + public: - explicit chunked_gpu_driver(double no_chunk_cutoff = 0.5) noexcept; + explicit chunked_gpu_driver(double no_chunk_cutoff = 0.5) noexcept + : no_chunk_cutoff_{no_chunk_cutoff} {} // driver xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept override; + make_workspace(basct::cspan mles, basct::cspan> product_table, + basct::cspan product_terms, unsigned n) const noexcept override { + auto res = std::make_unique(product_table, product_terms); + res->mles = mles; + res->n = n; + res->num_variables = std::max(basn::ceil_log2(n), 1); + auto gpu_memory_fraction = get_gpu_memory_fraction(mles); + if (gpu_memory_fraction <= no_chunk_cutoff_) { + gpu_driver drv; + res->single_gpu_workspace = + co_await drv.make_workspace(mles, product_table, product_terms, n); + } + co_return std::unique_ptr(std::move(res)); + } + + xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { + auto& work = static_cast(ws); + if (work.single_gpu_workspace) { + gpu_driver drv; + co_return co_await drv.sum(polynomial, *work.single_gpu_workspace); + } + co_await sum_gpu(polynomial, work.cache, work.mles, work.n); + } + + xena::future<> fold(workspace& ws, const T& r) const noexcept override { + auto& work = static_cast(ws); + if (work.single_gpu_workspace) { + gpu_driver drv; + co_return co_await drv.fold(*work.single_gpu_workspace, r); + } + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + auto num_mles = work.mles.size() / n; + SXT_RELEASE_ASSERT( + // clang-format off + work.n >= mid && work.mles.size() % n == 0 + // clang-format on + ); + + auto one_m_r = T::one(); + sub(one_m_r, one_m_r, r); + + // fold + memmg::managed_array mles_p(num_mles * mid); + co_await fold_gpu(mles_p, work.mles, n, r); - xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override; + // update + work.n = mid; + --work.num_variables; + work.mles_data = std::move(mles_p); + work.mles = work.mles_data; - xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept override; + // check if we should fall back to single gpu workspace + co_await try_make_single_gpu_workspace(work, no_chunk_cutoff_); + } private: double no_chunk_cutoff_; diff --git a/sxt/proof/sumcheck/chunked_gpu_driver.t.cc b/sxt/proof/sumcheck/chunked_gpu_driver.t.cc index 8fb4a3aac..5642a3a8c 100644 --- a/sxt/proof/sumcheck/chunked_gpu_driver.t.cc +++ b/sxt/proof/sumcheck/chunked_gpu_driver.t.cc @@ -24,12 +24,12 @@ using namespace sxt::prfsk; TEST_CASE("we can perform the primitive operations for sumcheck proofs") { SECTION("we handle the case when only chunking is used") { - chunked_gpu_driver drv{0.0}; + chunked_gpu_driver drv{0.0}; exercise_driver(drv); } SECTION("we handle the case when the chunked driver falls back to the single gpu driver") { - chunked_gpu_driver drv{1.0}; + chunked_gpu_driver drv{1.0}; exercise_driver(drv); } } diff --git a/sxt/proof/sumcheck/constant.cc b/sxt/proof/sumcheck/constant.cc index 396197800..12625ca3b 100644 --- a/sxt/proof/sumcheck/constant.cc +++ b/sxt/proof/sumcheck/constant.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/sxt/proof/sumcheck/constant.h b/sxt/proof/sumcheck/constant.h index aa70ad9a0..dec9e8a33 100644 --- a/sxt/proof/sumcheck/constant.h +++ b/sxt/proof/sumcheck/constant.h @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/sxt/proof/sumcheck/cpu_driver.cc b/sxt/proof/sumcheck/cpu_driver.cc index cc8d69192..7fa91e0da 100644 --- a/sxt/proof/sumcheck/cpu_driver.cc +++ b/sxt/proof/sumcheck/cpu_driver.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,147 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/cpu_driver.h" - -#include - -#include "sxt/base/container/stack_array.h" -#include "sxt/base/error/panic.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/execution/async/future.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/proof/sumcheck/polynomial_utility.h" -#include "sxt/scalar25/operation/mul.h" -#include "sxt/scalar25/operation/muladd.h" -#include "sxt/scalar25/operation/sub.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// cpu_workspace -//-------------------------------------------------------------------------------------------------- -namespace { -struct cpu_workspace final : public workspace { - memmg::managed_array mles; - basct::cspan> product_table; - basct::cspan product_terms; - unsigned n; - unsigned num_variables; -}; -} // namespace - -//-------------------------------------------------------------------------------------------------- -// make_workspace -//-------------------------------------------------------------------------------------------------- -xena::future> -cpu_driver::make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept { - auto res = std::make_unique(); - res->mles = memmg::managed_array{mles.begin(), mles.end()}; - res->product_table = product_table; - res->product_terms = product_terms; - res->n = n; - res->num_variables = std::max(basn::ceil_log2(n), 1); - return xena::make_ready_future>(std::move(res)); -} - -//-------------------------------------------------------------------------------------------------- -// sum -//-------------------------------------------------------------------------------------------------- -xena::future<> cpu_driver::sum(basct::span polynomial, - workspace& ws) const noexcept { - auto& work = static_cast(ws); - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - SXT_RELEASE_ASSERT(work.n >= mid); - - auto mles = work.mles.data(); - auto product_table = work.product_table; - auto product_terms = work.product_terms; - - for (auto& val : polynomial) { - val = {}; - } - - // expand paired terms - auto n1 = work.n - mid; - for (unsigned i = 0; i < n1; ++i) { - unsigned term_first = 0; - for (auto [mult, num_terms] : product_table) { - SXT_RELEASE_ASSERT(num_terms < polynomial.size()); - auto terms = product_terms.subspan(term_first, num_terms); - SXT_STACK_ARRAY(p, num_terms + 1u, s25t::element); - expand_products(p, mles + i, n, mid, terms); - for (unsigned term_index = 0; term_index < p.size(); ++term_index) { - s25o::muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); - } - term_first += num_terms; - } - } - - // expand terms where the corresponding pair is zero (i.e. n is not a power of 2) - for (unsigned i = n1; i < mid; ++i) { - unsigned term_first = 0; - for (auto [mult, num_terms] : product_table) { - auto terms = product_terms.subspan(term_first, num_terms); - SXT_STACK_ARRAY(p, num_terms + 1u, s25t::element); - partial_expand_products(p, mles + i, n, terms); - for (unsigned term_index = 0; term_index < p.size(); ++term_index) { - s25o::muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); - } - term_first += num_terms; - } - } - - return xena::make_ready_future(); -} - -//-------------------------------------------------------------------------------------------------- -// fold -//-------------------------------------------------------------------------------------------------- -xena::future<> cpu_driver::fold(workspace& ws, const s25t::element& r) const noexcept { - using s25t::operator""_s25; - - auto& work = static_cast(ws); - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - auto num_mles = work.mles.size() / n; - SXT_RELEASE_ASSERT( - // clang-format off - work.n >= mid && work.mles.size() % n == 0 - // clang-format on - ); - - auto mles = work.mles.data(); - memmg::managed_array mles_p(num_mles * mid); - - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); - auto n1 = work.n - mid; - for (auto mle_index = 0; mle_index < num_mles; ++mle_index) { - auto data = mles + n * mle_index; - auto data_p = mles_p.data() + mid * mle_index; - - // fold paired terms - for (unsigned i = 0; i < n1; ++i) { - auto val = data[i]; - s25o::mul(val, val, one_m_r); - s25o::muladd(val, r, data[mid + i], val); - data_p[i] = val; - } - - // fold terms paired with zero - for (unsigned i = n1; i < mid; ++i) { - auto val = data[i]; - s25o::mul(val, val, one_m_r); - data_p[i] = val; - } - } - - work.n = mid; - --work.num_variables; - work.mles = std::move(mles_p); - return xena::make_ready_future(); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/cpu_driver.h b/sxt/proof/sumcheck/cpu_driver.h index b017396d1..ec82f3a5d 100644 --- a/sxt/proof/sumcheck/cpu_driver.h +++ b/sxt/proof/sumcheck/cpu_driver.h @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,22 +16,129 @@ */ #pragma once +#include "sxt/base/container/stack_array.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/memory/management/managed_array.h" #include "sxt/proof/sumcheck/driver.h" +#include "sxt/proof/sumcheck/polynomial_utility.h" namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // cpu_driver //-------------------------------------------------------------------------------------------------- -class cpu_driver final : public driver { +template class cpu_driver final : public driver { + struct cpu_workspace final : public workspace { + memmg::managed_array mles; + basct::cspan> product_table; + basct::cspan product_terms; + unsigned n; + unsigned num_variables; + }; + public: // driver xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept override; + make_workspace(basct::cspan mles, basct::cspan> product_table, + basct::cspan product_terms, unsigned n) const noexcept override { + auto res = std::make_unique(); + res->mles = memmg::managed_array{mles.begin(), mles.end()}; + res->product_table = product_table; + res->product_terms = product_terms; + res->n = n; + res->num_variables = std::max(basn::ceil_log2(n), 1); + return xena::make_ready_future>(std::move(res)); + } + + xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + SXT_RELEASE_ASSERT(work.n >= mid); + + auto mles = work.mles.data(); + auto product_table = work.product_table; + auto product_terms = work.product_terms; + + for (auto& val : polynomial) { + val = {}; + } + + // expand paired terms + auto n1 = work.n - mid; + for (unsigned i = 0; i < n1; ++i) { + unsigned term_first = 0; + for (auto [mult, num_terms] : product_table) { + SXT_RELEASE_ASSERT(num_terms < polynomial.size()); + auto terms = product_terms.subspan(term_first, num_terms); + SXT_STACK_ARRAY(p, num_terms + 1u, T); + expand_products(p, mles + i, n, mid, terms); + for (unsigned term_index = 0; term_index < p.size(); ++term_index) { + muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); + } + term_first += num_terms; + } + } + + // expand terms where the corresponding pair is zero (i.e. n is not a power of 2) + for (unsigned i = n1; i < mid; ++i) { + unsigned term_first = 0; + for (auto [mult, num_terms] : product_table) { + auto terms = product_terms.subspan(term_first, num_terms); + SXT_STACK_ARRAY(p, num_terms + 1u, T); + partial_expand_products(p, mles + i, n, terms); + for (unsigned term_index = 0; term_index < p.size(); ++term_index) { + muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]); + } + term_first += num_terms; + } + } + + return xena::make_ready_future(); + } + + xena::future<> fold(workspace& ws, const T& r) const noexcept override { + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + auto num_mles = work.mles.size() / n; + SXT_RELEASE_ASSERT( + // clang-format off + work.n >= mid && work.mles.size() % n == 0 + // clang-format on + ); + + auto mles = work.mles.data(); + memmg::managed_array mles_p(num_mles * mid); + + T one_m_r = T::one(); + sub(one_m_r, one_m_r, r); + auto n1 = work.n - mid; + for (auto mle_index = 0; mle_index < num_mles; ++mle_index) { + auto data = mles + n * mle_index; + auto data_p = mles_p.data() + mid * mle_index; + + // fold paired terms + for (unsigned i = 0; i < n1; ++i) { + auto val = data[i]; + mul(val, val, one_m_r); + muladd(val, r, data[mid + i], val); + data_p[i] = val; + } - xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override; + // fold terms paired with zero + for (unsigned i = n1; i < mid; ++i) { + auto val = data[i]; + mul(val, val, one_m_r); + data_p[i] = val; + } + } - xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept override; + work.n = mid; + --work.num_variables; + work.mles = std::move(mles_p); + return xena::make_ready_future(); + } }; } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/cpu_driver.t.cc b/sxt/proof/sumcheck/cpu_driver.t.cc index 42ae1e7c4..2544c1bc2 100644 --- a/sxt/proof/sumcheck/cpu_driver.t.cc +++ b/sxt/proof/sumcheck/cpu_driver.t.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,6 @@ */ #include "sxt/proof/sumcheck/cpu_driver.h" -#include - #include "sxt/base/test/unit_test.h" #include "sxt/proof/sumcheck/driver_test.h" @@ -25,6 +23,6 @@ using namespace sxt; using namespace sxt::prfsk; TEST_CASE("we can perform the primitive operations for sumcheck proofs") { - cpu_driver drv; + cpu_driver drv; exercise_driver(drv); } diff --git a/sxt/proof/sumcheck/device_cache.cc b/sxt/proof/sumcheck/device_cache.cc index 2daaadc4b..0fb5bce21 100644 --- a/sxt/proof/sumcheck/device_cache.cc +++ b/sxt/proof/sumcheck/device_cache.cc @@ -15,56 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/device_cache.h" - -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/state.h" -#include "sxt/base/device/stream.h" -#include "sxt/memory/resource/device_resource.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// make_device_copy -//-------------------------------------------------------------------------------------------------- -static std::unique_ptr -make_device_copy(basct::cspan> product_table, - basct::cspan product_terms, basdv::stream& stream) noexcept { - device_cache_data res{ - .product_table{product_table.size(), memr::get_device_resource()}, - .product_terms{product_terms.size(), memr::get_device_resource()}, - }; - basdv::async_copy_host_to_device(res.product_table, product_table, stream); - basdv::async_copy_host_to_device(res.product_terms, product_terms, stream); - return std::make_unique(std::move(res)); -} - -//-------------------------------------------------------------------------------------------------- -// constructor -//-------------------------------------------------------------------------------------------------- -device_cache::device_cache(basct::cspan> product_table, - basct::cspan product_terms) noexcept - : product_table_{product_table}, product_terms_{product_terms} {} - -//-------------------------------------------------------------------------------------------------- -// lookup -//-------------------------------------------------------------------------------------------------- -void device_cache::lookup(basct::cspan>& product_table, - basct::cspan& product_terms, basdv::stream& stream) noexcept { - auto& ptr = data_[basdv::get_device()]; - if (ptr == nullptr) { - ptr = make_device_copy(product_table_, product_terms_, stream); - } - product_table = ptr->product_table; - product_terms = ptr->product_terms; -} - -//-------------------------------------------------------------------------------------------------- -// clear -//-------------------------------------------------------------------------------------------------- -std::unique_ptr device_cache::clear() noexcept { - auto res{std::move(data_[basdv::get_device()])}; - for (auto& ptr : data_) { - ptr.reset(); - } - return res; -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/device_cache.h b/sxt/proof/sumcheck/device_cache.h index c61038392..f2510501d 100644 --- a/sxt/proof/sumcheck/device_cache.h +++ b/sxt/proof/sumcheck/device_cache.h @@ -21,38 +21,69 @@ #include "sxt/base/container/span.h" #include "sxt/base/device/device_map.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/state.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/field/element.h" #include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/device_resource.h" #include "sxt/scalar25/type/element.h" -namespace sxt::basdv { -class stream; -} - namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // device_cache_data //-------------------------------------------------------------------------------------------------- -struct device_cache_data { - memmg::managed_array> product_table; +template struct device_cache_data { + memmg::managed_array> product_table; memmg::managed_array product_terms; }; +//-------------------------------------------------------------------------------------------------- +// make_device_copy +//-------------------------------------------------------------------------------------------------- +template +std::unique_ptr> +make_device_copy(basct::cspan> product_table, + basct::cspan product_terms, basdv::stream& stream) noexcept { + device_cache_data res{ + .product_table{product_table.size(), memr::get_device_resource()}, + .product_terms{product_terms.size(), memr::get_device_resource()}, + }; + basdv::async_copy_host_to_device(res.product_table, product_table, stream); + basdv::async_copy_host_to_device(res.product_terms, product_terms, stream); + return std::make_unique>(std::move(res)); +} + //-------------------------------------------------------------------------------------------------- // device_cache //-------------------------------------------------------------------------------------------------- -class device_cache { +template class device_cache { public: - device_cache(basct::cspan> product_table, - basct::cspan product_terms) noexcept; + device_cache(basct::cspan> product_table, + basct::cspan product_terms) noexcept + : product_table_{product_table}, product_terms_{product_terms} {} - void lookup(basct::cspan>& product_table, - basct::cspan& product_terms, basdv::stream& stream) noexcept; + void lookup(basct::cspan>& product_table, + basct::cspan& product_terms, basdv::stream& stream) noexcept { + auto& ptr = data_[basdv::get_device()]; + if (ptr == nullptr) { + ptr = make_device_copy(product_table_, product_terms_, stream); + } + product_table = ptr->product_table; + product_terms = ptr->product_terms; + } - std::unique_ptr clear() noexcept; + std::unique_ptr> clear() noexcept { + auto res{std::move(data_[basdv::get_device()])}; + for (auto& ptr : data_) { + ptr.reset(); + } + return res; + } private: - basct::cspan> product_table_; + basct::cspan> product_table_; basct::cspan product_terms_; - basdv::device_map> data_; + basdv::device_map>> data_; }; } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/device_cache.t.cc b/sxt/proof/sumcheck/device_cache.t.cc index 008440faa..cad5ea828 100644 --- a/sxt/proof/sumcheck/device_cache.t.cc +++ b/sxt/proof/sumcheck/device_cache.t.cc @@ -22,6 +22,7 @@ #include "sxt/base/device/stream.h" #include "sxt/base/device/synchronization.h" #include "sxt/base/test/unit_test.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; @@ -29,6 +30,7 @@ using namespace sxt::prfsk; using s25t::operator""_s25; TEST_CASE("we can cache device values that don't change as a proof is computed") { + using T = s25t::element; std::vector> product_table; std::vector product_terms; @@ -40,7 +42,7 @@ TEST_CASE("we can cache device values that don't change as a proof is computed") SECTION("we can access values from device memory") { product_table = {{0x123_s25, 0}}; product_terms = {0}; - device_cache cache{product_table, product_terms}; + device_cache cache{product_table, product_terms}; cache.lookup(product_table_dev, product_terms_dev, stream); std::vector> product_table_p(product_table.size()); @@ -57,7 +59,7 @@ TEST_CASE("we can cache device values that don't change as a proof is computed") SECTION("we can clear the device cache") { product_table = {{0x123_s25, 0}}; product_terms = {0}; - device_cache cache{product_table, product_terms}; + device_cache cache{product_table, product_terms}; cache.lookup(product_table_dev, product_terms_dev, stream); std::vector> product_table_p(product_table.size()); diff --git a/sxt/proof/sumcheck/driver.cc b/sxt/proof/sumcheck/driver.cc index 6e46927f6..075234c6c 100644 --- a/sxt/proof/sumcheck/driver.cc +++ b/sxt/proof/sumcheck/driver.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/sxt/proof/sumcheck/driver.h b/sxt/proof/sumcheck/driver.h index 422bb88ad..3dcd2ae2b 100644 --- a/sxt/proof/sumcheck/driver.h +++ b/sxt/proof/sumcheck/driver.h @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,29 +19,24 @@ #include #include "sxt/base/container/span.h" +#include "sxt/base/field/element.h" #include "sxt/execution/async/future_fwd.h" #include "sxt/proof/sumcheck/workspace.h" -namespace sxt::s25t { -class element; -} - namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // driver //-------------------------------------------------------------------------------------------------- -class driver { +template class driver { public: virtual ~driver() noexcept = default; virtual xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, + make_workspace(basct::cspan mles, basct::cspan> product_table, basct::cspan product_terms, unsigned n) const noexcept = 0; - virtual xena::future<> sum(basct::span polynomial, - workspace& ws) const noexcept = 0; + virtual xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept = 0; - virtual xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept = 0; + virtual xena::future<> fold(workspace& ws, const T& r) const noexcept = 0; }; } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/driver_test.cc b/sxt/proof/sumcheck/driver_test.cc index 97fa80df1..58bac88a4 100644 --- a/sxt/proof/sumcheck/driver_test.cc +++ b/sxt/proof/sumcheck/driver_test.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ using s25t::operator""_s25; //-------------------------------------------------------------------------------------------------- // exercise_driver //-------------------------------------------------------------------------------------------------- -void exercise_driver(const driver& drv) { +void exercise_driver(const driver& drv) { std::vector mles; std::vector> product_table{ {0x1_s25, 1}, diff --git a/sxt/proof/sumcheck/driver_test.h b/sxt/proof/sumcheck/driver_test.h index 69c0a53ea..f13db0b55 100644 --- a/sxt/proof/sumcheck/driver_test.h +++ b/sxt/proof/sumcheck/driver_test.h @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,12 @@ */ #pragma once -namespace sxt::prfsk { -class driver; +#include "sxt/proof/sumcheck/driver.h" +#include "sxt/scalar25/realization/field.h" +namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // exercise_driver //-------------------------------------------------------------------------------------------------- -void exercise_driver(const driver& drv); +void exercise_driver(const driver& drv); } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/fold_gpu.cc b/sxt/proof/sumcheck/fold_gpu.cc index 5e55e226e..3ade12a7c 100644 --- a/sxt/proof/sumcheck/fold_gpu.cc +++ b/sxt/proof/sumcheck/fold_gpu.cc @@ -15,142 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/fold_gpu.h" - -#include - -#include "sxt/algorithm/iteration/kernel_fit.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/property.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/error/assert.h" -#include "sxt/base/iterator/split.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/device/for_each.h" -#include "sxt/execution/device/synchronization.h" -#include "sxt/execution/kernel/kernel_dims.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/memory/resource/async_device_resource.h" -#include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck/mle_utility.h" -#include "sxt/scalar25/operation/mul.h" -#include "sxt/scalar25/operation/muladd.h" -#include "sxt/scalar25/operation/sub.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// fold_kernel -//-------------------------------------------------------------------------------------------------- -static __global__ void fold_kernel(s25t::element* __restrict__ mles, unsigned np, unsigned split, - s25t::element r, s25t::element one_m_r) noexcept { - auto thread_index = threadIdx.x; - auto block_index = blockIdx.x; - auto block_size = blockDim.x; - auto k = basn::divide_up(split, gridDim.x * block_size) * block_size; - auto block_first = block_index * k; - assert(block_first < split && "every block should be active"); - auto m = umin(block_first + k, split); - - // adjust mles - mles += np * blockIdx.y; - - // fold - auto index = block_first + thread_index; - for (; index < m; index += block_size) { - auto x = mles[index]; - s25o::mul(x, x, one_m_r); - auto index_p = split + index; - if (index_p < np) { - s25o::muladd(x, mles[index_p], r, x); - } - mles[index] = x; - } -} - -//-------------------------------------------------------------------------------------------------- -// fold_impl -//-------------------------------------------------------------------------------------------------- -static xena::future<> fold_impl(basct::span mles_p, basct::cspan mles, - unsigned n, unsigned mid, unsigned a, unsigned b, - const s25t::element& r, const s25t::element one_m_r) noexcept { - auto num_mles = mles.size() / n; - auto split = b - a; - - // copy MLEs to device - basdv::stream stream; - memr::async_device_resource resource{stream}; - memmg::managed_array mles_dev{&resource}; - copy_partial_mles(mles_dev, stream, mles, n, a, b); - - // fold - auto np = mles_dev.size() / num_mles; - auto dims = algi::fit_iteration_kernel(split); - fold_kernel<<(dims.block_size), 0, - stream>>>(mles_dev.data(), np, split, r, one_m_r); - - // copy results back - copy_folded_mles(mles_p, stream, mles_dev, mid, a, b); - - co_await xendv::await_stream(stream); -} - -//-------------------------------------------------------------------------------------------------- -// fold_gpu -//-------------------------------------------------------------------------------------------------- -xena::future<> fold_gpu(basct::span mles_p, - const basit::split_options& split_options, basct::cspan mles, - unsigned n, const s25t::element& r) noexcept { - using s25t::operator""_s25; - auto num_mles = mles.size() / n; - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto mid = 1u << (num_variables - 1u); - SXT_DEBUG_ASSERT( - // clang-format off - n > 1 && mles.size() == num_mles * n - // clang-format on - ); - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); - - // split - auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, mid}, split_options); - - // fold - co_await xendv::concurrent_for_each( - chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> { - co_await fold_impl(mles_p, mles, n, mid, rng.a(), rng.b(), r, one_m_r); - }); -} - -xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, - unsigned n, const s25t::element& r) noexcept { - using s25t::operator""_s25; - auto num_mles = mles.size() / n; - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto mid = 1u << (num_variables - 1u); - SXT_DEBUG_ASSERT( - // clang-format off - n > 1 && mles.size() == num_mles * n - // clang-format on - ); - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); - - // split - basit::split_options split_options{ - .min_chunk_size = 1024u * 128u, - .max_chunk_size = 1024u * 256u, - .split_factor = basdv::get_num_devices(), - }; - auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, mid}, split_options); - - // fold - co_await xendv::concurrent_for_each( - chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> { - co_await fold_impl(mles_p, mles, n, mid, rng.a(), rng.b(), r, one_m_r); - }); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/fold_gpu.h b/sxt/proof/sumcheck/fold_gpu.h index 611d37522..c5498a21b 100644 --- a/sxt/proof/sumcheck/fold_gpu.h +++ b/sxt/proof/sumcheck/fold_gpu.h @@ -16,25 +16,125 @@ */ #pragma once +#include + +#include "sxt/algorithm/iteration/kernel_fit.h" #include "sxt/base/container/span.h" -#include "sxt/execution/async/future_fwd.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/property.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/field/element.h" +#include "sxt/base/iterator/split.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/device/for_each.h" +#include "sxt/execution/device/synchronization.h" +#include "sxt/execution/kernel/kernel_dims.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/async_device_resource.h" +#include "sxt/memory/resource/device_resource.h" +#include "sxt/proof/sumcheck/mle_utility.h" +#include "sxt/scalar25/operation/mul.h" +#include "sxt/scalar25/operation/muladd.h" +#include "sxt/scalar25/operation/sub.h" +#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/type/literal.h" + +namespace sxt::prfsk { +//-------------------------------------------------------------------------------------------------- +// fold_kernel +//-------------------------------------------------------------------------------------------------- +template +__global__ void fold_kernel(T* __restrict__ mles, unsigned np, unsigned split, T r, + T one_m_r) noexcept { + auto thread_index = threadIdx.x; + auto block_index = blockIdx.x; + auto block_size = blockDim.x; + auto k = basn::divide_up(split, gridDim.x * block_size) * block_size; + auto block_first = block_index * k; + assert(block_first < split && "every block should be active"); + auto m = umin(block_first + k, split); -namespace sxt::s25t { -class element; + // adjust mles + mles += np * blockIdx.y; + + // fold + auto index = block_first + thread_index; + for (; index < m; index += block_size) { + auto x = mles[index]; + mul(x, x, one_m_r); + auto index_p = split + index; + if (index_p < np) { + muladd(x, mles[index_p], r, x); + } + mles[index] = x; + } } -namespace sxt::basit { -struct split_options; +//-------------------------------------------------------------------------------------------------- +// fold_impl +//-------------------------------------------------------------------------------------------------- +template +xena::future<> fold_impl(basct::span mles_p, basct::cspan mles, unsigned n, unsigned mid, + unsigned a, unsigned b, const T& r, const T& one_m_r) noexcept { + auto num_mles = mles.size() / n; + auto split = b - a; + + // copy MLEs to device + basdv::stream stream; + memr::async_device_resource resource{stream}; + memmg::managed_array mles_dev{&resource}; + copy_partial_mles(mles_dev, stream, mles, n, a, b); + + // fold + auto np = mles_dev.size() / num_mles; + auto dims = algi::fit_iteration_kernel(split); + fold_kernel<<(dims.block_size), 0, + stream>>>(mles_dev.data(), np, split, r, one_m_r); + + // copy results back + copy_folded_mles(mles_p, stream, mles_dev, mid, a, b); + + co_await xendv::await_stream(stream); } -namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // fold_gpu //-------------------------------------------------------------------------------------------------- -xena::future<> fold_gpu(basct::span mles_p, - const basit::split_options& split_options, basct::cspan mles, - unsigned n, const s25t::element& r) noexcept; +template +xena::future<> fold_gpu(basct::span mles_p, const basit::split_options& split_options, + basct::cspan mles, unsigned n, const T& r) noexcept { + auto num_mles = mles.size() / n; + auto num_variables = std::max(basn::ceil_log2(n), 1); + auto mid = 1u << (num_variables - 1u); + SXT_DEBUG_ASSERT( + // clang-format off + n > 1 && mles.size() == num_mles * n + // clang-format on + ); + auto one_m_r = T::one(); + sub(one_m_r, one_m_r, r); -xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, - unsigned n, const s25t::element& r) noexcept; + // split + auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, mid}, split_options); + + // fold + co_await xendv::concurrent_for_each( + chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> { + co_await fold_impl(mles_p, mles, n, mid, rng.a(), rng.b(), r, one_m_r); + }); +} + +template +xena::future<> fold_gpu(basct::span mles_p, basct::cspan mles, unsigned n, + const T& r) noexcept { + basit::split_options split_options{ + .min_chunk_size = 1024u * 128u, + .max_chunk_size = 1024u * 256u, + .split_factor = basdv::get_num_devices(), + }; + co_await fold_gpu(mles_p, split_options, mles, n, r); +} } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/fold_gpu.t.cc b/sxt/proof/sumcheck/fold_gpu.t.cc index c6d6ae693..edb00aaa1 100644 --- a/sxt/proof/sumcheck/fold_gpu.t.cc +++ b/sxt/proof/sumcheck/fold_gpu.t.cc @@ -23,7 +23,7 @@ #include "sxt/execution/async/future.h" #include "sxt/execution/schedule/scheduler.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; @@ -31,6 +31,7 @@ using namespace sxt::prfsk; using s25t::operator""_s25; TEST_CASE("we can fold scalars using the gpu") { + using T = s25t::element; std::vector mles, mles_p, expected; auto r = 0xabc123_s25; @@ -39,7 +40,7 @@ TEST_CASE("we can fold scalars using the gpu") { SECTION("we can fold a single mle with n=2") { mles = {0x1_s25, 0x2_s25}; mles_p.resize(1); - auto fut = fold_gpu(mles_p, mles, 2, r); + auto fut = fold_gpu(mles_p, mles, 2, r); xens::get_scheduler().run(); REQUIRE(fut.ready()); expected = { @@ -51,7 +52,7 @@ TEST_CASE("we can fold scalars using the gpu") { SECTION("we can fold a single mle with n=3") { mles = {0x123_s25, 0x456_s25, 0x789_s25}; mles_p.resize(2); - auto fut = fold_gpu(mles_p, mles, 3, r); + auto fut = fold_gpu(mles_p, mles, 3, r); xens::get_scheduler().run(); REQUIRE(fut.ready()); expected = { @@ -69,7 +70,7 @@ TEST_CASE("we can fold scalars using the gpu") { }; mles = {0x123_s25, 0x456_s25, 0x789_s25, 0x101112_s25}; mles_p.resize(2); - auto fut = fold_gpu(mles_p, split_options, mles, 4, r); + auto fut = fold_gpu(mles_p, split_options, mles, 4, r); xens::get_scheduler().run(); REQUIRE(fut.ready()); expected = { diff --git a/sxt/proof/sumcheck/gpu_driver.cc b/sxt/proof/sumcheck/gpu_driver.cc index de996b487..a9c0566ff 100644 --- a/sxt/proof/sumcheck/gpu_driver.cc +++ b/sxt/proof/sumcheck/gpu_driver.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,186 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/gpu_driver.h" - -#include -#include - -#include "sxt/algorithm/iteration/for_each.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/device/synchronization.h" -#include "sxt/base/error/panic.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/base/num/constexpr_switch.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/device/synchronization.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/memory/resource/async_device_resource.h" -#include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck/constant.h" -#include "sxt/proof/sumcheck/sum_gpu.h" -#include "sxt/scalar25/operation/mul.h" -#include "sxt/scalar25/operation/muladd.h" -#include "sxt/scalar25/operation/sub.h" -#include "sxt/scalar25/type/element.h" -#include "sxt/scalar25/type/literal.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// gpu_workspace -//-------------------------------------------------------------------------------------------------- -namespace { -struct gpu_workspace final : public workspace { - memmg::managed_array mles; - memmg::managed_array> product_table; - memmg::managed_array product_terms; - unsigned n; - unsigned num_variables; - - gpu_workspace() noexcept - : mles{memr::get_device_resource()}, product_table{memr::get_device_resource()}, - product_terms{memr::get_device_resource()} {} -}; -} // namespace - -//-------------------------------------------------------------------------------------------------- -// make_workspace -//-------------------------------------------------------------------------------------------------- -xena::future> gpu_driver::make_workspace( - basct::cspan mles, - memmg::managed_array>&& product_table_dev, - memmg::managed_array&& product_terms_dev, unsigned n) const noexcept { - auto ws = std::make_unique(); - - // dimensions - ws->n = n; - ws->num_variables = std::max(basn::ceil_log2(n), 1); - - // mles - ws->mles = memmg::managed_array{ - mles.size(), - memr::get_device_resource(), - }; - basdv::stream mle_stream; - basdv::async_copy_host_to_device(ws->mles, mles, mle_stream); - - // product_table - ws->product_table = std::move(product_table_dev); - - // product_terms - ws->product_terms = std::move(product_terms_dev); - - // await - co_await xendv::await_stream(mle_stream); - co_return ws; -} - -xena::future> -gpu_driver::make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept { - auto ws = std::make_unique(); - - // dimensions - ws->n = n; - ws->num_variables = std::max(basn::ceil_log2(n), 1); - - // mles - ws->mles = memmg::managed_array{ - mles.size(), - memr::get_device_resource(), - }; - basdv::stream mle_stream; - basdv::async_copy_host_to_device(ws->mles, mles, mle_stream); - - // product_table - ws->product_table = memmg::managed_array>{ - product_table.size(), - memr::get_device_resource(), - }; - basdv::stream product_table_stream; - basdv::async_copy_host_to_device(ws->product_table, product_table, product_table_stream); - - // product_terms - ws->product_terms = memmg::managed_array{ - product_terms.size(), - memr::get_device_resource(), - }; - basdv::stream product_terms_stream; - basdv::async_copy_host_to_device(ws->product_terms, product_terms, product_terms_stream); - - // await - co_await xendv::await_stream(mle_stream); - co_await xendv::await_stream(product_table_stream); - co_await xendv::await_stream(product_terms_stream); - co_return ws; -} - -//-------------------------------------------------------------------------------------------------- -// sum -//-------------------------------------------------------------------------------------------------- -xena::future<> gpu_driver::sum(basct::span polynomial, - workspace& ws) const noexcept { - auto& work = static_cast(ws); - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - SXT_RELEASE_ASSERT( - // clang-format off - work.n >= mid && - polynomial.size() - 1u <= max_degree_v - // clang-format on - ); - co_await sum_gpu(polynomial, work.mles, work.product_table, work.product_terms, n); -} - -//-------------------------------------------------------------------------------------------------- -// fold -//-------------------------------------------------------------------------------------------------- -xena::future<> gpu_driver::fold(workspace& ws, const s25t::element& r) const noexcept { - using s25t::operator""_s25; - auto& work = static_cast(ws); - auto n = work.n; - auto mid = 1u << (work.num_variables - 1u); - auto num_mles = work.mles.size() / n; - SXT_RELEASE_ASSERT( - // clang-format off - work.n >= mid && work.mles.size() % n == 0 - // clang-format on - ); - - s25t::element one_m_r = 0x1_s25; - s25o::sub(one_m_r, one_m_r, r); - - memmg::managed_array mles_p{num_mles * mid, memr::get_device_resource()}; - - auto f = [ - // clang-format off - mles_p = mles_p.data(), - mles = work.mles.data(), - n = n, - num_mles = num_mles, - r = r, - one_m_r = one_m_r - // clang-format on - ] __device__ - __host__(unsigned mid, unsigned i) noexcept { - for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { - auto val = mles[i + mle_index * n]; - s25o::mul(val, val, one_m_r); - if (mid + i < n) { - s25o::muladd(val, r, mles[mid + i + mle_index * n], val); - } - mles_p[i + mle_index * mid] = val; - } - }; - auto fut = algi::for_each(f, mid); - - // complete - co_await std::move(fut); - - work.n = mid; - --work.num_variables; - work.mles = std::move(mles_p); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/gpu_driver.h b/sxt/proof/sumcheck/gpu_driver.h index 891acb583..1fd638b87 100644 --- a/sxt/proof/sumcheck/gpu_driver.h +++ b/sxt/proof/sumcheck/gpu_driver.h @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,28 +16,137 @@ */ #pragma once -#include "sxt/memory/management/managed_array_fwd.h" +#include + +#include "sxt/algorithm/iteration/for_each.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/device/synchronization.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/device_resource.h" #include "sxt/proof/sumcheck/driver.h" +#include "sxt/proof/sumcheck/sum_gpu.h" namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // gpu_driver //-------------------------------------------------------------------------------------------------- -class gpu_driver final : public driver { +template class gpu_driver final : public driver { public: - xena::future> - make_workspace(basct::cspan mles, - memmg::managed_array>&& product_table_dev, - memmg::managed_array&& product_terms_dev, unsigned n) const noexcept; + struct gpu_workspace final : public workspace { + memmg::managed_array mles; + memmg::managed_array> product_table; + memmg::managed_array product_terms; + unsigned n; + unsigned num_variables; + + gpu_workspace() noexcept + : mles{memr::get_device_resource()}, product_table{memr::get_device_resource()}, + product_terms{memr::get_device_resource()} {} + }; // driver xena::future> - make_workspace(basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) const noexcept override; + make_workspace(basct::cspan mles, basct::cspan> product_table, + basct::cspan product_terms, unsigned n) const noexcept override { + auto ws = std::make_unique(); + + // dimensions + ws->n = n; + ws->num_variables = std::max(basn::ceil_log2(n), 1); + + // mles + ws->mles = memmg::managed_array{ + mles.size(), + memr::get_device_resource(), + }; + basdv::stream mle_stream; + basdv::async_copy_host_to_device(ws->mles, mles, mle_stream); + + // product_table + ws->product_table = memmg::managed_array>{ + product_table.size(), + memr::get_device_resource(), + }; + basdv::stream product_table_stream; + basdv::async_copy_host_to_device(ws->product_table, product_table, product_table_stream); + + // product_terms + ws->product_terms = memmg::managed_array{ + product_terms.size(), + memr::get_device_resource(), + }; + basdv::stream product_terms_stream; + basdv::async_copy_host_to_device(ws->product_terms, product_terms, product_terms_stream); + + // await + co_await xendv::await_stream(mle_stream); + co_await xendv::await_stream(product_table_stream); + co_await xendv::await_stream(product_terms_stream); + co_return ws; + } + + xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override { + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + SXT_RELEASE_ASSERT( + // clang-format off + work.n >= mid && + polynomial.size() - 1u <= max_degree_v + // clang-format on + ); + co_await sum_gpu(polynomial, work.mles, work.product_table, work.product_terms, n); + } + + xena::future<> fold(workspace& ws, const T& r) const noexcept override { + auto& work = static_cast(ws); + auto n = work.n; + auto mid = 1u << (work.num_variables - 1u); + auto num_mles = work.mles.size() / n; + SXT_RELEASE_ASSERT( + // clang-format off + work.n >= mid && work.mles.size() % n == 0 + // clang-format on + ); + + T one_m_r = T::one(); + sub(one_m_r, one_m_r, r); + + memmg::managed_array mles_p{num_mles * mid, memr::get_device_resource()}; + + auto f = + [ + // clang-format off + mles_p = mles_p.data(), + mles = work.mles.data(), + n = n, + num_mles = num_mles, + r = r, + one_m_r = one_m_r + // clang-format on + ] __device__ + __host__(unsigned mid, unsigned i) noexcept { + for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { + auto val = mles[i + mle_index * n]; + mul(val, val, one_m_r); + if (mid + i < n) { + muladd(val, r, mles[mid + i + mle_index * n], val); + } + mles_p[i + mle_index * mid] = val; + } + }; + auto fut = algi::for_each(f, mid); - xena::future<> sum(basct::span polynomial, workspace& ws) const noexcept override; + // complete + co_await std::move(fut); - xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept override; + work.n = mid; + --work.num_variables; + work.mles = std::move(mles_p); + } }; } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/gpu_driver.t.cc b/sxt/proof/sumcheck/gpu_driver.t.cc index cf0acbba0..565ac6fa7 100644 --- a/sxt/proof/sumcheck/gpu_driver.t.cc +++ b/sxt/proof/sumcheck/gpu_driver.t.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,6 @@ using namespace sxt; using namespace sxt::prfsk; TEST_CASE("we can perform the primitive operations for sumcheck proofs") { - gpu_driver drv; + gpu_driver drv; exercise_driver(drv); } diff --git a/sxt/proof/sumcheck/mle_utility.cc b/sxt/proof/sumcheck/mle_utility.cc index 8bcf33f36..e548908cc 100644 --- a/sxt/proof/sumcheck/mle_utility.cc +++ b/sxt/proof/sumcheck/mle_utility.cc @@ -15,84 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/mle_utility.h" - -#include -#include - -#include "sxt/base/container/span_utility.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/property.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/error/assert.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/base/num/divide_up.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// copy_partial_mles -//-------------------------------------------------------------------------------------------------- -void copy_partial_mles(memmg::managed_array& partial_mles, basdv::stream& stream, - basct::cspan mles, unsigned n, unsigned a, - unsigned b) noexcept { - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto mid = 1u << (num_variables - 1u); - auto num_mles = mles.size() / n; - auto part1_size = b - a; - SXT_DEBUG_ASSERT(a < b && b <= n); - auto ap = std::min(mid + a, n); - auto bp = std::min(mid + b, n); - auto part2_size = bp - ap; - - // resize array - auto partial_length = part1_size + part2_size; - partial_mles.resize(partial_length * num_mles); - - // copy data - for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { - // first part - auto src = mles.subspan(n * mle_index + a, part1_size); - auto dst = basct::subspan(partial_mles, partial_length * mle_index, part1_size); - basdv::async_copy_host_to_device(dst, src, stream); - - // second part - src = mles.subspan(n * mle_index + ap, part2_size); - dst = basct::subspan(partial_mles, partial_length * mle_index + part1_size, part2_size); - if (!src.empty()) { - basdv::async_copy_host_to_device(dst, src, stream); - } - } -} - -//-------------------------------------------------------------------------------------------------- -// copy_folded_mles -//-------------------------------------------------------------------------------------------------- -void copy_folded_mles(basct::span host_mles, basdv::stream& stream, - basct::cspan device_mles, unsigned np, unsigned a, - unsigned b) noexcept { - auto num_mles = host_mles.size() / np; - auto slice_n = device_mles.size() / num_mles; - auto slice_np = b - a; - SXT_DEBUG_ASSERT( - // clang-format off - host_mles.size() == num_mles * np && - device_mles.size() == num_mles * slice_n && - b <= np - // clang-format on - ); - for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { - auto src = device_mles.subspan(mle_index * slice_n, slice_np); - auto dst = host_mles.subspan(mle_index * np + a, slice_np); - basdv::async_copy_device_to_host(dst, src, stream); - } -} - -//-------------------------------------------------------------------------------------------------- -// get_gpu_memory_fraction -//-------------------------------------------------------------------------------------------------- -double get_gpu_memory_fraction(basct::cspan mles) noexcept { - auto total_memory = static_cast(basdv::get_total_device_memory()); - return static_cast(mles.size() * sizeof(s25t::element)) / total_memory; -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/mle_utility.h b/sxt/proof/sumcheck/mle_utility.h index 9f3285a89..1fec67281 100644 --- a/sxt/proof/sumcheck/mle_utility.h +++ b/sxt/proof/sumcheck/mle_utility.h @@ -16,33 +16,84 @@ */ #pragma once -#include "sxt/base/container/span.h" -#include "sxt/memory/management/managed_array_fwd.h" +#include +#include -namespace sxt::basdv { -class stream; -} -namespace sxt::s25t { -class element; -} +#include "sxt/base/container/span.h" +#include "sxt/base/container/span_utility.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/property.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/field/element.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/base/num/divide_up.h" +#include "sxt/memory/management/managed_array.h" namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // copy_partial_mles //-------------------------------------------------------------------------------------------------- -void copy_partial_mles(memmg::managed_array& partial_mles, basdv::stream& stream, - basct::cspan mles, unsigned n, unsigned a, - unsigned b) noexcept; +template +void copy_partial_mles(memmg::managed_array& partial_mles, basdv::stream& stream, + basct::cspan mles, unsigned n, unsigned a, unsigned b) noexcept { + auto num_variables = std::max(basn::ceil_log2(n), 1); + auto mid = 1u << (num_variables - 1u); + auto num_mles = mles.size() / n; + auto part1_size = b - a; + SXT_DEBUG_ASSERT(a < b && b <= n); + auto ap = std::min(mid + a, n); + auto bp = std::min(mid + b, n); + auto part2_size = bp - ap; + + // resize array + auto partial_length = part1_size + part2_size; + partial_mles.resize(partial_length * num_mles); + + // copy data + for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { + // first part + auto src = mles.subspan(n * mle_index + a, part1_size); + auto dst = basct::subspan(partial_mles, partial_length * mle_index, part1_size); + basdv::async_copy_host_to_device(dst, src, stream); + + // second part + src = mles.subspan(n * mle_index + ap, part2_size); + dst = basct::subspan(partial_mles, partial_length * mle_index + part1_size, part2_size); + if (!src.empty()) { + basdv::async_copy_host_to_device(dst, src, stream); + } + } +} //-------------------------------------------------------------------------------------------------- // copy_folded_mles //-------------------------------------------------------------------------------------------------- -void copy_folded_mles(basct::span host_mles, basdv::stream& stream, - basct::cspan device_mles, unsigned np, unsigned a, - unsigned b) noexcept; +template +void copy_folded_mles(basct::span host_mles, basdv::stream& stream, basct::cspan device_mles, + unsigned np, unsigned a, unsigned b) noexcept { + auto num_mles = host_mles.size() / np; + auto slice_n = device_mles.size() / num_mles; + auto slice_np = b - a; + SXT_DEBUG_ASSERT( + // clang-format off + host_mles.size() == num_mles * np && + device_mles.size() == num_mles * slice_n && + b <= np + // clang-format on + ); + for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) { + auto src = device_mles.subspan(mle_index * slice_n, slice_np); + auto dst = host_mles.subspan(mle_index * np + a, slice_np); + basdv::async_copy_device_to_host(dst, src, stream); + } +} //-------------------------------------------------------------------------------------------------- // get_gpu_memory_fraction //-------------------------------------------------------------------------------------------------- -double get_gpu_memory_fraction(basct::cspan mles) noexcept; +template double get_gpu_memory_fraction(basct::cspan mles) noexcept { + auto total_memory = static_cast(basdv::get_total_device_memory()); + return static_cast(mles.size() * sizeof(T)) / total_memory; +} } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/mle_utility.t.cc b/sxt/proof/sumcheck/mle_utility.t.cc index f930943dd..2494b2c57 100644 --- a/sxt/proof/sumcheck/mle_utility.t.cc +++ b/sxt/proof/sumcheck/mle_utility.t.cc @@ -23,13 +23,15 @@ #include "sxt/base/test/unit_test.h" #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/managed_device_resource.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; using namespace sxt::prfsk; using s25t::operator""_s25; +using T = s25t::element; + TEST_CASE("we can copy a slice of mles to device memory") { std::pmr::vector mles{memr::get_managed_device_resource()}; memmg::managed_array partial_mles{memr::get_managed_device_resource()}; @@ -38,7 +40,7 @@ TEST_CASE("we can copy a slice of mles to device memory") { SECTION("we can copy an mle with a single element") { mles = {0x123_s25}; - copy_partial_mles(partial_mles, stream, mles, 1, 0, 1); + copy_partial_mles(partial_mles, stream, mles, 1, 0, 1); basdv::synchronize_stream(stream); memmg::managed_array expected = {0x123_s25}; REQUIRE(partial_mles == expected); @@ -46,7 +48,7 @@ TEST_CASE("we can copy a slice of mles to device memory") { SECTION("we can copy a slice of MLEs") { mles = {0x1_s25, 0x2_s25, 0x3_s25, 0x4_s25, 0x5_s25, 0x6_s25}; - copy_partial_mles(partial_mles, stream, mles, 3, 0, 1); + copy_partial_mles(partial_mles, stream, mles, 3, 0, 1); basdv::synchronize_stream(stream); memmg::managed_array expected = {0x1_s25, 0x3_s25, 0x4_s25, 0x6_s25}; REQUIRE(partial_mles == expected); @@ -62,7 +64,7 @@ TEST_CASE("we can copy partially folded MLEs to the host") { SECTION("we can copy a single element") { device_mles = {0x123_s25}; host_mles.resize(1); - copy_folded_mles(host_mles, stream, device_mles, 1, 0, 1); + copy_folded_mles(host_mles, stream, device_mles, 1, 0, 1); basdv::synchronize_stream(stream); std::vector expected = {0x123_s25}; REQUIRE(host_mles == expected); @@ -71,7 +73,7 @@ TEST_CASE("we can copy partially folded MLEs to the host") { SECTION("we can copy partially folded MLEs") { device_mles = {0x123_s25, 0x456_s25}; host_mles.resize(4); - copy_folded_mles(host_mles, stream, device_mles, 2, 0, 1); + copy_folded_mles(host_mles, stream, device_mles, 2, 0, 1); basdv::synchronize_stream(stream); std::vector expected = {0x123_s25, 0x0_s25, 0x456_s25, 0x0_s25}; REQUIRE(host_mles == expected); @@ -81,14 +83,14 @@ TEST_CASE("we can copy partially folded MLEs to the host") { TEST_CASE("we can query the fraction of device memory taken by MLEs") { std::vector mles; - SECTION("we handle the zero case") { REQUIRE(get_gpu_memory_fraction(mles) == 0.0); } + SECTION("we handle the zero case") { REQUIRE(get_gpu_memory_fraction(mles) == 0.0); } SECTION("the fractions doubles if the length of mles doubles") { mles.resize(1); - auto f1 = get_gpu_memory_fraction(mles); + auto f1 = get_gpu_memory_fraction(mles); REQUIRE(f1 > 0); mles.resize(2); - auto f2 = get_gpu_memory_fraction(mles); + auto f2 = get_gpu_memory_fraction(mles); REQUIRE(f2 == Catch::Approx(2 * f1)); } } diff --git a/sxt/proof/sumcheck/polynomial_mapper.h b/sxt/proof/sumcheck/polynomial_mapper.h index a10dc935a..a44981847 100644 --- a/sxt/proof/sumcheck/polynomial_mapper.h +++ b/sxt/proof/sumcheck/polynomial_mapper.h @@ -16,16 +16,16 @@ */ #pragma once +#include "sxt/base/field/element.h" #include "sxt/base/macro/cuda_callable.h" #include "sxt/proof/sumcheck/polynomial_utility.h" -#include "sxt/scalar25/type/element.h" namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // polynomial_mapper //-------------------------------------------------------------------------------------------------- -template struct polynomial_mapper { - using value_type = std::array; +template struct polynomial_mapper { + using value_type = std::array; CUDA_CALLABLE value_type map_index(unsigned index) const noexcept { @@ -37,13 +37,13 @@ template struct polynomial_mapper { CUDA_CALLABLE void map_index(value_type& p, unsigned index) const noexcept { if (index + split < n) { - expand_products(p, mles + index, n, split, {product_terms, Degree}); + expand_products(p, mles + index, n, split, {product_terms, Degree}); } else { - partial_expand_products(p, mles + index, n, {product_terms, Degree}); + partial_expand_products(p, mles + index, n, {product_terms, Degree}); } } - const s25t::element* __restrict__ mles; + const T* __restrict__ mles; const unsigned* __restrict__ product_terms; unsigned split; unsigned n; diff --git a/sxt/proof/sumcheck/polynomial_mapper.t.cc b/sxt/proof/sumcheck/polynomial_mapper.t.cc index 6c0d85827..4f940c53f 100644 --- a/sxt/proof/sumcheck/polynomial_mapper.t.cc +++ b/sxt/proof/sumcheck/polynomial_mapper.t.cc @@ -20,13 +20,15 @@ #include "sxt/base/test/unit_test.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; using namespace sxt::prfsk; using s25t::operator""_s25; +using T = s25t::element; + TEST_CASE("we can map indexes to expanded polynomials") { std::vector mles; std::vector product_terms; @@ -34,7 +36,7 @@ TEST_CASE("we can map indexes to expanded polynomials") { SECTION("we can map a single element mle") { mles = {0x123_s25}; product_terms = {0}; - polynomial_mapper<1> m{ + polynomial_mapper<1, T> m{ .mles = mles.data(), .product_terms = product_terms.data(), .split = 1, @@ -48,7 +50,7 @@ TEST_CASE("we can map indexes to expanded polynomials") { SECTION("we can map an mle with two elements") { mles = {0x123_s25, 0x456_s25}; product_terms = {0}; - polynomial_mapper<1> m{ + polynomial_mapper<1, T> m{ .mles = mles.data(), .product_terms = product_terms.data(), .split = 1, diff --git a/sxt/proof/sumcheck/polynomial_reducer.cc b/sxt/proof/sumcheck/polynomial_reducer.cc new file mode 100644 index 000000000..cc61a64d1 --- /dev/null +++ b/sxt/proof/sumcheck/polynomial_reducer.cc @@ -0,0 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/proof/sumcheck/polynomial_reducer.h" diff --git a/sxt/proof/sumcheck/polynomial_reducer.h b/sxt/proof/sumcheck/polynomial_reducer.h new file mode 100644 index 000000000..92fd0110d --- /dev/null +++ b/sxt/proof/sumcheck/polynomial_reducer.h @@ -0,0 +1,35 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "sxt/base/field/element.h" +#include "sxt/base/macro/cuda_callable.h" + +namespace sxt::prfsk { +//-------------------------------------------------------------------------------------------------- +// polynomial_reducer +//-------------------------------------------------------------------------------------------------- +template struct polynomial_reducer { + using value_type = std::array; + + CUDA_CALLABLE static void accumulate_inplace(value_type& res, const value_type& e) noexcept { + for (unsigned i = 0; i < res.size(); ++i) { + add(res[i], res[i], e[i]); + } + } +}; +} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_utility.cc b/sxt/proof/sumcheck/polynomial_utility.cc index 4ba3d9169..46faf33bb 100644 --- a/sxt/proof/sumcheck/polynomial_utility.cc +++ b/sxt/proof/sumcheck/polynomial_utility.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,123 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/polynomial_utility.h" - -#include - -#include "sxt/scalar25/operation/add.h" -#include "sxt/scalar25/operation/mul.h" -#include "sxt/scalar25/operation/muladd.h" -#include "sxt/scalar25/operation/neg.h" -#include "sxt/scalar25/operation/sub.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// sum_polynomial_01 -//-------------------------------------------------------------------------------------------------- -void sum_polynomial_01(s25t::element& e, basct::cspan polynomial) noexcept { - if (polynomial.empty()) { - e = s25t::element{}; - return; - } - e = polynomial[0]; - for (unsigned i = 0; i < polynomial.size(); ++i) { - s25o::add(e, e, polynomial[i]); - } -} - -//-------------------------------------------------------------------------------------------------- -// evaluate_polynomial -//-------------------------------------------------------------------------------------------------- -void evaluate_polynomial(s25t::element& e, basct::cspan polynomial, - const s25t::element& x) noexcept { - if (polynomial.empty()) { - e = s25t::element{}; - return; - } - auto i = polynomial.size(); - --i; - e = polynomial[i]; - while (i > 0) { - --i; - s25o::muladd(e, e, x, polynomial[i]); - } -} - -//-------------------------------------------------------------------------------------------------- -// expand_products -//-------------------------------------------------------------------------------------------------- -CUDA_CALLABLE -void expand_products(basct::span p, const s25t::element* mles, unsigned n, - unsigned step, basct::cspan terms) noexcept { - auto num_terms = terms.size(); - assert( - // clang-format off - num_terms > 0 && - n > step && - p.size() == num_terms + 1u - // clang-format on - ); - s25t::element a, b; - auto mle_index = terms[0]; - a = *(mles + mle_index * n); - b = *(mles + mle_index * n + step); - s25o::sub(b, b, a); - p[0] = a; - p[1] = b; - - for (unsigned i = 1; i < num_terms; ++i) { - auto mle_index = terms[i]; - a = *(mles + mle_index * n); - b = *(mles + mle_index * n + step); - s25o::sub(b, b, a); - - auto c_prev = p[0]; - s25o::mul(p[0], c_prev, a); - for (unsigned pow = 1u; pow < i + 1u; ++pow) { - auto c = p[pow]; - s25o::mul(p[pow], c, a); - s25o::muladd(p[pow], c_prev, b, p[pow]); - c_prev = c; - } - s25o::mul(p[i + 1u], c_prev, b); - } -} - -//-------------------------------------------------------------------------------------------------- -// partial_expand_products -//-------------------------------------------------------------------------------------------------- -CUDA_CALLABLE -void partial_expand_products(basct::span p, const s25t::element* mles, unsigned n, - basct::cspan terms) noexcept { - auto num_terms = terms.size(); - assert( - // clang-format off - num_terms > 0 && - p.size() == num_terms + 1u - // clang-format on - ); - s25t::element a, b; - auto mle_index = terms[0]; - a = *(mles + mle_index * n); - s25o::neg(b, a); - p[0] = a; - p[1] = b; - - for (unsigned i = 1; i < num_terms; ++i) { - auto mle_index = terms[i]; - a = *(mles + mle_index * n); - s25o::neg(b, a); - - auto c_prev = p[0]; - s25o::mul(p[0], c_prev, a); - for (unsigned pow = 1u; pow < i + 1u; ++pow) { - auto c = p[pow]; - s25o::mul(p[pow], c, a); - s25o::muladd(p[pow], c_prev, b, p[pow]); - c_prev = c; - } - s25o::mul(p[i + 1u], c_prev, b); - } -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_utility.h b/sxt/proof/sumcheck/polynomial_utility.h index 6754c4d74..65c21608d 100644 --- a/sxt/proof/sumcheck/polynomial_utility.h +++ b/sxt/proof/sumcheck/polynomial_utility.h @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,12 @@ */ #pragma once +#include + #include "sxt/base/container/span.h" +#include "sxt/base/field/element.h" #include "sxt/base/macro/cuda_callable.h" -namespace sxt::s25t { -class element; -} - namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // sum_polynomial_01 @@ -31,25 +30,109 @@ namespace sxt::prfsk { // f_a(X) = a[0] + a[1] * X + a[2] * X^2 + ... // compute the sum // f_a(0) + f_a(1) -void sum_polynomial_01(s25t::element& e, basct::cspan polynomial) noexcept; +template void sum_polynomial_01(T& e, basct::cspan polynomial) noexcept { + if (polynomial.empty()) { + e = T{}; + return; + } + e = polynomial[0]; + for (unsigned i = 0; i < polynomial.size(); ++i) { + add(e, e, polynomial[i]); + } +} //-------------------------------------------------------------------------------------------------- // evaluate_polynomial //-------------------------------------------------------------------------------------------------- -void evaluate_polynomial(s25t::element& e, basct::cspan polynomial, - const s25t::element& x) noexcept; +template +void evaluate_polynomial(T& e, basct::cspan polynomial, const T& x) noexcept { + if (polynomial.empty()) { + e = T{}; + return; + } + auto i = polynomial.size(); + --i; + e = polynomial[i]; + while (i > 0) { + --i; + muladd(e, e, x, polynomial[i]); + } +} //-------------------------------------------------------------------------------------------------- // expand_products //-------------------------------------------------------------------------------------------------- -CUDA_CALLABLE -void expand_products(basct::span p, const s25t::element* mles, unsigned n, - unsigned step, basct::cspan terms) noexcept; +template +CUDA_CALLABLE void expand_products(basct::span p, const T* mles, unsigned n, unsigned step, + basct::cspan terms) noexcept { + auto num_terms = terms.size(); + assert( + // clang-format off + num_terms > 0 && + n > step && + p.size() == num_terms + 1u + // clang-format on + ); + T a, b; + auto mle_index = terms[0]; + a = *(mles + mle_index * n); + b = *(mles + mle_index * n + step); + sub(b, b, a); + p[0] = a; + p[1] = b; + + for (unsigned i = 1; i < num_terms; ++i) { + auto mle_index = terms[i]; + a = *(mles + mle_index * n); + b = *(mles + mle_index * n + step); + sub(b, b, a); + + auto c_prev = p[0]; + mul(p[0], c_prev, a); + for (unsigned pow = 1u; pow < i + 1u; ++pow) { + auto c = p[pow]; + mul(p[pow], c, a); + muladd(p[pow], c_prev, b, p[pow]); + c_prev = c; + } + mul(p[i + 1u], c_prev, b); + } +} //-------------------------------------------------------------------------------------------------- // partial_expand_products //-------------------------------------------------------------------------------------------------- -CUDA_CALLABLE -void partial_expand_products(basct::span p, const s25t::element* mles, unsigned n, - basct::cspan terms) noexcept; +template +CUDA_CALLABLE void partial_expand_products(basct::span p, const T* mles, unsigned n, + basct::cspan terms) noexcept { + auto num_terms = terms.size(); + assert( + // clang-format off + num_terms > 0 && + p.size() == num_terms + 1u + // clang-format on + ); + T a, b; + auto mle_index = terms[0]; + a = *(mles + mle_index * n); + neg(b, a); + p[0] = a; + p[1] = b; + + for (unsigned i = 1; i < num_terms; ++i) { + auto mle_index = terms[i]; + a = *(mles + mle_index * n); + neg(b, a); + + auto c_prev = p[0]; + mul(p[0], c_prev, a); + for (unsigned pow = 1u; pow < i + 1u; ++pow) { + auto c = p[pow]; + mul(p[pow], c, a); + muladd(p[pow], c_prev, b, p[pow]); + c_prev = c; + } + mul(p[i + 1u], c_prev, b); + } +} } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/polynomial_utility.t.cc b/sxt/proof/sumcheck/polynomial_utility.t.cc index 281e65e3a..821124e47 100644 --- a/sxt/proof/sumcheck/polynomial_utility.t.cc +++ b/sxt/proof/sumcheck/polynomial_utility.t.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,55 +20,58 @@ #include "sxt/base/test/unit_test.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; using namespace sxt::prfsk; using s25t::operator""_s25; +using T = s25t::element; + TEST_CASE("we perform basic operations on polynomials") { + s25t::element e; std::vector p; SECTION("we can compute the 0-1 sum of a zero polynomials") { - sum_polynomial_01(e, p); + sum_polynomial_01(e, p); REQUIRE(e == 0x0_s25); } SECTION("we can compute the 0-1 sum of a constant polynomial") { p = {0x123_s25}; - sum_polynomial_01(e, p); + sum_polynomial_01(e, p); REQUIRE(e == 0x246_s25); } SECTION("we can compute the 0-1 sum of a 1 degree polynomial") { p = {0x123_s25, 0x456_s25}; - sum_polynomial_01(e, p); + sum_polynomial_01(e, p); REQUIRE(e == 0x246_s25 + 0x456_s25); } SECTION("we can evaluate the zero polynomial") { - evaluate_polynomial(e, p, 0x123_s25); + evaluate_polynomial(e, p, 0x123_s25); REQUIRE(e == 0x0_s25); } SECTION("we can evaluate a constant polynomial") { p = {0x123_s25}; - evaluate_polynomial(e, p, 0x321_s25); + evaluate_polynomial(e, p, 0x321_s25); REQUIRE(e == 0x123_s25); } SECTION("we can evaluate a polynomial of degree 1") { p = {0x123_s25, 0x456_s25}; - evaluate_polynomial(e, p, 0x321_s25); + evaluate_polynomial(e, p, 0x321_s25); REQUIRE(e == 0x123_s25 + 0x456_s25 * 0x321_s25); } SECTION("we can evaluate a polynomial of degree 2") { p = {0x123_s25, 0x456_s25, 0x789_s25}; - evaluate_polynomial(e, p, 0x321_s25); + evaluate_polynomial(e, p, 0x321_s25); REQUIRE(e == 0x123_s25 + 0x456_s25 * 0x321_s25 + 0x789_s25 * 0x321_s25 * 0x321_s25); } } @@ -82,7 +85,7 @@ TEST_CASE("we can expand a product of MLEs") { p.resize(2); mles = {0x123_s25, 0x456_s25}; terms = {0}; - expand_products(p, mles.data(), 2, 1, terms); + expand_products(p, mles.data(), 2, 1, terms); REQUIRE(p[0] == mles[0]); REQUIRE(p[1] == mles[1] - mles[0]); } @@ -91,10 +94,10 @@ TEST_CASE("we can expand a product of MLEs") { mles = {0x123_s25, 0x0_s25}; p.resize(2); terms = {0}; - partial_expand_products(p, mles.data(), 1, terms); + partial_expand_products(p, mles.data(), 1, terms); std::vector expected(2); - expand_products(expected, mles.data(), 2, 1, terms); + expand_products(expected, mles.data(), 2, 1, terms); REQUIRE(p == expected); } @@ -102,7 +105,7 @@ TEST_CASE("we can expand a product of MLEs") { p.resize(3); mles = {0x123_s25, 0x456_s25, 0x1122_s25, 0x4455_s25}; terms = {0, 1}; - expand_products(p, mles.data(), 2, 1, terms); + expand_products(p, mles.data(), 2, 1, terms); auto a1 = mles[0]; auto a2 = mles[1] - mles[0]; auto b1 = mles[2]; @@ -116,7 +119,7 @@ TEST_CASE("we can expand a product of MLEs") { p.resize(4); mles = {0x123_s25, 0x456_s25, 0x1122_s25, 0x4455_s25, 0x2233_s25, 0x5566_s25}; terms = {0, 1, 2}; - expand_products(p, mles.data(), 2, 1, terms); + expand_products(p, mles.data(), 2, 1, terms); auto a1 = mles[0]; auto a2 = mles[1] - mles[0]; auto b1 = mles[2]; diff --git a/sxt/proof/sumcheck/proof_computation.cc b/sxt/proof/sumcheck/proof_computation.cc index d77311d79..d999e8d39 100644 --- a/sxt/proof/sumcheck/proof_computation.cc +++ b/sxt/proof/sumcheck/proof_computation.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,57 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/proof_computation.h" - -#include "sxt/base/error/assert.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/proof/sumcheck/driver.h" -#include "sxt/proof/sumcheck/sumcheck_transcript.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// prove_sum -//-------------------------------------------------------------------------------------------------- -xena::future<> prove_sum(basct::span polynomials, - basct::span evaluation_point, - sumcheck_transcript& transcript, const driver& drv, - basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) noexcept { - SXT_RELEASE_ASSERT(0 < n); - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto polynomial_length = polynomials.size() / num_variables; - auto num_mles = mles.size() / n; - SXT_RELEASE_ASSERT( - // clang-format off - polynomial_length > 1 && - evaluation_point.size() == num_variables && - polynomials.size() == num_variables * polynomial_length && - mles.size() == n * num_mles - // clang-format on - ); - - transcript.init(num_variables, polynomial_length - 1); - - auto ws = co_await drv.make_workspace(mles, product_table, product_terms, n); - - for (unsigned round_index = 0; round_index < num_variables; ++round_index) { - auto polynomial = polynomials.subspan(round_index * polynomial_length, polynomial_length); - - // compute the round polynomial - co_await drv.sum(polynomial, *ws); - - // draw the next random challenge - s25t::element r; - transcript.round_challenge(r, polynomial); - evaluation_point[round_index] = r; - - // fold the polynomial - if (round_index < num_variables - 1u) { - co_await drv.fold(*ws, r); - } - } -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/proof_computation.h b/sxt/proof/sumcheck/proof_computation.h index 2569d7744..b3e1328d3 100644 --- a/sxt/proof/sumcheck/proof_computation.h +++ b/sxt/proof/sumcheck/proof_computation.h @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,29 +16,55 @@ */ #pragma once -#include - -#include "sxt/base/container/span.h" -#include "sxt/execution/async/future_fwd.h" - -namespace sxt::prft { -class transcript; -} -namespace sxt::s25t { -class element; -} +#include "sxt/base/error/assert.h" +#include "sxt/base/field/element.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/async/future.h" +#include "sxt/proof/sumcheck/driver.h" +#include "sxt/proof/sumcheck/sumcheck_transcript.h" namespace sxt::prfsk { -class driver; -class sumcheck_transcript; - //-------------------------------------------------------------------------------------------------- // prove_sum //-------------------------------------------------------------------------------------------------- -xena::future<> prove_sum(basct::span polynomials, - basct::span evaluation_point, - sumcheck_transcript& transcript, const driver& drv, - basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) noexcept; +template +xena::future<> prove_sum(basct::span polynomials, basct::span evaluation_point, + sumcheck_transcript& transcript, const driver& drv, + basct::cspan mles, basct::cspan> product_table, + basct::cspan product_terms, unsigned n) noexcept { + SXT_RELEASE_ASSERT(0 < n); + auto num_variables = std::max(basn::ceil_log2(n), 1); + auto polynomial_length = polynomials.size() / num_variables; + auto num_mles = mles.size() / n; + SXT_RELEASE_ASSERT( + // clang-format off + polynomial_length > 1 && + evaluation_point.size() == num_variables && + polynomials.size() == num_variables * polynomial_length && + mles.size() == n * num_mles + // clang-format on + ); + + transcript.init(num_variables, polynomial_length - 1); + + auto ws = co_await drv.make_workspace(mles, product_table, product_terms, n); + + for (unsigned round_index = 0; round_index < num_variables; ++round_index) { + auto polynomial = polynomials.subspan(round_index * polynomial_length, polynomial_length); + + // compute the round polynomial + co_await drv.sum(polynomial, *ws); + + // draw the next random challenge + T r; + transcript.round_challenge(r, polynomial); + evaluation_point[round_index] = r; + + // fold the polynomial + if (round_index < num_variables - 1u) { + co_await drv.fold(*ws, r); + } + } +} } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/proof_computation.t.cc b/sxt/proof/sumcheck/proof_computation.t.cc index b820e0a2b..11be3a42d 100644 --- a/sxt/proof/sumcheck/proof_computation.t.cc +++ b/sxt/proof/sumcheck/proof_computation.t.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,30 +35,32 @@ #include "sxt/proof/sumcheck/verification.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; using namespace sxt::prfsk; using s25t::operator""_s25; -static void test_proof(const driver& drv) noexcept { +using T = s25t::element; + +static void test_proof(const driver& drv) noexcept { prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; + reference_transcript transcript{base_transcript}; std::vector polynomials(2); std::vector evaluation_point(1); std::vector mles = { 0x8_s25, 0x3_s25, }; - std::vector> product_table = { + std::vector> product_table = { {0x1_s25, 1}, }; std::vector product_terms = {0}; SECTION("we can prove a sum with n=1") { - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 1); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 1); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(polynomials[0] == mles[0]); @@ -66,8 +68,8 @@ static void test_proof(const driver& drv) noexcept { } SECTION("we can prove a sum with a single variable") { - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 2); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 2); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(polynomials[0] == mles[0]); @@ -80,8 +82,8 @@ static void test_proof(const driver& drv) noexcept { }; product_terms = {0, 0}; polynomials.resize(3); - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 2); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 2); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(polynomials[0] == mles[0] * mles[0]); @@ -97,8 +99,8 @@ static void test_proof(const driver& drv) noexcept { polynomials.resize(3); mles.push_back(0x7_s25); mles.push_back(0x10_s25); - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 2); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 2); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(polynomials[0] == mles[0] * mles[2]); @@ -108,8 +110,8 @@ static void test_proof(const driver& drv) noexcept { SECTION("we can prove a sum where the term multiplier is different from one") { product_table[0].first = 0x2_s25; - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 2); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 2); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(polynomials[0] == 0x2_s25 * mles[0]); @@ -121,8 +123,8 @@ static void test_proof(const driver& drv) noexcept { mles.push_back(0x7_s25); polynomials.resize(4); evaluation_point.resize(2); - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 4); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 4); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(polynomials[0] == mles[0] + mles[1]); @@ -140,8 +142,8 @@ static void test_proof(const driver& drv) noexcept { mles.push_back(0x4_s25); polynomials.resize(4); evaluation_point.resize(2); - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, 3); + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, 3); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(polynomials[0] == mles[0] + mles[1]); @@ -175,32 +177,32 @@ static void test_proof(const driver& drv) noexcept { // prove { prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; - auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, - product_terms, n); + reference_transcript transcript{base_transcript}; + auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table, + product_terms, n); xens::get_scheduler().run(); } // we can verify { prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; + reference_transcript transcript{base_transcript}; s25t::element expected_sum; - sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); - auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - polynomials, polynomial_length - 1u); + sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); + auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + polynomials, polynomial_length - 1u); REQUIRE(valid); } // verification fails if we break the proof { prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; + reference_transcript transcript{base_transcript}; s25t::element expected_sum; - sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); + sum_polynomial_01(expected_sum, basct::subspan(polynomials, 0, polynomial_length)); polynomials[polynomials.size() - 1] = polynomials[0] + polynomials[1]; - auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - polynomials, polynomial_length - 1u); + auto valid = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + polynomials, polynomial_length - 1u); REQUIRE(!valid); } } @@ -209,24 +211,24 @@ static void test_proof(const driver& drv) noexcept { TEST_CASE("we can create a sumcheck proof") { SECTION("we can prove with the cpu driver") { - cpu_driver drv; + cpu_driver drv; test_proof(drv); } SECTION("we can prove with the gpu driver") { - gpu_driver drv; + gpu_driver drv; test_proof(drv); } SECTION("we can prove with the chunked gpu driver") { - chunked_gpu_driver drv{0.0}; + chunked_gpu_driver drv{0.0}; test_proof(drv); } SECTION("we can prove with a chunked driver that switches over to the single gpu driver") { std::vector mles(4); - auto fraction = get_gpu_memory_fraction(mles); - chunked_gpu_driver drv{fraction}; + auto fraction = get_gpu_memory_fraction(mles); + chunked_gpu_driver drv{fraction}; test_proof(drv); } } diff --git a/sxt/proof/sumcheck/reduction_gpu.cc b/sxt/proof/sumcheck/reduction_gpu.cc index e03c12d62..837b4c09d 100644 --- a/sxt/proof/sumcheck/reduction_gpu.cc +++ b/sxt/proof/sumcheck/reduction_gpu.cc @@ -15,100 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/reduction_gpu.h" - -#include "sxt/algorithm/base/identity_mapper.h" -#include "sxt/algorithm/reduction/kernel_fit.h" -#include "sxt/algorithm/reduction/thread_reduction.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/error/assert.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/device/synchronization.h" -#include "sxt/execution/kernel/kernel_dims.h" -#include "sxt/execution/kernel/launch.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/memory/resource/async_device_resource.h" -#include "sxt/scalar25/operation/accumulator.h" -#include "sxt/scalar25/operation/add.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// reduction_kernel -//-------------------------------------------------------------------------------------------------- -template -__global__ static void reduction_kernel(s25t::element* __restrict__ out, - const s25t::element* __restrict__ partials, - unsigned n) noexcept { - auto thread_index = threadIdx.x; - auto block_index = blockIdx.x; - auto coefficient_index = blockIdx.y; - auto index = block_index * (BlockSize * 2) + thread_index; - auto step = BlockSize * 2 * gridDim.x; - __shared__ s25t::element shared_data[2 * BlockSize]; - - // coefficient adjustment - out += coefficient_index; - partials += coefficient_index * n; - - // mapper - algb::identity_mapper mapper{partials}; - - // reduce - algr::thread_reduce(out + block_index, shared_data, mapper, n, step, - thread_index, index); -} - -//-------------------------------------------------------------------------------------------------- -// reduce_sums -//-------------------------------------------------------------------------------------------------- -xena::future<> reduce_sums(basct::span p, basdv::stream& stream, - basct::cspan partial_terms) noexcept { - auto num_coefficients = p.size(); - auto n = partial_terms.size() / num_coefficients; - SXT_DEBUG_ASSERT( - // clang-format off - n > 0 && - partial_terms.size() == num_coefficients * n && - basdv::is_host_pointer(p.data()) && - basdv::is_active_device_pointer(partial_terms.data()) - // clang-format on - ); - auto dims = algr::fit_reduction_kernel(n); - - // p_dev - memr::async_device_resource resource{stream}; - memmg::managed_array p_dev{num_coefficients * dims.num_blocks, &resource}; - - // launch kernel - xenk::launch_kernel(dims.block_size, [&]( - std::integral_constant) noexcept { - reduction_kernel - <<>>( - p_dev.data(), partial_terms.data(), n); - }); - - // copy polynomial to host - memmg::managed_array p_host_data; - basct::span p_host = p; - if (dims.num_blocks > 1) { - p_host_data.resize(p_dev.size()); - p_host = p_host_data; - } - basdv::async_copy_device_to_host(p_host, p_dev, stream); - co_await xendv::await_stream(stream); - - // complete reduction on host if necessary - if (dims.num_blocks == 1) { - co_return; - } - for (unsigned coefficient_index = 0; coefficient_index < num_coefficients; ++coefficient_index) { - p[coefficient_index] = p_host[coefficient_index * dims.num_blocks]; - for (unsigned block_index = 1; block_index < dims.num_blocks; ++block_index) { - s25o::add(p[coefficient_index], p[coefficient_index], - p_host[coefficient_index * dims.num_blocks + block_index]); - } - } -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/reduction_gpu.h b/sxt/proof/sumcheck/reduction_gpu.h index eb89835c5..ddb3995cf 100644 --- a/sxt/proof/sumcheck/reduction_gpu.h +++ b/sxt/proof/sumcheck/reduction_gpu.h @@ -16,20 +16,98 @@ */ #pragma once -#include "sxt/base/container/span.h" -#include "sxt/execution/async/future_fwd.h" +#include "sxt/algorithm/base/identity_mapper.h" +#include "sxt/algorithm/reduction/kernel_fit.h" +#include "sxt/algorithm/reduction/thread_reduction.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/error/assert.h" +#include "sxt/base/field/accumulator.h" +#include "sxt/base/field/element.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/device/synchronization.h" +#include "sxt/execution/kernel/kernel_dims.h" +#include "sxt/execution/kernel/launch.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/async_device_resource.h" -namespace sxt::basdv { -class stream; -} -namespace sxt::s25t { -class element; +namespace sxt::prfsk { +//-------------------------------------------------------------------------------------------------- +// reduction_kernel +//-------------------------------------------------------------------------------------------------- +template +__global__ static void reduction_kernel(T* __restrict__ out, const T* __restrict__ partials, + unsigned n) noexcept { + auto thread_index = threadIdx.x; + auto block_index = blockIdx.x; + auto coefficient_index = blockIdx.y; + auto index = block_index * (BlockSize * 2) + thread_index; + auto step = BlockSize * 2 * gridDim.x; + __shared__ T shared_data[2 * BlockSize]; + + // coefficient adjustment + out += coefficient_index; + partials += coefficient_index * n; + + // mapper + algb::identity_mapper mapper{partials}; + + // reduce + algr::thread_reduce, BlockSize>(out + block_index, shared_data, mapper, n, + step, thread_index, index); } -namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // reduce_sums //-------------------------------------------------------------------------------------------------- -xena::future<> reduce_sums(basct::span p, basdv::stream& stream, - basct::cspan partial_terms) noexcept; +template +xena::future<> reduce_sums(basct::span p, basdv::stream& stream, + basct::cspan partial_terms) noexcept { + auto num_coefficients = p.size(); + auto n = partial_terms.size() / num_coefficients; + SXT_DEBUG_ASSERT( + // clang-format off + n > 0 && + partial_terms.size() == num_coefficients * n && + basdv::is_host_pointer(p.data()) && + basdv::is_active_device_pointer(partial_terms.data()) + // clang-format on + ); + auto dims = algr::fit_reduction_kernel(n); + + // p_dev + memr::async_device_resource resource{stream}; + memmg::managed_array p_dev{num_coefficients * dims.num_blocks, &resource}; + + // launch kernel + xenk::launch_kernel(dims.block_size, [&]( + std::integral_constant) noexcept { + reduction_kernel + <<>>( + p_dev.data(), partial_terms.data(), n); + }); + + // copy polynomial to host + memmg::managed_array p_host_data; + basct::span p_host = p; + if (dims.num_blocks > 1) { + p_host_data.resize(p_dev.size()); + p_host = p_host_data; + } + basdv::async_copy_device_to_host(p_host, p_dev, stream); + co_await xendv::await_stream(stream); + + // complete reduction on host if necessary + if (dims.num_blocks == 1) { + co_return; + } + for (unsigned coefficient_index = 0; coefficient_index < num_coefficients; ++coefficient_index) { + p[coefficient_index] = p_host[coefficient_index * dims.num_blocks]; + for (unsigned block_index = 1; block_index < dims.num_blocks; ++block_index) { + add(p[coefficient_index], p[coefficient_index], + p_host[coefficient_index * dims.num_blocks + block_index]); + } + } +} } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/reduction_gpu.t.cc b/sxt/proof/sumcheck/reduction_gpu.t.cc index 218d8f1b6..ce4a8e73e 100644 --- a/sxt/proof/sumcheck/reduction_gpu.t.cc +++ b/sxt/proof/sumcheck/reduction_gpu.t.cc @@ -24,7 +24,7 @@ #include "sxt/execution/schedule/scheduler.h" #include "sxt/memory/resource/managed_device_resource.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; @@ -32,6 +32,7 @@ using namespace sxt::prfsk; using s25t::operator""_s25; TEST_CASE("we can reduce sumcheck polynomials") { + using T = s25t::element; std::vector p; std::pmr::vector partial_terms{memr::get_managed_device_resource()}; @@ -40,7 +41,7 @@ TEST_CASE("we can reduce sumcheck polynomials") { SECTION("we can reduce a sum with a single term") { p.resize(1); partial_terms = {0x123_s25}; - auto fut = reduce_sums(p, stream, partial_terms); + auto fut = reduce_sums(p, stream, partial_terms); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(p[0] == 0x123_s25); @@ -49,7 +50,7 @@ TEST_CASE("we can reduce sumcheck polynomials") { SECTION("we can reduce two terms") { p.resize(1); partial_terms = {0x123_s25, 0x456_s25}; - auto fut = reduce_sums(p, stream, partial_terms); + auto fut = reduce_sums(p, stream, partial_terms); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(p[0] == 0x123_s25 + 0x456_s25); @@ -58,7 +59,7 @@ TEST_CASE("we can reduce sumcheck polynomials") { SECTION("we can reduce multiple coefficients") { p.resize(2); partial_terms = {0x123_s25, 0x456_s25}; - auto fut = reduce_sums(p, stream, partial_terms); + auto fut = reduce_sums(p, stream, partial_terms); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(p[0] == 0x123_s25); diff --git a/sxt/proof/sumcheck/reference_transcript.cc b/sxt/proof/sumcheck/reference_transcript.cc index fcbeb6b5b..1d1616dc0 100644 --- a/sxt/proof/sumcheck/reference_transcript.cc +++ b/sxt/proof/sumcheck/reference_transcript.cc @@ -15,32 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/reference_transcript.h" - -#include "sxt/proof/transcript/transcript_utility.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// constructor -//-------------------------------------------------------------------------------------------------- -reference_transcript::reference_transcript(prft::transcript& transcript) noexcept - : transcript_{transcript} {} - -//-------------------------------------------------------------------------------------------------- -// init -//-------------------------------------------------------------------------------------------------- -void reference_transcript::init(size_t num_variables, size_t round_degree) noexcept { - prft::set_domain(transcript_, "sumcheck proof v1"); - prft::append_value(transcript_, "n", num_variables); - prft::append_value(transcript_, "k", round_degree); -} - -//-------------------------------------------------------------------------------------------------- -// round_challenge -//-------------------------------------------------------------------------------------------------- -void reference_transcript::round_challenge(s25t::element& r, - basct::cspan polynomial) noexcept { - prft::append_values(transcript_, "P", polynomial); - prft::challenge_value(r, transcript_, "R"); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/reference_transcript.h b/sxt/proof/sumcheck/reference_transcript.h index 3ee41ed0b..265ae1a34 100644 --- a/sxt/proof/sumcheck/reference_transcript.h +++ b/sxt/proof/sumcheck/reference_transcript.h @@ -18,18 +18,26 @@ #include "sxt/proof/sumcheck/sumcheck_transcript.h" #include "sxt/proof/transcript/transcript.h" +#include "sxt/proof/transcript/transcript_utility.h" namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // reference_transcript //-------------------------------------------------------------------------------------------------- -class reference_transcript final : public sumcheck_transcript { +template class reference_transcript final : public sumcheck_transcript { public: - explicit reference_transcript(prft::transcript& transcript) noexcept; + explicit reference_transcript(prft::transcript& transcript) noexcept : transcript_{transcript} {} - void init(size_t num_variables, size_t round_degree) noexcept override; + void init(size_t num_variables, size_t round_degree) noexcept { + prft::set_domain(transcript_, "sumcheck proof v1"); + prft::append_value(transcript_, "n", num_variables); + prft::append_value(transcript_, "k", round_degree); + } - void round_challenge(s25t::element& r, basct::cspan polynomial) noexcept override; + void round_challenge(T& r, basct::cspan polynomial) noexcept { + prft::append_values(transcript_, "P", polynomial); + prft::challenge_value(r, transcript_, "R"); + } private: prft::transcript& transcript_; diff --git a/sxt/proof/sumcheck/reference_transcript.t.cc b/sxt/proof/sumcheck/reference_transcript.t.cc index 890312c51..373fea8c5 100644 --- a/sxt/proof/sumcheck/reference_transcript.t.cc +++ b/sxt/proof/sumcheck/reference_transcript.t.cc @@ -16,20 +16,21 @@ */ #include "sxt/proof/sumcheck/reference_transcript.h" -#include - #include "sxt/base/test/unit_test.h" #include "sxt/proof/transcript/transcript.h" +#include "sxt/scalar25/operation/overload.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; using namespace sxt::prfsk; -using s25t::operator""_s25; +using sxt::s25t::operator""_s25; TEST_CASE("we provide an implementation of sumcheck transcript") { + using T = s25t::element; prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; - std::vector p = {0x123_s25}; + reference_transcript transcript{base_transcript}; + std::vector p = {0x123_s25}; s25t::element r, rp; SECTION("we don't draw the same challenge from a transcript") { @@ -38,7 +39,7 @@ TEST_CASE("we provide an implementation of sumcheck transcript") { REQUIRE(r != rp); prft::transcript base_transcript_p{"abc"}; - reference_transcript transcript_p{base_transcript_p}; + reference_transcript transcript_p{base_transcript_p}; p[0] = 0x456_s25; transcript_p.round_challenge(rp, p); REQUIRE(r != rp); @@ -49,7 +50,7 @@ TEST_CASE("we provide an implementation of sumcheck transcript") { transcript.round_challenge(r, p); prft::transcript base_transcript_p{"abc"}; - reference_transcript transcript_p{base_transcript_p}; + reference_transcript transcript_p{base_transcript_p}; transcript.init(2, 1); transcript.round_challenge(rp, p); diff --git a/sxt/proof/sumcheck/sum_gpu.cc b/sxt/proof/sumcheck/sum_gpu.cc index 8066fd249..e177b6f83 100644 --- a/sxt/proof/sumcheck/sum_gpu.cc +++ b/sxt/proof/sumcheck/sum_gpu.cc @@ -15,217 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/sum_gpu.h" - -#include - -#include "sxt/algorithm/reduction/kernel_fit.h" -#include "sxt/algorithm/reduction/thread_reduction.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/device/state.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/iterator/split.h" -#include "sxt/base/num/ceil_log2.h" -#include "sxt/base/num/constexpr_switch.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/async/future.h" -#include "sxt/execution/device/for_each.h" -#include "sxt/execution/kernel/kernel_dims.h" -#include "sxt/execution/kernel/launch.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/memory/resource/async_device_resource.h" -#include "sxt/memory/resource/device_resource.h" -#include "sxt/proof/sumcheck/constant.h" -#include "sxt/proof/sumcheck/device_cache.h" -#include "sxt/proof/sumcheck/mle_utility.h" -#include "sxt/proof/sumcheck/polynomial_mapper.h" -#include "sxt/proof/sumcheck/reduction_gpu.h" -#include "sxt/scalar25/operation/add.h" -#include "sxt/scalar25/operation/mul.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// polynomial_reducer -//-------------------------------------------------------------------------------------------------- -namespace { -template struct polynomial_reducer { - using value_type = std::array; - - CUDA_CALLABLE static void accumulate_inplace(value_type& res, const value_type& e) noexcept { - for (unsigned i = 0; i < res.size(); ++i) { - s25o::add(res[i], res[i], e[i]); - } - } -}; -} // namespace - -//-------------------------------------------------------------------------------------------------- -// partial_sum_kernel_impl -//-------------------------------------------------------------------------------------------------- -template -__device__ static void partial_sum_kernel_impl(s25t::element* __restrict__ shared_data, - const s25t::element* __restrict__ mles, - const unsigned* __restrict__ product_terms, - unsigned split, unsigned n) noexcept { - using Mapper = polynomial_mapper; - using Reducer = polynomial_reducer; - using T = Mapper::value_type; - Mapper mapper{ - .mles = mles, - .product_terms = product_terms, - .split = split, - .n = n, - }; - auto index = blockIdx.x * (BlockSize * 2) + threadIdx.x; - auto step = BlockSize * 2 * gridDim.x; - algr::thread_reduce(reinterpret_cast(shared_data), mapper, split, step, - threadIdx.x, index); -} - -//-------------------------------------------------------------------------------------------------- -// partial_sum_kernel -//-------------------------------------------------------------------------------------------------- -template -__global__ static void -partial_sum_kernel(s25t::element* __restrict__ out, const s25t::element* __restrict__ mles, - const std::pair* __restrict__ product_table, - const unsigned* __restrict__ product_terms, unsigned num_coefficients, - unsigned split, unsigned n) noexcept { - auto product_index = blockIdx.y; - auto num_terms = product_table[product_index].second; - auto thread_index = threadIdx.x; - auto output_step = gridDim.x * gridDim.y; - - // shared data for reduction - __shared__ s25t::element shared_data[2 * BlockSize * (max_degree_v + 1u)]; - - // adjust pointers - out += blockIdx.x; - out += gridDim.x * product_index; - for (unsigned i = 0; i < product_index; ++i) { - product_terms += product_table[i].second; - } - - // sum - basn::constexpr_switch<1, max_degree_v + 1u>( - num_terms, [&](std::integral_constant) noexcept { - partial_sum_kernel_impl(shared_data, mles, product_terms, split, n); - }); - - // write out result - auto mult = product_table[product_index].first; - for (unsigned i = thread_index; i < num_coefficients; i += BlockSize) { - auto output_index = output_step * i; - if (i < num_terms + 1u) { - s25o::mul(out[output_index], mult, shared_data[i]); - } else { - out[output_index] = s25t::element{}; - } - } -} - -//-------------------------------------------------------------------------------------------------- -// partial_sum -//-------------------------------------------------------------------------------------------------- -static xena::future<> partial_sum(basct::span p, basdv::stream& stream, - basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned split, - unsigned n) noexcept { - auto num_coefficients = p.size(); - auto num_products = product_table.size(); - auto dims = algr::fit_reduction_kernel(split); - memr::async_device_resource resource{stream}; - - // partials - memmg::managed_array partials{num_coefficients * dims.num_blocks * num_products, - &resource}; - xenk::launch_kernel(dims.block_size, [&]( - std::integral_constant) noexcept { - partial_sum_kernel<<>>( - partials.data(), mles.data(), product_table.data(), product_terms.data(), num_coefficients, - split, n); - }); - - // reduce partials - co_await reduce_sums(p, stream, partials); -} - -//-------------------------------------------------------------------------------------------------- -// sum_gpu -//-------------------------------------------------------------------------------------------------- -xena::future<> sum_gpu(basct::span p, device_cache& cache, - const basit::split_options& options, basct::cspan mles, - unsigned n) noexcept { - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto mid = 1u << (num_variables - 1u); - auto num_mles = mles.size() / n; - auto num_coefficients = p.size(); - - // split - auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, mid}, options); - - // sum - size_t counter = 0; - co_await xendv::concurrent_for_each( - chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> { - basdv::stream stream; - memr::async_device_resource resource{stream}; - - // copy partial mles to device - memmg::managed_array partial_mles{&resource}; - copy_partial_mles(partial_mles, stream, mles, n, rng.a(), rng.b()); - auto split = rng.b() - rng.a(); - auto np = partial_mles.size() / num_mles; - - // lookup problem descriptor - basct::cspan> product_table; - basct::cspan product_terms; - cache.lookup(product_table, product_terms, stream); - - // compute - memmg::managed_array partial_p(num_coefficients); - co_await partial_sum(partial_p, stream, partial_mles, product_table, product_terms, split, - np); - - // fill in the result - if (counter == 0) { - for (unsigned i = 0; i < num_coefficients; ++i) { - p[i] = partial_p[i]; - } - } else { - for (unsigned i = 0; i < num_coefficients; ++i) { - s25o::add(p[i], p[i], partial_p[i]); - } - } - ++counter; - }); -} - -xena::future<> sum_gpu(basct::span p, device_cache& cache, - basct::cspan mles, unsigned n) noexcept { - basit::split_options options{ - .min_chunk_size = 100'000u, - .max_chunk_size = 200'000u, - .split_factor = basdv::get_num_devices(), - }; - co_await sum_gpu(p, cache, options, mles, n); -} - -xena::future<> sum_gpu(basct::span p, basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) noexcept { - auto num_variables = std::max(basn::ceil_log2(n), 1); - auto mid = 1u << (num_variables - 1u); - SXT_DEBUG_ASSERT( - // clang-format off - basdv::is_host_pointer(p.data()) && - basdv::is_active_device_pointer(mles.data()) && - basdv::is_active_device_pointer(product_table.data()) && - basdv::is_active_device_pointer(product_terms.data()) - // clang-format on - ); - basdv::stream stream; - co_await partial_sum(p, stream, mles, product_table, product_terms, mid, n); -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/sum_gpu.h b/sxt/proof/sumcheck/sum_gpu.h index a693c7da3..8f8cfb82e 100644 --- a/sxt/proof/sumcheck/sum_gpu.h +++ b/sxt/proof/sumcheck/sum_gpu.h @@ -16,23 +16,33 @@ */ #pragma once -#include +#include -#include "sxt/base/container/span.h" -#include "sxt/base/device/property.h" -#include "sxt/execution/async/future_fwd.h" - -namespace sxt::s25t { -class element; -} - -namespace sxt::basit { -struct split_options; -} +#include "sxt/algorithm/reduction/kernel_fit.h" +#include "sxt/algorithm/reduction/thread_reduction.h" +#include "sxt/base/device/memory_utility.h" +#include "sxt/base/device/state.h" +#include "sxt/base/device/stream.h" +#include "sxt/base/field/element.h" +#include "sxt/base/iterator/split.h" +#include "sxt/base/num/ceil_log2.h" +#include "sxt/base/num/constexpr_switch.h" +#include "sxt/execution/async/coroutine.h" +#include "sxt/execution/async/future.h" +#include "sxt/execution/device/for_each.h" +#include "sxt/execution/kernel/kernel_dims.h" +#include "sxt/execution/kernel/launch.h" +#include "sxt/memory/management/managed_array.h" +#include "sxt/memory/resource/async_device_resource.h" +#include "sxt/memory/resource/device_resource.h" +#include "sxt/proof/sumcheck/constant.h" +#include "sxt/proof/sumcheck/device_cache.h" +#include "sxt/proof/sumcheck/mle_utility.h" +#include "sxt/proof/sumcheck/polynomial_mapper.h" +#include "sxt/proof/sumcheck/polynomial_reducer.h" +#include "sxt/proof/sumcheck/reduction_gpu.h" namespace sxt::prfsk { -class device_cache; - //-------------------------------------------------------------------------------------------------- // sum_options //-------------------------------------------------------------------------------------------------- @@ -42,17 +52,175 @@ struct sum_options { unsigned split_factor = unsigned(basdv::get_num_devices()); }; +//-------------------------------------------------------------------------------------------------- +// partial_sum_kernel_impl +//-------------------------------------------------------------------------------------------------- +template +__device__ static void partial_sum_kernel_impl(T* __restrict__ shared_data, + const T* __restrict__ mles, + const unsigned* __restrict__ product_terms, + unsigned split, unsigned n) noexcept { + using Mapper = polynomial_mapper; + using Reducer = polynomial_reducer; + using U = Mapper::value_type; + Mapper mapper{ + .mles = mles, + .product_terms = product_terms, + .split = split, + .n = n, + }; + auto index = blockIdx.x * (BlockSize * 2) + threadIdx.x; + auto step = BlockSize * 2 * gridDim.x; + algr::thread_reduce(reinterpret_cast(shared_data), mapper, split, step, + threadIdx.x, index); +} + +//-------------------------------------------------------------------------------------------------- +// partial_sum_kernel +//-------------------------------------------------------------------------------------------------- +template +__global__ static void partial_sum_kernel(T* __restrict__ out, const T* __restrict__ mles, + const std::pair* __restrict__ product_table, + const unsigned* __restrict__ product_terms, + unsigned num_coefficients, unsigned split, + unsigned n) noexcept { + auto product_index = blockIdx.y; + auto num_terms = product_table[product_index].second; + auto thread_index = threadIdx.x; + auto output_step = gridDim.x * gridDim.y; + + // shared data for reduction + __shared__ T shared_data[2 * BlockSize * (max_degree_v + 1u)]; + + // adjust pointers + out += blockIdx.x; + out += gridDim.x * product_index; + for (unsigned i = 0; i < product_index; ++i) { + product_terms += product_table[i].second; + } + + // sum + basn::constexpr_switch<1, max_degree_v + 1u>( + num_terms, [&](std::integral_constant) noexcept { + partial_sum_kernel_impl(shared_data, mles, product_terms, split, n); + }); + + // write out result + auto mult = product_table[product_index].first; + for (unsigned i = thread_index; i < num_coefficients; i += BlockSize) { + auto output_index = output_step * i; + if (i < num_terms + 1u) { + mul(out[output_index], mult, shared_data[i]); + } else { + out[output_index] = T{}; + } + } +} + +//-------------------------------------------------------------------------------------------------- +// partial_sum +//-------------------------------------------------------------------------------------------------- +template +static xena::future<> partial_sum(basct::span p, basdv::stream& stream, basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned split, + unsigned n) noexcept { + auto num_coefficients = p.size(); + auto num_products = product_table.size(); + auto dims = algr::fit_reduction_kernel(split); + memr::async_device_resource resource{stream}; + + // partials + memmg::managed_array partials{num_coefficients * dims.num_blocks * num_products, &resource}; + xenk::launch_kernel(dims.block_size, [&]( + std::integral_constant) noexcept { + partial_sum_kernel<<>>( + partials.data(), mles.data(), product_table.data(), product_terms.data(), num_coefficients, + split, n); + }); + + // reduce partials + co_await reduce_sums(p, stream, partials); +} + //-------------------------------------------------------------------------------------------------- // sum_gpu //-------------------------------------------------------------------------------------------------- -xena::future<> sum_gpu(basct::span p, device_cache& cache, - const basit::split_options& options, basct::cspan mles, - unsigned n) noexcept; +template +xena::future<> sum_gpu(basct::span p, device_cache& cache, + const basit::split_options& options, basct::cspan mles, + unsigned n) noexcept { + auto num_variables = std::max(basn::ceil_log2(n), 1); + auto mid = 1u << (num_variables - 1u); + auto num_mles = mles.size() / n; + auto num_coefficients = p.size(); + + // split + auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, mid}, options); -xena::future<> sum_gpu(basct::span p, device_cache& cache, - basct::cspan mles, unsigned n) noexcept; + // sum + size_t counter = 0; + co_await xendv::concurrent_for_each( + chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> { + basdv::stream stream; + memr::async_device_resource resource{stream}; -xena::future<> sum_gpu(basct::span p, basct::cspan mles, - basct::cspan> product_table, - basct::cspan product_terms, unsigned n) noexcept; + // copy partial mles to device + memmg::managed_array partial_mles{&resource}; + copy_partial_mles(partial_mles, stream, mles, n, rng.a(), rng.b()); + auto split = rng.b() - rng.a(); + auto np = partial_mles.size() / num_mles; + + // lookup problem descriptor + basct::cspan> product_table; + basct::cspan product_terms; + cache.lookup(product_table, product_terms, stream); + + // compute + memmg::managed_array partial_p(num_coefficients); + co_await partial_sum(partial_p, stream, partial_mles, product_table, product_terms, + split, np); + + // fill in the result + if (counter == 0) { + for (unsigned i = 0; i < num_coefficients; ++i) { + p[i] = partial_p[i]; + } + } else { + for (unsigned i = 0; i < num_coefficients; ++i) { + add(p[i], p[i], partial_p[i]); + } + } + ++counter; + }); +} + +template +xena::future<> sum_gpu(basct::span p, device_cache& cache, basct::cspan mles, + unsigned n) noexcept { + basit::split_options options{ + .min_chunk_size = 100'000u, + .max_chunk_size = 200'000u, + .split_factor = basdv::get_num_devices(), + }; + co_await sum_gpu(p, cache, options, mles, n); +} + +template +xena::future<> sum_gpu(basct::span p, basct::cspan mles, + basct::cspan> product_table, + basct::cspan product_terms, unsigned n) noexcept { + auto num_variables = std::max(basn::ceil_log2(n), 1); + auto mid = 1u << (num_variables - 1u); + SXT_DEBUG_ASSERT( + // clang-format off + basdv::is_host_pointer(p.data()) && + basdv::is_active_device_pointer(mles.data()) && + basdv::is_active_device_pointer(product_table.data()) && + basdv::is_active_device_pointer(product_terms.data()) + // clang-format on + ); + basdv::stream stream; + co_await partial_sum(p, stream, mles, product_table, product_terms, mid, n); +} } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/sum_gpu.t.cc b/sxt/proof/sumcheck/sum_gpu.t.cc index 35807fd55..5538698ca 100644 --- a/sxt/proof/sumcheck/sum_gpu.t.cc +++ b/sxt/proof/sumcheck/sum_gpu.t.cc @@ -24,7 +24,7 @@ #include "sxt/execution/schedule/scheduler.h" #include "sxt/proof/sumcheck/device_cache.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; @@ -32,17 +32,19 @@ using namespace sxt::prfsk; using s25t::operator""_s25; TEST_CASE("we can sum MLEs") { - std::vector> product_table; + using T = s25t::element; + + std::vector> product_table; std::vector product_terms; - std::vector mles; - std::vector p(2); + std::vector mles; + std::vector p(2); SECTION("we can sum an MLE with a single term and n=1") { product_table = {{0x1_s25, 1}}; product_terms = {0}; - device_cache cache{product_table, product_terms}; + device_cache cache{product_table, product_terms}; mles = {0x123_s25}; - auto fut = sum_gpu(p, cache, mles, 1); + auto fut = sum_gpu(p, cache, mles, 1); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(p[0] == mles[0]); @@ -52,9 +54,9 @@ TEST_CASE("we can sum MLEs") { SECTION("we can sum an MLE with a single term, n=1, and a non-unity multiplier") { product_table = {{0x2_s25, 1}}; product_terms = {0}; - device_cache cache{product_table, product_terms}; + device_cache cache{product_table, product_terms}; mles = {0x123_s25}; - auto fut = sum_gpu(p, cache, mles, 1); + auto fut = sum_gpu(p, cache, mles, 1); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(p[0] == product_table[0].first * mles[0]); @@ -64,9 +66,9 @@ TEST_CASE("we can sum MLEs") { SECTION("we can sum an MLE with a single term and n=2") { product_table = {{0x1_s25, 1}}; product_terms = {0}; - device_cache cache{product_table, product_terms}; + device_cache cache{product_table, product_terms}; mles = {0x123_s25, 0x456_s25}; - auto fut = sum_gpu(p, cache, mles, 2); + auto fut = sum_gpu(p, cache, mles, 2); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(p[0] == mles[0]); @@ -77,9 +79,9 @@ TEST_CASE("we can sum MLEs") { p.resize(3); product_table = {{0x1_s25, 2}}; product_terms = {0, 1}; - device_cache cache{product_table, product_terms}; + device_cache cache{product_table, product_terms}; mles = {0x123_s25, 0x456_s25}; - auto fut = sum_gpu(p, cache, mles, 1); + auto fut = sum_gpu(p, cache, mles, 1); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(p[0] == mles[0] * mles[1]); @@ -93,9 +95,9 @@ TEST_CASE("we can sum MLEs") { {0x1_s25, 1}, }; product_terms = {0, 1}; - device_cache cache{product_table, product_terms}; + device_cache cache{product_table, product_terms}; mles = {0x123_s25, 0x456_s25}; - auto fut = sum_gpu(p, cache, mles, 1); + auto fut = sum_gpu(p, cache, mles, 1); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(p[0] == mles[0] + mles[1]); @@ -105,14 +107,14 @@ TEST_CASE("we can sum MLEs") { SECTION("we can chunk sums with n=4") { product_table = {{0x1_s25, 1}}; product_terms = {0}; - device_cache cache{product_table, product_terms}; + device_cache cache{product_table, product_terms}; mles = {0x123_s25, 0x456_s25, 0x789_s25, 0x91011_s25}; basit::split_options options{ .min_chunk_size = 1, .max_chunk_size = 1, .split_factor = 2, }; - auto fut = sum_gpu(p, cache, options, mles, 4); + auto fut = sum_gpu(p, cache, options, mles, 4); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(p[0] == mles[0] + mles[1]); @@ -122,14 +124,14 @@ TEST_CASE("we can sum MLEs") { SECTION("we can chunk sums with n=4") { product_table = {{0x1_s25, 1}}; product_terms = {0}; - device_cache cache{product_table, product_terms}; + device_cache cache{product_table, product_terms}; mles = {0x2_s25, 0x4_s25, 0x7_s25, 0x9_s25}; basit::split_options options{ .min_chunk_size = 16, .max_chunk_size = 16, .split_factor = 2, }; - auto fut = sum_gpu(p, cache, options, mles, 4); + auto fut = sum_gpu(p, cache, options, mles, 4); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(p[0] == mles[0] + mles[1]); diff --git a/sxt/proof/sumcheck/sumcheck_transcript.h b/sxt/proof/sumcheck/sumcheck_transcript.h index 2d690ded2..db871f62a 100644 --- a/sxt/proof/sumcheck/sumcheck_transcript.h +++ b/sxt/proof/sumcheck/sumcheck_transcript.h @@ -16,25 +16,19 @@ */ #pragma once -#include - #include "sxt/base/container/span.h" - -namespace sxt::s25t { -class element; -} +#include "sxt/base/field/element.h" namespace sxt::prfsk { //-------------------------------------------------------------------------------------------------- // sumcheck_transcript //-------------------------------------------------------------------------------------------------- -class sumcheck_transcript { +template class sumcheck_transcript { public: virtual ~sumcheck_transcript() noexcept = default; virtual void init(size_t num_variables, size_t round_degree) noexcept = 0; - virtual void round_challenge(s25t::element& r, - basct::cspan polynomial) noexcept = 0; + virtual void round_challenge(T& r, basct::cspan polynomial) noexcept = 0; }; } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/verification.cc b/sxt/proof/sumcheck/verification.cc index 96e0eb66f..16233df78 100644 --- a/sxt/proof/sumcheck/verification.cc +++ b/sxt/proof/sumcheck/verification.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,63 +15,3 @@ * limitations under the License. */ #include "sxt/proof/sumcheck/verification.h" - -#include "sxt/base/error/assert.h" -#include "sxt/base/log/log.h" -#include "sxt/proof/sumcheck/polynomial_utility.h" -#include "sxt/proof/sumcheck/sumcheck_transcript.h" -#include "sxt/scalar25/type/element.h" - -namespace sxt::prfsk { -//-------------------------------------------------------------------------------------------------- -// verify_sumcheck_no_evaluation -//-------------------------------------------------------------------------------------------------- -bool verify_sumcheck_no_evaluation(s25t::element& expected_sum, - basct::span evaluation_point, - sumcheck_transcript& transcript, - basct::cspan round_polynomials, - unsigned round_degree) noexcept { - auto num_variables = evaluation_point.size(); - SXT_RELEASE_ASSERT( - // clang-format off - num_variables > 0 && round_degree > 0 - // clang-format on - ); - - basl::info("verifying sumcheck of {} variables and round degree {}", num_variables, round_degree); - - // check dimensions - if (auto expected_count = (round_degree + 1u) * num_variables; - round_polynomials.size() != expected_count) { - basl::info("sumcheck verification failed: expected {} scalars for round_polynomials but got {}", - expected_count, round_polynomials.size()); - return false; - } - - transcript.init(num_variables, round_degree); - - // go through sumcheck rounds - for (unsigned round_index = 0; round_index < num_variables; ++round_index) { - auto polynomial = - round_polynomials.subspan((round_degree + 1u) * round_index, round_degree + 1u); - - // check sum - s25t::element sum; - sum_polynomial_01(sum, polynomial); - if (expected_sum != sum) { - basl::info("sumcheck verification failed on round {}", round_index + 1); - return false; - } - - // draw a random scalar - s25t::element r; - transcript.round_challenge(r, polynomial); - evaluation_point[round_index] = r; - - // evaluate at random point - evaluate_polynomial(expected_sum, polynomial, r); - } - - return true; -} -} // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/verification.h b/sxt/proof/sumcheck/verification.h index 3ceea1819..127725aaa 100644 --- a/sxt/proof/sumcheck/verification.h +++ b/sxt/proof/sumcheck/verification.h @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,23 +17,61 @@ #pragma once #include "sxt/base/container/span.h" - -namespace sxt::prft { -class transcript; -} -namespace sxt::s25t { -class element; -} +#include "sxt/base/error/assert.h" +#include "sxt/base/log/log.h" +#include "sxt/proof/sumcheck/polynomial_utility.h" +#include "sxt/proof/sumcheck/sumcheck_transcript.h" namespace sxt::prfsk { -class sumcheck_transcript; - //-------------------------------------------------------------------------------------------------- // verify_sumcheck_no_evaluation //-------------------------------------------------------------------------------------------------- -bool verify_sumcheck_no_evaluation(s25t::element& expected_sum, - basct::span evaluation_point, - sumcheck_transcript& transcript, - basct::cspan round_polynomials, - unsigned round_degree) noexcept; +template +bool verify_sumcheck_no_evaluation(T& expected_sum, basct::span evaluation_point, + sumcheck_transcript& transcript, + basct::cspan round_polynomials, + unsigned round_degree) noexcept { + auto num_variables = evaluation_point.size(); + SXT_RELEASE_ASSERT( + // clang-format off + num_variables > 0 && round_degree > 0 + // clang-format on + ); + + basl::info("verifying sumcheck of {} variables and round degree {}", num_variables, round_degree); + + // check dimensions + if (auto expected_count = (round_degree + 1u) * num_variables; + round_polynomials.size() != expected_count) { + basl::info("sumcheck verification failed: expected {} scalars for round_polynomials but got {}", + expected_count, round_polynomials.size()); + return false; + } + + transcript.init(num_variables, round_degree); + + // go through sumcheck rounds + for (unsigned round_index = 0; round_index < num_variables; ++round_index) { + auto polynomial = + round_polynomials.subspan((round_degree + 1u) * round_index, round_degree + 1u); + + // check sum + T sum; + sum_polynomial_01(sum, polynomial); + if (expected_sum != sum) { + basl::info("sumcheck verification failed on round {}", round_index + 1); + return false; + } + + // draw a random scalar + T r; + transcript.round_challenge(r, polynomial); + evaluation_point[round_index] = r; + + // evaluate at random point + evaluate_polynomial(expected_sum, polynomial, r); + } + + return true; +} } // namespace sxt::prfsk diff --git a/sxt/proof/sumcheck/verification.t.cc b/sxt/proof/sumcheck/verification.t.cc index b2c2ccfe5..3c1bc60ef 100644 --- a/sxt/proof/sumcheck/verification.t.cc +++ b/sxt/proof/sumcheck/verification.t.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ #include "sxt/proof/sumcheck/reference_transcript.h" #include "sxt/proof/transcript/transcript.h" #include "sxt/scalar25/operation/overload.h" -#include "sxt/scalar25/type/element.h" +#include "sxt/scalar25/realization/field.h" #include "sxt/scalar25/type/literal.h" using namespace sxt; @@ -30,29 +30,30 @@ using namespace sxt::prfsk; using sxt::s25t::operator""_s25; TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { + using T = s25t::element; s25t::element expected_sum = 0x0_s25; - std::vector evaluation_point = {0x0_s25}; + std::vector evaluation_point = {0x0_s25}; prft::transcript base_transcript{"abc"}; - reference_transcript transcript{base_transcript}; - std::vector round_polynomials = {0x0_s25, 0x0_s25}; + reference_transcript transcript{base_transcript}; + std::vector round_polynomials = {0x0_s25, 0x0_s25}; SECTION("verification fails if dimensions don't match") { - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 2); + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 2); REQUIRE(!res); } SECTION("we can verify a single round") { - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 1); REQUIRE(res); REQUIRE(evaluation_point[0] != 0x0_s25); } SECTION("verification fails if the round polynomial doesn't match the sum") { round_polynomials[1] = 0x1_s25; - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 1); REQUIRE(!res); } @@ -69,9 +70,9 @@ TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { s25t::element r; { prft::transcript base_transcript_p{"abc"}; - reference_transcript transcript_p{base_transcript_p}; + reference_transcript transcript_p{base_transcript_p}; transcript_p.init(2, 1); - transcript_p.round_challenge(r, basct::span{round_polynomials}.subspan(0, 2)); + transcript_p.round_challenge(r, basct::span{round_polynomials}.subspan(0, 2)); } // round 2 @@ -81,8 +82,8 @@ TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { // prove evaluation_point.resize(2); - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 1); REQUIRE(evaluation_point[0] == r); REQUIRE(res); } @@ -106,8 +107,8 @@ TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { // prove evaluation_point.resize(2); - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 1); + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 1); REQUIRE(!res); } @@ -121,8 +122,8 @@ TEST_CASE("we can verify a sumcheck proof up to the polynomial evaluation") { (-0x3_s25 - 0x7_s25) * (0x2_s25 + 0x4_s25), }; expected_sum = 0x3_s25 * -0x2_s25 - 0x7_s25 * 0x4_s25; - auto res = sxt::prfsk::verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, - round_polynomials, 2); + auto res = verify_sumcheck_no_evaluation(expected_sum, evaluation_point, transcript, + round_polynomials, 2); REQUIRE(res); REQUIRE(evaluation_point[0] != 0x0_s25); } diff --git a/sxt/proof/sumcheck/workspace.cc b/sxt/proof/sumcheck/workspace.cc index d356b4af7..997e37cd3 100644 --- a/sxt/proof/sumcheck/workspace.cc +++ b/sxt/proof/sumcheck/workspace.cc @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/sxt/proof/sumcheck/workspace.h b/sxt/proof/sumcheck/workspace.h index edacbd2b5..2a46631e1 100644 --- a/sxt/proof/sumcheck/workspace.h +++ b/sxt/proof/sumcheck/workspace.h @@ -1,6 +1,6 @@ /** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. * - * Copyright 2024-present Space and Time Labs, Inc. + * Copyright 2025-present Space and Time Labs, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/sxt/scalar25/realization/BUILD b/sxt/scalar25/realization/BUILD new file mode 100644 index 000000000..bca94a28a --- /dev/null +++ b/sxt/scalar25/realization/BUILD @@ -0,0 +1,18 @@ +load( + "//bazel:sxt_build_system.bzl", + "sxt_cc_component", +) + +sxt_cc_component( + name = "field", + with_test = False, + deps = [ + "//sxt/base/field:element", + "//sxt/scalar25/operation:add", + "//sxt/scalar25/operation:mul", + "//sxt/scalar25/operation:muladd", + "//sxt/scalar25/operation:neg", + "//sxt/scalar25/operation:sub", + "//sxt/scalar25/type:element", + ], +) diff --git a/sxt/scalar25/realization/field.cc b/sxt/scalar25/realization/field.cc new file mode 100644 index 000000000..a2452d2e1 --- /dev/null +++ b/sxt/scalar25/realization/field.cc @@ -0,0 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/scalar25/realization/field.h" diff --git a/sxt/scalar25/realization/field.h b/sxt/scalar25/realization/field.h new file mode 100644 index 000000000..8ce413ac8 --- /dev/null +++ b/sxt/scalar25/realization/field.h @@ -0,0 +1,27 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "sxt/base/field/element.h" +#include "sxt/scalar25/operation/add.h" +#include "sxt/scalar25/operation/mul.h" +#include "sxt/scalar25/operation/muladd.h" +#include "sxt/scalar25/operation/neg.h" +#include "sxt/scalar25/operation/sub.h" +#include "sxt/scalar25/type/element.h" + +static_assert(sxt::basfld::element); diff --git a/sxt/scalar25/type/BUILD b/sxt/scalar25/type/BUILD index 5a463e5ce..cf154a0b9 100644 --- a/sxt/scalar25/type/BUILD +++ b/sxt/scalar25/type/BUILD @@ -3,6 +3,11 @@ load( "sxt_cc_component", ) +sxt_cc_component( + name = "operation_adl_stub", + with_test = False, +) + sxt_cc_component( name = "element", impl_deps = [ @@ -12,6 +17,7 @@ sxt_cc_component( "//sxt/base/test:unit_test", ], deps = [ + ":operation_adl_stub", "//sxt/base/macro:cuda_callable", ], ) diff --git a/sxt/scalar25/type/element.h b/sxt/scalar25/type/element.h index 3882b5c89..5e78cf638 100644 --- a/sxt/scalar25/type/element.h +++ b/sxt/scalar25/type/element.h @@ -23,6 +23,7 @@ #include #include "sxt/base/macro/cuda_callable.h" +#include "sxt/scalar25/type/operation_adl_stub.h" namespace sxt::s25t { //-------------------------------------------------------------------------------------------------- @@ -33,7 +34,7 @@ namespace sxt::s25t { * L being the order of the main subgroup * (L = 2^252 + 27742317777372353535851937790883648493). */ -class element { +class element : public s25o::operation_adl_stub { public: element() noexcept = default; @@ -53,6 +54,12 @@ class element { static constexpr element identity() noexcept { return element{}; }; + static constexpr element one() noexcept { + element res{}; + res.data_[0] = 1; + return res; + } + private: uint8_t data_[32]; }; diff --git a/sxt/scalar25/type/operation_adl_stub.cc b/sxt/scalar25/type/operation_adl_stub.cc new file mode 100644 index 000000000..95c90e307 --- /dev/null +++ b/sxt/scalar25/type/operation_adl_stub.cc @@ -0,0 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/scalar25/type/operation_adl_stub.h" diff --git a/sxt/scalar25/type/operation_adl_stub.h b/sxt/scalar25/type/operation_adl_stub.h new file mode 100644 index 000000000..34a7e577f --- /dev/null +++ b/sxt/scalar25/type/operation_adl_stub.h @@ -0,0 +1,28 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2025-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace sxt::s25o { +//-------------------------------------------------------------------------------------------------- +// operation_adl_stub +//-------------------------------------------------------------------------------------------------- +/** + * A stub class that can be inherited so that functions in the s25o namespace + * will participate in ADL. + */ +struct operation_adl_stub {}; +}; // namespace sxt::s25o