Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor serial tbsv implementation details and tests #2478

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions batched/dense/impl/KokkosBatched_Pbtrs_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_

#include "KokkosBlas_util.hpp"
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"

Expand Down Expand Up @@ -50,8 +51,9 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalLower<Algo::Pbtrs::Unblocked>::inv
SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);

// Solve L**T *X = B, overwriting B with X.
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);
using op =
std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(op(), false, an, A, as0, as1, x, xs0, kd);

return 0;
}
Expand All @@ -76,8 +78,9 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalUpper<Algo::Pbtrs::Unblocked>::inv
/**/ ValueType *KOKKOS_RESTRICT x,
const int xs0, const int kd) {
// Solve U**T *X = B, overwriting B with X.
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);
using op =
std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(op(), false, an, A, as0, as1, x, xs0, kd);

// Solve U*X = B, overwriting B with X.
SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);
Expand Down
45 changes: 26 additions & 19 deletions batched/dense/impl/KokkosBatched_Tbsv_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@

/// \author Yuuichi Asahi ([email protected])

#include "KokkosBlas_util.hpp"
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"

namespace KokkosBatched {

namespace Impl {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const XViewType &x, [[maybe_unused]] const int k) {
static_assert(Kokkos::is_view<AViewType>::value, "KokkosBatched::tbsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<XViewType>::value, "KokkosBatched::tbsv: XViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::tbsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::tbsv: XViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::tbsv: AViewType must have rank 2.");
static_assert(XViewType::rank == 1, "KokkosBatched::tbsv: XViewType must have rank 1.");

Expand Down Expand Up @@ -63,15 +64,17 @@ KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewTyp
return 0;
}

} // namespace Impl

//// Lower non-transpose ////
template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(
return Impl::SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
}
};
Expand All @@ -81,11 +84,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -94,11 +98,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -107,10 +112,10 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(
return Impl::SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
}
};
Expand All @@ -120,11 +125,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -133,11 +139,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand Down
58 changes: 18 additions & 40 deletions batched/dense/impl/KokkosBatched_Tbsv_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
#include "KokkosBatched_Util.hpp"

namespace KokkosBatched {

namespace Impl {
///
/// Serial Internal Impl
/// ====================

///
/// Lower, Non-Transpose
/// Lower
///

template <typename AlgoType>
Expand Down Expand Up @@ -70,49 +70,37 @@ KOKKOS_INLINE_FUNCTION int SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invok

template <typename AlgoType>
struct SerialTbsvInternalLowerTranspose {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int an,
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const bool use_unit_diag, const int an,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k);
};

template <>
template <typename ValueType>
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
const bool use_unit_diag, const bool do_conj, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
Op op, const bool use_unit_diag, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = an - 1; j >= 0; --j) {
auto temp = x[j * xs0];

if (do_conj) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= Kokkos::ArithTraits<ValueType>::conj(A[(i - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / Kokkos::ArithTraits<ValueType>::conj(A[0 + j * as1]);
} else {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= A[(i - j) * as0 + j * as1] * x[i * xs0];
}
if (!use_unit_diag) temp = temp / A[0 + j * as1];
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= op(A[(i - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / op(A[0 + j * as1]);
x[j * xs0] = temp;
}

return 0;
}

///
/// Upper, Non-Transpose
/// Upper
///

template <typename AlgoType>
Expand Down Expand Up @@ -154,46 +142,36 @@ KOKKOS_INLINE_FUNCTION int SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invok

template <typename AlgoType>
struct SerialTbsvInternalUpperTranspose {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int an,
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const bool use_unit_diag, const int an,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k);
};

template <>
template <typename ValueType>
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
const bool use_unit_diag, const bool do_conj, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
Op op, const bool use_unit_diag, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < an; j++) {
auto temp = x[j * xs0];
if (do_conj) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= Kokkos::ArithTraits<ValueType>::conj(A[(i + k - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / Kokkos::ArithTraits<ValueType>::conj(A[k * as0 + j * as1]);
} else {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= A[(i + k - j) * as0 + j * as1] * x[i * xs0];
}
if (!use_unit_diag) temp = temp / A[k * as0 + j * as1];
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= op(A[(i + k - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / op(A[k * as0 + j * as1]);
x[j * xs0] = temp;
}

return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_TBSV_SERIAL_INTERNAL_HPP_
Loading
Loading