Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sub-group 'of' tests #807

Merged
merged 5 commits into from
Nov 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 60 additions & 33 deletions tests/group_functions/group_of.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,14 @@ void predicate_function_of_sub_group(sycl::queue& queue) {
sycl::range<D> work_group_range = sycl_cts::util::work_group_range<D>(queue);

// array to return results: 4 predicates * 3 functions
bool res[test_matrix * test_cases] = {false};
constexpr int total_case_count = test_matrix * test_cases;
bool res[total_case_count];
// Initially fill the results array with 'true'. Each sub-group test 'ands'
// with this to ensure every sub-group in the work-group returns the correct
// result.
std::fill(res, res + total_case_count, true);
{
sycl::buffer<bool, 1> res_sycl(res,
sycl::range<1>(test_matrix * test_cases));
sycl::buffer<bool, 1> res_sycl(res, sycl::range<1>(total_case_count));

queue.submit([&](sycl::handler& cgh) {
auto res_acc = res_sycl.get_access<sycl::access::mode::read_write>(cgh);
Expand All @@ -306,47 +310,57 @@ void predicate_function_of_sub_group(sycl::queue& queue) {
sycl::sub_group sub_group = item.get_sub_group();
T size = sub_group.get_local_linear_range();

T local_var(item.get_global_linear_id() + 1);
// Use the sub-group local ID (plus 1) as a variable against which to
// test our predicates. Note that this has a well-defined set of values
// [1,2,...,N] where N is the sub-group size. Note that the sub-group
// could also just be of size 1.
T local_var(sub_group.get_local_linear_id() + 1);

// predicates
// The variable is never 1 for any member of the sub-group
auto none_true = [&](T i) { return i == 0; };
// Exactly one member of the sub-group has value 1 (the first)
auto one_true = [&](T i) { return i == 1; };
// Some (or all, for sub-groups of size 1) members of the sub-group have
// this value
auto some_true = [&](T i) { return i > size / 2; };
// The variable is less than or equal to the sub-group size for all
// members of the sub-group.
auto all_true = [&](T i) { return i <= size; };

if (sub_group.get_group_id()[0] == 0) {
{
ASSERT_RETURN_TYPE(
bool, sycl::any_of_group(sub_group, local_var, none_true),
"Return type of any_of_group(Sub_group g, bool pred) is wrong\n");
res_acc[0] = !sycl::any_of_group(sub_group, local_var, none_true);
res_acc[1] = sycl::any_of_group(sub_group, local_var, one_true);
res_acc[2] = sycl::any_of_group(sub_group, local_var, some_true);
res_acc[3] = sycl::any_of_group(sub_group, local_var, all_true);
res_acc[0] &= !sycl::any_of_group(sub_group, local_var, none_true);
res_acc[1] &= sycl::any_of_group(sub_group, local_var, one_true);
res_acc[2] &= sycl::any_of_group(sub_group, local_var, some_true);
res_acc[3] &= sycl::any_of_group(sub_group, local_var, all_true);

ASSERT_RETURN_TYPE(
bool, sycl::all_of_group(sub_group, local_var, none_true),
"Return type of all_of_group(Sub_group g, bool pred) is wrong\n");
res_acc[4] = !sycl::all_of_group(sub_group, local_var, none_true);
res_acc[4] &= !sycl::all_of_group(sub_group, local_var, none_true);
// Note that 'one_true' returns true for the first item. Thus in the
// case that the sub-group size is 1, check that all items match;
// otherwise check that not all items match.
res_acc[5] =
res_acc[5] &=
sycl::all_of_group(sub_group, local_var, one_true) ^ (size != 1);
// Note that 'some_true' returns true for the first item if the
// sub-group size is 1. In that case, check that all items match;
// otherwise check that not all items match.
res_acc[6] =
res_acc[6] &=
sycl::all_of_group(sub_group, local_var, some_true) ^ (size != 1);
res_acc[7] = sycl::all_of_group(sub_group, local_var, all_true);
res_acc[7] &= sycl::all_of_group(sub_group, local_var, all_true);

ASSERT_RETURN_TYPE(
bool, sycl::none_of_group(sub_group, local_var, none_true),
"Return type of none_of_group(Sub_group g, bool pred) is "
"wrong\n");
res_acc[8] = sycl::none_of_group(sub_group, local_var, none_true);
res_acc[9] = !sycl::none_of_group(sub_group, local_var, one_true);
res_acc[10] = !sycl::none_of_group(sub_group, local_var, some_true);
res_acc[11] = !sycl::none_of_group(sub_group, local_var, all_true);
res_acc[8] &= sycl::none_of_group(sub_group, local_var, none_true);
res_acc[9] &= !sycl::none_of_group(sub_group, local_var, one_true);
res_acc[10] &= !sycl::none_of_group(sub_group, local_var, some_true);
res_acc[11] &= !sycl::none_of_group(sub_group, local_var, all_true);
}
});
});
Expand Down Expand Up @@ -477,11 +491,15 @@ void bool_function_of_sub_group(sycl::queue& queue) {

sycl::range<D> work_group_range = sycl_cts::util::work_group_range<D>(queue);

// array to return results
bool res[test_matrix * test_cases] = {false};
// array to return results: 4 predicates * 3 functions
constexpr int total_case_count = test_matrix * test_cases;
bool res[total_case_count];
// Initially fill the results array with 'true'. Each sub-group test 'ands'
// with this to ensure every sub-group in the work-group returns the correct
// result.
std::fill(res, res + total_case_count, true);
{
sycl::buffer<bool, 1> res_sycl(res,
sycl::range<1>(test_matrix * test_cases));
sycl::buffer<bool, 1> res_sycl(res, sycl::range<1>(total_case_count));

queue.submit([&](sycl::handler& cgh) {
auto res_acc = res_sycl.get_access<sycl::access::mode::read_write>(cgh);
Expand All @@ -493,23 +511,32 @@ void bool_function_of_sub_group(sycl::queue& queue) {
sycl::sub_group sub_group = item.get_sub_group();
T size = sub_group.get_local_linear_range();

// Use the sub-group local ID (plus 1) as a variable against which to
// test our predicates. Note that this has a well-defined set of values
// [1,2,...,N] where N is the sub-group size. Note that the sub-group
// could also just be of size 1.
T local_var(sub_group.get_local_linear_id() + 1);

// predicates
// The variable is never 1 for any member of the sub-group
auto none_true = [&](T i) { return i == 0; };
// Exactly one member of the sub-group has value 1 (the first)
auto one_true = [&](T i) { return i == 1; };
// Some (or all, for sub-groups of size 1) members of the sub-group have
// this value
auto some_true = [&](T i) { return i > size / 2; };
// The variable is less than or equal to the sub-group size for all
// members of the sub-group.
auto all_true = [&](T i) { return i <= size; };

// if(sub_group.get_group_linear_id() == 0)
{
ASSERT_RETURN_TYPE(
bool, sycl::any_of_group(sub_group, none_true(local_var)),
"Return type of any_of_group(sub_group g, bool pred) is wrong\n");
res_acc[0] = !sycl::any_of_group(sub_group, none_true(local_var));
res_acc[1] = sycl::any_of_group(sub_group, one_true(local_var));
res_acc[2] = sycl::any_of_group(sub_group, some_true(local_var));
res_acc[3] = sycl::any_of_group(sub_group, all_true(local_var));
res_acc[0] &= !sycl::any_of_group(sub_group, none_true(local_var));
res_acc[1] &= sycl::any_of_group(sub_group, one_true(local_var));
res_acc[2] &= sycl::any_of_group(sub_group, some_true(local_var));
res_acc[3] &= sycl::any_of_group(sub_group, all_true(local_var));

ASSERT_RETURN_TYPE(
bool, sycl::all_of_group(sub_group, none_true(local_var)),
Expand All @@ -518,23 +545,23 @@ void bool_function_of_sub_group(sycl::queue& queue) {
// Note that 'one_true' returns true for the first item. Thus in the
// case that the sub-group size is 1, check that all items match;
// otherwise check that not all items match.
res_acc[5] =
res_acc[5] &=
sycl::all_of_group(sub_group, one_true(local_var)) ^ (size != 1);
// Note that 'some_true' returns true for the first item if the
// sub-group size is 1. In that case, check that all items match;
// otherwise check that not all items match.
res_acc[6] =
res_acc[6] &=
sycl::all_of_group(sub_group, some_true(local_var)) ^ (size != 1);
res_acc[7] = sycl::all_of_group(sub_group, all_true(local_var));
res_acc[7] &= sycl::all_of_group(sub_group, all_true(local_var));

ASSERT_RETURN_TYPE(
bool, sycl::none_of_group(sub_group, none_true(local_var)),
"Return type of none_of_group(sub_group g, bool pred) is "
"wrong\n");
res_acc[8] = sycl::none_of_group(sub_group, none_true(local_var));
res_acc[9] = !sycl::none_of_group(sub_group, one_true(local_var));
res_acc[10] = !sycl::none_of_group(sub_group, some_true(local_var));
res_acc[11] = !sycl::none_of_group(sub_group, all_true(local_var));
res_acc[8] &= sycl::none_of_group(sub_group, none_true(local_var));
res_acc[9] &= !sycl::none_of_group(sub_group, one_true(local_var));
res_acc[10] &= !sycl::none_of_group(sub_group, some_true(local_var));
res_acc[11] &= !sycl::none_of_group(sub_group, all_true(local_var));
}
});
});
Expand Down