Skip to content

Commit

Permalink
portability
Browse files Browse the repository at this point in the history
Signed-off-by: Jeff Hammond <[email protected]>
  • Loading branch information
jeffhammond committed Apr 15, 2024
1 parent 97a1ffe commit 8de0b0e
Showing 1 changed file with 70 additions and 49 deletions.
119 changes: 70 additions & 49 deletions Cxx11/xgemm-cblas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,79 +60,100 @@
#include "prk_util.h"

#if defined(MKL)
#include <mkl_cblas.h>
#include <mkl_cblas.h>
#define PRK_INT MKL_INT
#define PRK_F16 MKL_F16
#define USE_F16 1
#define PRK_BF16 MKL_BF16
#define USE_BF16 1
#elif defined(ACCELERATE)
// The location of cblas.h is not in the system include path when -framework Accelerate is provided.
#include <Accelerate/Accelerate.h>
// The location of cblas.h is not in the system include path when -framework Accelerate is provided.
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif

#ifndef MKL_INT
#define MKL_INT int
// assume OpenBLAS for now
#include <cblas.h>
#ifdef OPENBLAS_USE64BITINT
#define PRK_INT long
#else
#define PRK_INT int
#endif
#define PRK_BF16 bfloat16
#endif

template <typename TAB, typename TC>
void prk_gemm(const CBLAS_LAYOUT Layout,
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
const MKL_INT M, const MKL_INT N, const MKL_INT K,
const PRK_INT M, const PRK_INT N, const PRK_INT K,
const TC alpha,
const TAB * A, const MKL_INT lda,
const TAB * B, const MKL_INT ldb,
const TAB * A, const PRK_INT lda,
const TAB * B, const PRK_INT ldb,
const TC beta,
TC * C, const MKL_INT ldc)
TC * C, const PRK_INT ldc)
{
std::cerr << "No valid template match for type T" << std::endl;
std::abort();
}

