Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for SYCL + MKL GEMM #2062

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions blas/tpls/KokkosBlas2_gemv_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,17 +777,6 @@ KOKKOSBLAS2_CGEMV_ROCBLAS(Kokkos::LayoutRight, Kokkos::HIPSpace, false)
namespace KokkosBlas {
namespace Impl {

inline oneapi::mkl::transpose mode_kk_to_onemkl(char mode_kk) {
switch (toupper(mode_kk)) {
case 'N': return oneapi::mkl::transpose::nontrans;
case 'T': return oneapi::mkl::transpose::trans;
case 'C': return oneapi::mkl::transpose::conjtrans;
default:;
}
throw std::invalid_argument(
"Invalid mode for oneMKL (should be one of N, T, C)");
}

template <typename T, bool is_complex = false>
struct kokkos_to_std_type_map {
using type = T;
Expand Down Expand Up @@ -829,7 +818,7 @@ struct kokkos_to_std_type_map<T, true> {
bool row_major = std::is_same<Kokkos::LayoutRight, LAYOUT>::value; \
const std::int64_t M = A.extent(0); \
const std::int64_t N = A.extent(1); \
oneapi::mkl::transpose trans = mode_kk_to_onemkl(kk_trans[0]); \
oneapi::mkl::transpose trans = trans_mode_kk_to_onemkl(kk_trans[0]); \
const std::int64_t LDA = row_major ? A.stride(0) : A.stride(1); \
std::string label = "KokkosBlas::gemv[TPL_ONEMKL," + \
Kokkos::ArithTraits<SCALAR>::name() + "]"; \
Expand Down
40 changes: 40 additions & 0 deletions blas/tpls/KokkosBlas3_gemm_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,46 @@ KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_ROCBLAS(Kokkos::complex<float>,
Kokkos::LayoutRight, Kokkos::HIPSpace)

#endif

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)

#define KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(SCALAR, LAYOUT, MEMSPACE) \
template <> \
struct gemm_tpl_spec_avail< \
Kokkos::Experimental::SYCL, \
Kokkos::View<const SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<const SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > > { \
enum : bool { value = true }; \
};

KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(double, Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(float, Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex<double>, Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex<float>, Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace)

KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(double, Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(float, Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex<double>,
Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex<float>, Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace)

#endif

} // namespace Impl
} // namespace KokkosBlas

Expand Down
142 changes: 142 additions & 0 deletions blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,4 +501,146 @@ KOKKOSBLAS3_CGEMM_ROCBLAS(Kokkos::LayoutRight, Kokkos::HIPSpace, false)
} // namespace KokkosBlas
#endif // KOKKOSKERNELS_ENABLE_TPL_ROCBLAS

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)
#include <KokkosBlas_tpl_spec.hpp>
#include <oneapi/mkl/blas.hpp>

namespace KokkosBlas::Impl {

/*!
SCALAR_TYPE is the Kokkos Kernels type
TPL_SCALAR_TYPE is the type MKL accents for SCALAR_TYPE
*/
#define KOKKOSBLAS3_XGEMM_MKL(SCALAR_TYPE, TPL_SCALAR_TYPE, LAYOUT, MEM_SPACE, \
ETI_SPEC_AVAIL) \
template <> \
struct GEMM< \
Kokkos::Experimental::SYCL, \
Kokkos::View<const SCALAR_TYPE**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<const SCALAR_TYPE**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<SCALAR_TYPE**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
true, ETI_SPEC_AVAIL> { \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like it if we can drop this ETI_SPEC_AVAIL in favor of calling the proper struct to check if this value should be true or false based on the actual eti being performed...

typedef SCALAR_TYPE SCALAR; \
typedef Kokkos::View< \
const SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > \
AViewType; \
typedef Kokkos::View< \
const SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > \
BViewType; \
typedef Kokkos::View< \
SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > \
CViewType; \
\
static void gemm(const typename CViewType::execution_space& space, \
const char transA[], const char transB[], \
typename AViewType::const_value_type& alpha, \
const AViewType& A, const BViewType& B, \
typename CViewType::const_value_type& beta, \
const CViewType& C) { \
Kokkos::Profiling::pushRegion("KokkosBlas::gemm[TPL_MKL," #SCALAR_TYPE \
"]"); \
\
const bool A_t = (transA[0] != 'N') && (transA[0] != 'n'); \
const int64_t M = static_cast<int64_t>(C.extent(0)); \
const int64_t N = static_cast<int64_t>(C.extent(1)); \
const int64_t K = static_cast<int64_t>(A.extent(A_t ? 0 : 1)); \
\
constexpr bool is_lr = std::is_same<Kokkos::LayoutRight, LAYOUT>::value; \
\
const int64_t ast = is_lr ? A.stride(0) : A.stride(1); \
const int64_t lda = ast == 0 ? 1 : ast; \
const int64_t bst = is_lr ? B.stride(0) : B.stride(1); \
const int64_t ldb = bst == 0 ? 1 : bst; \
const int64_t cst = is_lr ? C.stride(0) : C.stride(1); \
const int64_t ldc = cst == 0 ? 1 : cst; \
\
oneapi::mkl::transpose transa = trans_mode_kk_to_onemkl(transA[0]); \
oneapi::mkl::transpose transb = trans_mode_kk_to_onemkl(transB[0]); \
oneapi::mkl::blas::compute_mode mode = \
oneapi::mkl::blas::compute_mode::standard; \
\
if constexpr (!is_lr) { \
oneapi::mkl::blas::column_major::gemm( \
space.sycl_queue(), transa, transb, M, N, K, alpha, \
reinterpret_cast<const TPL_SCALAR_TYPE*>(A.data()), lda, \
reinterpret_cast<const TPL_SCALAR_TYPE*>(B.data()), ldb, beta, \
reinterpret_cast<TPL_SCALAR_TYPE*>(C.data()), ldc, mode); \
} else { \
oneapi::mkl::blas::row_major::gemm( \
space.sycl_queue(), transa, transb, M, N, K, alpha, \
reinterpret_cast<const TPL_SCALAR_TYPE*>(A.data()), lda, \
reinterpret_cast<const TPL_SCALAR_TYPE*>(B.data()), ldb, beta, \
reinterpret_cast<TPL_SCALAR_TYPE*>(C.data()), ldc, mode); \
} \
\
Kokkos::Profiling::popRegion(); \
} \
};

#define KOKKOSBLAS3_DGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \
KOKKOSBLAS3_XGEMM_MKL(double, double, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL)

#define KOKKOSBLAS3_SGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \
KOKKOSBLAS3_XGEMM_MKL(float, float, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL)

#define KOKKOSBLAS3_ZGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \
KOKKOSBLAS3_XGEMM_MKL(Kokkos::complex<double>, std::complex<double>, LAYOUT, \
MEM_SPACE, ETI_SPEC_AVAIL)

#define KOKKOSBLAS3_CGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \
KOKKOSBLAS3_XGEMM_MKL(Kokkos::complex<float>, std::complex<float>, LAYOUT, \
MEM_SPACE, ETI_SPEC_AVAIL)

// ETI_SPEC_AVAIL is both false and true here, because we want to use
// MKL regardless of whether ETI is available.
KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)
KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)

KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)
KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)

KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)
KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)

KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)
KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)
} // namespace KokkosBlas::Impl
#endif // KOKKOSKERNELS_ENABLE_TPL_MKL && KOKKOS_ENABLE_SYCL

#endif
26 changes: 26 additions & 0 deletions blas/tpls/KokkosBlas_tpl_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,30 @@ struct MagmaSingleton {
} // namespace KokkosBlas
#endif // KOKKOSKERNELS_ENABLE_TPL_MAGMA

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)
#include <sstream>
#include <oneapi/mkl/types.hpp>

namespace KokkosBlas {
namespace Impl {

/// \brief This function converts KK transpose mode to MKL transpose mode
inline oneapi::mkl::transpose trans_mode_kk_to_onemkl(char mode_kk) {
switch (toupper(mode_kk)) {
case 'N': return oneapi::mkl::transpose::nontrans;
case 'T': return oneapi::mkl::transpose::trans;
case 'C': return oneapi::mkl::transpose::conjtrans;
default:;
}
std::stringstream ss;
ss << "Invalid mode \"" << mode_kk
<< "\" for oneMKL (should be one of N, T, C)";
throw std::invalid_argument(ss.str());
}

} // namespace Impl
} // namespace KokkosBlas

#endif // KOKKOSKERNELS_ENABLE_TPL_MKL && KOKKOS_ENABLE_SYCL

#endif // KOKKOSBLAS_TPL_SPEC_HPP_