Skip to content

Commit

Permalink
kokkos: fix use of views in parallel_reduce in unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikołaj Zuzek committed Apr 15, 2022
1 parent 09f6ddc commit 3628b87
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
30 changes: 22 additions & 8 deletions tests/kokkos-based/matrix_rank1_update_kokkos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ inline bool is_same_vector(
const auto size = v1.extent(0);
if (size != v2.extent(0))
return false;
const auto v1_view = KokkosKernelsSTD::Impl::mdspan_to_view(v1);
const auto v2_view = KokkosKernelsSTD::Impl::mdspan_to_view(v2);
int diff = false;
Kokkos::parallel_reduce(size,
KOKKOS_LAMBDA(const std::size_t i, decltype(diff) &d){
d = d || !(v1(i) == v2(i));
d = d || !(v1_view(i) == v2_view(i));
}, diff);
return !diff;
}
Expand Down Expand Up @@ -155,6 +157,18 @@ class value_diff<std::complex<T>>: public value_diff<T> {
}
};

template <typename T>
class value_diff<Kokkos::complex<T>>: public value_diff<T> {
using base = value_diff<T>;
public:
KOKKOS_INLINE_FUNCTION
value_diff(const Kokkos::complex<T> &val1, const Kokkos::complex<T> &val2) {
const T dreal = base(val1.real(), val2.real());
const T dimag = base(val1.imag(), val2.imag());
base::_v = dreal > dimag ? dreal : dimag; // can't use std::max on GPU
}
};

template <typename ElementType,
typename LayoutPolicy1,
typename AccessorPolicy1,
Expand All @@ -170,14 +184,14 @@ inline bool is_same_matrix(
const auto ext1 = A.extent(1);
if (B.extent(0) != ext0 or B.extent(1) != ext1)
return false;
const auto A_view = KokkosKernelsSTD::Impl::mdspan_to_view(A);
const auto B_view = KokkosKernelsSTD::Impl::mdspan_to_view(B);
int diff = false;
const auto size = ext0 * ext1;
Kokkos::parallel_reduce(size,
KOKKOS_LAMBDA(const std::size_t ij, decltype(diff) &d) {
const auto i = ij / ext1;
const auto j = ij - i * ext1;
if (value_diff(A(i, j), B(i, j)) > tolerance)
d = true;
Kokkos::parallel_reduce(ext0,
KOKKOS_LAMBDA(std::size_t i, decltype(diff) &d) {
for (decltype(i) j = 0; j < ext1; ++j) {
d = d || (value_diff(A_view(i, j), B_view(i, j)) > tolerance);
}
}, diff);
return !diff;
}
Expand Down
30 changes: 22 additions & 8 deletions tests/kokkos-based/symmetric_matrix_rank1_update_kokkos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ inline bool is_same_vector(
const auto size = v1.extent(0);
if (size != v2.extent(0))
return false;
const auto v1_view = KokkosKernelsSTD::Impl::mdspan_to_view(v1);
const auto v2_view = KokkosKernelsSTD::Impl::mdspan_to_view(v2);
int diff = false;
Kokkos::parallel_reduce(size,
KOKKOS_LAMBDA(const std::size_t i, decltype(diff) &d){
d = d || !(v1(i) == v2(i));
d = d || !(v1_view(i) == v2_view(i));
}, diff);
return !diff;
}
Expand Down Expand Up @@ -155,6 +157,18 @@ class value_diff<std::complex<T>>: public value_diff<T> {
}
};

template <typename T>
class value_diff<Kokkos::complex<T>>: public value_diff<T> {
using base = value_diff<T>;
public:
KOKKOS_INLINE_FUNCTION
value_diff(const Kokkos::complex<T> &val1, const Kokkos::complex<T> &val2) {
const T dreal = base(val1.real(), val2.real());
const T dimag = base(val1.imag(), val2.imag());
base::_v = dreal > dimag ? dreal : dimag; // can't use std::max on GPU
}
};

template <typename ElementType,
typename LayoutPolicy1,
typename AccessorPolicy1,
Expand All @@ -170,14 +184,14 @@ inline bool is_same_matrix(
const auto ext1 = A.extent(1);
if (B.extent(0) != ext0 or B.extent(1) != ext1)
return false;
const auto A_view = KokkosKernelsSTD::Impl::mdspan_to_view(A);
const auto B_view = KokkosKernelsSTD::Impl::mdspan_to_view(B);
int diff = false;
const auto size = ext0 * ext1;
Kokkos::parallel_reduce(size,
KOKKOS_LAMBDA(const std::size_t ij, decltype(diff) &d) {
const auto i = ij / ext1;
const auto j = ij - i * ext1;
if (value_diff(A(i, j), B(i, j)) > tolerance)
d = true;
Kokkos::parallel_reduce(ext0,
KOKKOS_LAMBDA(std::size_t i, decltype(diff) &d) {
for (decltype(i) j = 0; j < ext1; ++j) {
d = d || (value_diff(A_view(i, j), B_view(i, j)) > tolerance);
}
}, diff);
return !diff;
}
Expand Down

0 comments on commit 3628b87

Please sign in to comment.