Skip to content

Commit

Permalink
Add comments to explain the test
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Aziz <[email protected]>
  • Loading branch information
0x12CC committed Nov 7, 2023
1 parent 66e7b78 commit 2415bff
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/extension/oneapi_auto_local_range/auto_local_range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ static void check_auto_range() {
sycl::buffer<int, Dimensions> output_buffer{N};

{
// Create an input sequence [1, 2, 3, 4, ..., N].
auto input = input_buffer.get_host_access();
std::iota(input.begin(), input.end(), 1);
}
Expand All @@ -53,21 +54,37 @@ static void check_auto_range() {
sycl::accessor output{output_buffer, cgh, sycl::write_only};
sycl::range<Dimensions> auto_range =
sycl::ext::oneapi::experimental::auto_range<Dimensions>();
// Launch a kernel with a global range of N and a local range chosen by the
// SYCL extension implementation.
cgh.parallel_for(sycl::nd_range<Dimensions>{N, auto_range}, [=](auto it) {
sycl::group<Dimensions> g = it.get_group();
int local_accumulator = 0;

// Each work item computes the sum of a subset of the input values and
// stores the result in local_accumulator. The calls to
// get_local_linear_id() and get_local_linear_range() ensure that the set
// of input values is divided between the work items in a group
// regardless of the group size chosen by the auto_range implementation.
for (size_t i = it.get_local_linear_id(); i < N.size();
i += g.get_local_linear_range()) {
// The unlinearize function maps each value of i to a unique value in
// the input. It's needed since multi-dimensional accessors cannot be
// indexed using a scalar.
int value = input[unlinearize(N, i)];
local_accumulator += value;
}

// The total sum of the input values is computed using a reduce operation
// with the partial sum from each work item in a group.
int total =
sycl::reduce_over_group(g, local_accumulator, sycl::plus<>());
output[it.get_global_id()] = total;
});
}).wait();

{
// Compare the output values to the expected sum computed using the formula
// for triangular numbers.
const int expected_sum = (N.size() * (N.size() + 1)) / 2;
auto output = output_buffer.get_host_access();
for (const auto& it : output) {
Expand Down

0 comments on commit 2415bff

Please sign in to comment.