Skip to content

Commit

Permalink
add sum function for bucket method
Browse files Browse the repository at this point in the history
  • Loading branch information
rnburn committed Feb 27, 2024
1 parent bd8dafb commit 9a931a8
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 0 deletions.
34 changes: 34 additions & 0 deletions sxt/multiexp/bucket_method2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,37 @@ sxt_cc_component(
"//sxt/memory/resource:async_device_resource",
],
)

sxt_cc_component(
name = "sum",
test_deps = [
"//sxt/base/curve:example_element",
"//sxt/base/device:synchronization",
"//sxt/base/test:unit_test",
"//sxt/execution/schedule:scheduler",
"//sxt/memory/resource:managed_device_resource",
],
deps = [
":multiproduct_table",
"//sxt/algorithm/iteration:for_each",
"//sxt/base/container:span",
"//sxt/base/curve:element",
"//sxt/base/device:memory_utility",
"//sxt/base/device:property",
"//sxt/base/device:stream",
"//sxt/base/error:assert",
"//sxt/base/iterator:index_range",
"//sxt/base/iterator:index_range_iterator",
"//sxt/base/iterator:index_range_utility",
"//sxt/base/log",
"//sxt/base/num:divide_up",
"//sxt/execution/async:coroutine",
"//sxt/execution/async:future",
"//sxt/execution/device:device_viewable",
"//sxt/execution/device:for_each",
"//sxt/execution/device:synchronization",
"//sxt/memory/management:managed_array",
"//sxt/memory/resource:async_device_resource",
"//sxt/memory/resource:pinned_resource",
],
)
17 changes: 17 additions & 0 deletions sxt/multiexp/bucket_method2/sum.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/multiexp/bucket_method2/sum.h"
127 changes: 127 additions & 0 deletions sxt/multiexp/bucket_method2/sum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <vector>

#include "sxt/algorithm/iteration/for_each.h"
#include "sxt/base/container/span.h"
#include "sxt/base/container/span_utility.h"
#include "sxt/base/curve/element.h"
#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/property.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/iterator/index_range.h"
#include "sxt/base/iterator/index_range_iterator.h"
#include "sxt/base/iterator/index_range_utility.h"
#include "sxt/base/log/log.h"
#include "sxt/base/num/divide_up.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/device/for_each.h"
#include "sxt/execution/device/synchronization.h"
#include "sxt/memory/management/managed_array.h"
#include "sxt/memory/resource/async_device_resource.h"
#include "sxt/memory/resource/device_resource.h"
#include "sxt/memory/resource/pinned_resource.h"
#include "sxt/multiexp/bucket_method2/multiproduct_table.h"

