diff --git a/tests/group_functions/group_of.h b/tests/group_functions/group_of.h index 4687574f3..95d18953c 100644 --- a/tests/group_functions/group_of.h +++ b/tests/group_functions/group_of.h @@ -291,10 +291,14 @@ void predicate_function_of_sub_group(sycl::queue& queue) { sycl::range work_group_range = sycl_cts::util::work_group_range(queue); // array to return results: 4 predicates * 3 functions - bool res[test_matrix * test_cases] = {false}; + constexpr int total_case_count = test_matrix * test_cases; + bool res[total_case_count]; + // Initially fill the results array with 'true'. Each sub-group test 'ands' + // with this to ensure every sub-group in the work-group returns the correct + // result. + std::fill(res, res + total_case_count, true); { - sycl::buffer res_sycl(res, - sycl::range<1>(test_matrix * test_cases)); + sycl::buffer res_sycl(res, sycl::range<1>(total_case_count)); queue.submit([&](sycl::handler& cgh) { auto res_acc = res_sycl.get_access(cgh); @@ -306,47 +310,57 @@ void predicate_function_of_sub_group(sycl::queue& queue) { sycl::sub_group sub_group = item.get_sub_group(); T size = sub_group.get_local_linear_range(); - T local_var(item.get_global_linear_id() + 1); + // Use the sub-group local ID (plus 1) as a variable against which to + // test our predicates. Note that this has a well-defined set of values + // [1,2,...,N] where N is the sub-group size. Note that the sub-group + // could also just be of size 1. + T local_var(sub_group.get_local_linear_id() + 1); // predicates + // The variable is never 1 for any member of the sub-group auto none_true = [&](T i) { return i == 0; }; + // Exactly one member of the sub-group has value 1 (the first) auto one_true = [&](T i) { return i == 1; }; + // Some (or all, for sub-groups of size 1) members of the sub-group have + // this value auto some_true = [&](T i) { return i > size / 2; }; + // The variable is less than or equal to the sub-group size for all + // members of the sub-group. auto all_true = [&](T i) { return i <= size; }; - if (sub_group.get_group_id()[0] == 0) { + { ASSERT_RETURN_TYPE( bool, sycl::any_of_group(sub_group, local_var, none_true), "Return type of any_of_group(Sub_group g, bool pred) is wrong\n"); - res_acc[0] = !sycl::any_of_group(sub_group, local_var, none_true); - res_acc[1] = sycl::any_of_group(sub_group, local_var, one_true); - res_acc[2] = sycl::any_of_group(sub_group, local_var, some_true); - res_acc[3] = sycl::any_of_group(sub_group, local_var, all_true); + res_acc[0] &= !sycl::any_of_group(sub_group, local_var, none_true); + res_acc[1] &= sycl::any_of_group(sub_group, local_var, one_true); + res_acc[2] &= sycl::any_of_group(sub_group, local_var, some_true); + res_acc[3] &= sycl::any_of_group(sub_group, local_var, all_true); ASSERT_RETURN_TYPE( bool, sycl::all_of_group(sub_group, local_var, none_true), "Return type of all_of_group(Sub_group g, bool pred) is wrong\n"); - res_acc[4] = !sycl::all_of_group(sub_group, local_var, none_true); + res_acc[4] &= !sycl::all_of_group(sub_group, local_var, none_true); // Note that 'one_true' returns true for the first item. Thus in the // case that the sub-group size is 1, check that all items match; // otherwise check that not all items match. - res_acc[5] = + res_acc[5] &= sycl::all_of_group(sub_group, local_var, one_true) ^ (size != 1); // Note that 'some_true' returns true for the first item if the // sub-group size is 1. In that case, check that all items match; // otherwise check that not all items match. - res_acc[6] = + res_acc[6] &= sycl::all_of_group(sub_group, local_var, some_true) ^ (size != 1); - res_acc[7] = sycl::all_of_group(sub_group, local_var, all_true); + res_acc[7] &= sycl::all_of_group(sub_group, local_var, all_true); ASSERT_RETURN_TYPE( bool, sycl::none_of_group(sub_group, local_var, none_true), "Return type of none_of_group(Sub_group g, bool pred) is " "wrong\n"); - res_acc[8] = sycl::none_of_group(sub_group, local_var, none_true); - res_acc[9] = !sycl::none_of_group(sub_group, local_var, one_true); - res_acc[10] = !sycl::none_of_group(sub_group, local_var, some_true); - res_acc[11] = !sycl::none_of_group(sub_group, local_var, all_true); + res_acc[8] &= sycl::none_of_group(sub_group, local_var, none_true); + res_acc[9] &= !sycl::none_of_group(sub_group, local_var, one_true); + res_acc[10] &= !sycl::none_of_group(sub_group, local_var, some_true); + res_acc[11] &= !sycl::none_of_group(sub_group, local_var, all_true); } }); }); @@ -477,11 +491,15 @@ void bool_function_of_sub_group(sycl::queue& queue) { sycl::range work_group_range = sycl_cts::util::work_group_range(queue); - // array to return results - bool res[test_matrix * test_cases] = {false}; + // array to return results: 4 predicates * 3 functions + constexpr int total_case_count = test_matrix * test_cases; + bool res[total_case_count]; + // Initially fill the results array with 'true'. Each sub-group test 'ands' + // with this to ensure every sub-group in the work-group returns the correct + // result. + std::fill(res, res + total_case_count, true); { - sycl::buffer res_sycl(res, - sycl::range<1>(test_matrix * test_cases)); + sycl::buffer res_sycl(res, sycl::range<1>(total_case_count)); queue.submit([&](sycl::handler& cgh) { auto res_acc = res_sycl.get_access(cgh); @@ -493,23 +511,32 @@ void bool_function_of_sub_group(sycl::queue& queue) { sycl::sub_group sub_group = item.get_sub_group(); T size = sub_group.get_local_linear_range(); + // Use the sub-group local ID (plus 1) as a variable against which to + // test our predicates. Note that this has a well-defined set of values + // [1,2,...,N] where N is the sub-group size. Note that the sub-group + // could also just be of size 1. T local_var(sub_group.get_local_linear_id() + 1); // predicates + // The variable is never 1 for any member of the sub-group auto none_true = [&](T i) { return i == 0; }; + // Exactly one member of the sub-group has value 1 (the first) auto one_true = [&](T i) { return i == 1; }; + // Some (or all, for sub-groups of size 1) members of the sub-group have + // this value auto some_true = [&](T i) { return i > size / 2; }; + // The variable is less than or equal to the sub-group size for all + // members of the sub-group. auto all_true = [&](T i) { return i <= size; }; - // if(sub_group.get_group_linear_id() == 0) { ASSERT_RETURN_TYPE( bool, sycl::any_of_group(sub_group, none_true(local_var)), "Return type of any_of_group(sub_group g, bool pred) is wrong\n"); - res_acc[0] = !sycl::any_of_group(sub_group, none_true(local_var)); - res_acc[1] = sycl::any_of_group(sub_group, one_true(local_var)); - res_acc[2] = sycl::any_of_group(sub_group, some_true(local_var)); - res_acc[3] = sycl::any_of_group(sub_group, all_true(local_var)); + res_acc[0] &= !sycl::any_of_group(sub_group, none_true(local_var)); + res_acc[1] &= sycl::any_of_group(sub_group, one_true(local_var)); + res_acc[2] &= sycl::any_of_group(sub_group, some_true(local_var)); + res_acc[3] &= sycl::any_of_group(sub_group, all_true(local_var)); ASSERT_RETURN_TYPE( bool, sycl::all_of_group(sub_group, none_true(local_var)), @@ -518,23 +545,23 @@ void bool_function_of_sub_group(sycl::queue& queue) { // Note that 'one_true' returns true for the first item. Thus in the // case that the sub-group size is 1, check that all items match; // otherwise check that not all items match. - res_acc[5] = + res_acc[5] &= sycl::all_of_group(sub_group, one_true(local_var)) ^ (size != 1); // Note that 'some_true' returns true for the first item if the // sub-group size is 1. In that case, check that all items match; // otherwise check that not all items match. - res_acc[6] = + res_acc[6] &= sycl::all_of_group(sub_group, some_true(local_var)) ^ (size != 1); - res_acc[7] = sycl::all_of_group(sub_group, all_true(local_var)); + res_acc[7] &= sycl::all_of_group(sub_group, all_true(local_var)); ASSERT_RETURN_TYPE( bool, sycl::none_of_group(sub_group, none_true(local_var)), "Return type of none_of_group(sub_group g, bool pred) is " "wrong\n"); - res_acc[8] = sycl::none_of_group(sub_group, none_true(local_var)); - res_acc[9] = !sycl::none_of_group(sub_group, one_true(local_var)); - res_acc[10] = !sycl::none_of_group(sub_group, some_true(local_var)); - res_acc[11] = !sycl::none_of_group(sub_group, all_true(local_var)); + res_acc[8] &= sycl::none_of_group(sub_group, none_true(local_var)); + res_acc[9] &= !sycl::none_of_group(sub_group, one_true(local_var)); + res_acc[10] &= !sycl::none_of_group(sub_group, some_true(local_var)); + res_acc[11] &= !sycl::none_of_group(sub_group, all_true(local_var)); } }); });