#ifdef MKL_F16
#ifdef PRK_F16
template <>
void prk_gemm(const CBLAS_LAYOUT Layout,
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
const MKL_INT M, const MKL_INT N, const MKL_INT K,
const MKL_F16 alpha,
const MKL_F16 * A, const MKL_INT lda,
const MKL_F16 * B, const MKL_INT ldb,
const MKL_F16 beta,
MKL_F16 * C, const MKL_INT ldc)
const PRK_INT M, const PRK_INT N, const PRK_INT K,
const PRK_F16 alpha,
const PRK_F16 * A, const PRK_INT lda,
const PRK_F16 * B, const PRK_INT ldb,
const PRK_F16 beta,
PRK_F16 * C, const PRK_INT ldc)
{
cblas_hgemm(Layout, TransA, TransB,
M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}
#endif

#ifdef MKL_BF16
#ifdef USE_BF16
template <>
void prk_gemm(const CBLAS_LAYOUT Layout,
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
const MKL_INT M, const MKL_INT N, const MKL_INT K,
const PRK_INT M, const PRK_INT N, const PRK_INT K,
const float alpha,
const MKL_BF16 * A, const MKL_INT lda,
const MKL_BF16 * B, const MKL_INT ldb,
const PRK_BF16 * A, const PRK_INT lda,
const PRK_BF16 * B, const PRK_INT ldb,
const float beta,
float * C, const MKL_INT ldc)
float * C, const PRK_INT ldc)
{
// cblas_gemm_bf16bf16f32(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE TransA,
// const CBLAS_TRANSPOSE TransB,
// const MKL_INT M, const MKL_INT N, const MKL_INT K,
// const float alpha, const MKL_BF16 *A, const MKL_INT lda,
// const MKL_BF16 *B, const MKL_INT ldb, const float beta,
// float *C, const MKL_INT ldc);
#ifdef MKL
// MKL
// cblas_gemm_bf16bf16f32(const CBLAS_LAYOUT Layout,
// const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
// const PRK_INT M, const PRK_INT N, const PRK_INT K,
// const float alpha, const PRK_BF16 *A, const PRK_INT lda,
// const PRK_BF16 *B, const PRK_INT ldb,
// const float beta, float *C, const PRK_INT ldc);
cblas_gemm_bf16bf16f32(Layout, TransA, TransB,
M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
#else
// OpenBLAS
// cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order,
// OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB,
// OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
// OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda,
// OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb,
// OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
cblas_sbgemm(Layout, TransA, TransB,
M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
#endif
}
#endif

template <>
void prk_gemm(const CBLAS_LAYOUT Layout,
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
const MKL_INT M, const MKL_INT N, const MKL_INT K,
const PRK_INT M, const PRK_INT N, const PRK_INT K,
const float alpha,
const float * A, const MKL_INT lda,
const float * B, const MKL_INT ldb,
const float * A, const PRK_INT lda,
const float * B, const PRK_INT ldb,
const float beta,
float * C, const MKL_INT ldc)
float * C, const PRK_INT ldc)
{
cblas_sgemm(Layout, TransA, TransB,
M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
Expand All @@ -141,26 +162,26 @@ void prk_gemm(const CBLAS_LAYOUT Layout,
template <>
void prk_gemm(const CBLAS_LAYOUT Layout,
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
const MKL_INT M, const MKL_INT N, const MKL_INT K,
const PRK_INT M, const PRK_INT N, const PRK_INT K,
const double alpha,
const double * A, const MKL_INT lda,
const double * B, const MKL_INT ldb,
const double * A, const PRK_INT lda,
const double * B, const PRK_INT ldb,
const double beta,
double * C, const MKL_INT ldc)
double * C, const PRK_INT ldc)
{
cblas_dgemm(Layout, TransA, TransB,
M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}

#ifdef MKL_BF16
#ifdef USE_BF16
void run_BF16(int iterations, int order)
{
double gemm_time{0};

const size_t nelems = (size_t)order * (size_t)order;

auto A = new MKL_BF16[nelems];
auto B = new MKL_BF16[nelems];
auto A = new PRK_BF16[nelems];
auto B = new PRK_BF16[nelems];
auto C = new float[nelems];

for (int i=0; i<order; ++i) {
Expand Down Expand Up @@ -205,7 +226,7 @@ void run_BF16(int iterations, int order)
}
const double residuum = std::abs(checksum - reference) / reference;
const double epsilon{1.0e-8};
if ((residuum < epsilon) || (sizeof(MKL_BF16) < 4)) {
if ((residuum < epsilon) || (sizeof(PRK_BF16) < 4)) {
#if VERBOSE
std::cout << "Reference checksum = " << reference << "\n"
<< "Actual checksum = " << checksum << std::endl;
Expand All @@ -232,14 +253,14 @@ void run(int iterations, int order)
auto is_fp64 = (typeid(T) == typeid(double));
auto is_fp32 = (typeid(T) == typeid(float));
auto is_fp16 =
#ifdef MKL_F16
(typeid(T) == typeid(MKL_F16));
#ifdef USE_F16
(typeid(T) == typeid(PRK_F16));
#else
false;
#endif
auto is_bf16 =
#ifdef MKL_BF16
(typeid(T) == typeid(MKL_BF16));
#ifdef USE_BF16
(typeid(T) == typeid(PRK_BF16));
#else
false;
#endif
Expand Down Expand Up @@ -352,10 +373,10 @@ int main(int argc, char * argv[])
std::cout << "Number of iterations = " << iterations << std::endl;
std::cout << "Matrix order = " << order << std::endl;

#ifdef MKL_F16
run<MKL_F16>(iterations, order);
#ifdef USE_F16
run<PRK_F16>(iterations, order);
#endif
#ifdef MKL_BF16
#ifdef USE_BF16
run_BF16(iterations, order);
#endif
run<float>(iterations, order);
Expand Down

0 comments on commit 8de0b0e

Please sign in to comment.