namespace sxt::mtxbk2 {
//--------------------------------------------------------------------------------------------------
// sum_bucket
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
CUDA_CALLABLE void sum_bucket(T* __restrict__ sums, const T* __restrict__ generators,
const uint16_t* __restrict__ bucket_prefix_counts,
const uint16_t* __restrict__ indexes, unsigned num_buckets_per_digit,
unsigned n, unsigned index) noexcept {
auto digit_index = index / num_buckets_per_digit;
auto bucket_offset = index % num_buckets_per_digit;

// adjust pointers
auto& sum = sums[index];
bucket_prefix_counts += digit_index * num_buckets_per_digit;
indexes += digit_index * n;

// sum
uint16_t first;
if (bucket_offset == 0) {
first = 0;
} else {
first = bucket_prefix_counts[bucket_offset - 1u];
}
auto last = bucket_prefix_counts[bucket_offset];
if (first == last) {
sum = T::identity();
return;
}
T e = generators[indexes[first++]];
for (; first != last; ++first) {
auto t = generators[indexes[first]];
add_inplace(e, t);
}
sum = e;
}

//--------------------------------------------------------------------------------------------------
// sum_buckets
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
xena::future<> sum_buckets(basct::span<T> sums, basct::cspan<T> generators,
basct::cspan<const uint8_t*> exponents, unsigned element_num_bytes,
unsigned bit_width) noexcept {
auto num_buckets_per_digit = (1u << bit_width) - 1u;
auto num_digits = basn::divide_up(element_num_bytes * 8u, bit_width);
auto num_outputs = static_cast<unsigned>(exponents.size());
auto num_buckets_total = static_cast<unsigned>(sums.size());
auto n = static_cast<unsigned>(generators.size());
SXT_DEBUG_ASSERT(basdv::is_active_device_pointer(sums.data()));

// compute multiproduct table
memmg::managed_array<uint16_t> bucket_prefix_counts{num_buckets_total,
memr::get_device_resource()};
memmg::managed_array<uint16_t> indexes{n * num_digits * num_outputs, memr::get_device_resource()};
auto fut = make_multiproduct_table(bucket_prefix_counts, indexes, exponents, element_num_bytes,
bit_width, n);

// copy generators to device
basdv::stream stream;
memr::async_device_resource resource{stream};
memmg::managed_array<T> generators_dev{n, &resource};
basdv::async_copy_host_to_device(generators_dev, generators, stream);

// sum buckets
memmg::managed_array<T> sums_dev{num_buckets_total, &resource};
co_await std::move(fut);
basl::info("summing {} buckets", num_buckets_total);
auto f = [
// clang-format off
sums = sums.data(),
generators = generators_dev.data(),
bucket_prefix_counts = bucket_prefix_counts.data(),
indexes = indexes.data(),
num_buckets_per_digit = num_buckets_per_digit,
n = n
// clang-format on
] __device__
__host__(unsigned /*num_buckets_total*/, unsigned index) noexcept {
sum_bucket<T>(sums, generators, bucket_prefix_counts, indexes, num_buckets_per_digit,
n, index);
};
algi::launch_for_each_kernel(stream, f, num_buckets_total);
co_await xendv::await_stream(stream);
}
} // namespace sxt::mtxbk2
76 changes: 76 additions & 0 deletions sxt/multiexp/bucket_method2/sum.t.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/multiexp/bucket_method2/sum.h"

#include <vector>

#include "sxt/base/curve/example_element.h"
#include "sxt/base/device/synchronization.h"
#include "sxt/base/test/unit_test.h"
#include "sxt/execution/schedule/scheduler.h"
#include "sxt/memory/resource/managed_device_resource.h"

using namespace sxt;
using namespace sxt::mtxbk2;

TEST_CASE("we can compute the bucket sums for a chunk") {
using E = bascrv::element97;
const unsigned element_num_bytes = 32;
const unsigned bit_width = 8;

std::pmr::vector<E> sums(255 * 32, memr::get_managed_device_resource());
std::vector<E> generators = {3u};
std::vector<const uint8_t*> scalars;

std::pmr::vector<E> expected(sums.size());

SECTION("we can compute bucket sums for a single exponent of zero") {
std::vector<uint8_t> scalars1(32);
scalars = {scalars1.data()};
auto fut = sum_buckets<E>(sums, generators, scalars, element_num_bytes, bit_width);
xens::get_scheduler().run();
REQUIRE(fut.ready());
basdv::synchronize_device();
REQUIRE(sums == expected);
}

SECTION("we can compute the bucket sums for a single exponent of one") {
std::vector<uint8_t> scalars1(32);
scalars1[0] = 1;
scalars = {scalars1.data()};
auto fut = sum_buckets<E>(sums, generators, scalars, element_num_bytes, bit_width);
xens::get_scheduler().run();
REQUIRE(fut.ready());
basdv::synchronize_device();
expected[0] = generators[0];
REQUIRE(sums == expected);
}

SECTION("we can compute the bucket sums for two exponents of one") {
generators = {3u, 4u};
std::vector<uint8_t> scalars1(64);
scalars1[0] = 1;
scalars1[32] = 1;
scalars = {scalars1.data()};
auto fut = sum_buckets<E>(sums, generators, scalars, element_num_bytes, bit_width);
xens::get_scheduler().run();
REQUIRE(fut.ready());
basdv::synchronize_device();
expected[0] = generators[0].value + generators[1].value;
REQUIRE(sums == expected);
}
}

0 comments on commit 9a931a8

Please sign in to comment.