diff --git a/examples/reshape.cpp b/examples/reshape.cpp index 9c680a6d17..e670e7754a 100644 --- a/examples/reshape.cpp +++ b/examples/reshape.cpp @@ -148,7 +148,7 @@ void checkResult(int *ptr, int K, int N, int M) for(int n = 0; n < N; ++n) { for(int m = 0; m < M; ++m) { const int idx = m + M * (n + N * k); - if (std::abs(ptr[idx] - idx) > 0) { + if (ptr[idx] != idx) { status = false; } } diff --git a/include/RAJA/util/View.hpp b/include/RAJA/util/View.hpp index 6134ab07e6..a8239c49dd 100644 --- a/include/RAJA/util/View.hpp +++ b/include/RAJA/util/View.hpp @@ -18,13 +18,11 @@ #ifndef RAJA_VIEW_HPP #define RAJA_VIEW_HPP -#include #include +#include #include "RAJA/config.hpp" - #include "RAJA/pattern/atomic.hpp" - #include "RAJA/util/IndexLayout.hpp" #include "RAJA/util/Layout.hpp" #include "RAJA/util/OffsetLayout.hpp" @@ -297,81 +295,88 @@ RAJA_INLINE AtomicViewWrapper make_atomic_view( return RAJA::AtomicViewWrapper(view); } -struct layout_left{}; -struct layout_right{}; - -template -struct Reshape; +struct layout_left { +}; +struct layout_right { +}; -template +template struct Reshape; -template +namespace detail +{ +template constexpr auto get_last_index(T last) { return last; } -template -constexpr auto get_last_index(T , Args... args) +template +constexpr auto get_last_index(T, Args... args) { return get_last_index(args...); } +} // namespace detail -template -struct Reshape> -{ - template +template +struct Reshape> { + template static auto get(T *ptr, Ts... s) { constexpr int N = sizeof...(Ts); std::array extent{s...}; auto custom_layout = - RAJA::make_permuted_layout(extent, std::array{Is...}); + RAJA::make_permuted_layout(extent, std::array{Is...}); - constexpr auto unit_stride = get_last_index(Is...); + constexpr auto unit_stride = detail::get_last_index(Is...); + using view_t = RAJA::View>; - return RAJA::View> - (ptr, custom_layout); + return view_t(ptr, custom_layout); } }; -template<> -struct Reshape -{ - template +template <> +struct Reshape { + template static auto get(T *ptr, Ts... s) { constexpr int N = sizeof...(Ts); - using view_t = RAJA::View>; + using view_t = RAJA::View>; return view_t(ptr, s...); } }; -template -constexpr std::array make_reverse_array(std::index_sequence) { - return std::array{sizeof...(Is) - 1U - Is ...}; -} +namespace detail +{ -template<> -struct Reshape +template +constexpr std::array make_reverse_array( + std::index_sequence) { - template + return std::array{sizeof...(Is) - 1U - Is...}; +} + +} // namespace detail + +template <> +struct Reshape { + template static auto get(T *ptr, Ts... s) { constexpr int N = sizeof...(Ts); std::array extent{s...}; - constexpr auto reverse_array = make_reverse_array(std::make_index_sequence{}); + constexpr auto reverse_array = + detail::make_reverse_array(std::make_index_sequence{}); auto reverse_layout = RAJA::make_permuted_layout(extent, reverse_array); + using view_t = RAJA::View>; - return RAJA::View>(ptr, reverse_layout); + return view_t(ptr, reverse_layout); } - }; } // namespace RAJA